├── CODEOWNERS
├── templates
├── package.json
├── css.go
├── Makefile
├── package-lock.json
├── README.md
├── error.go
├── input.css
└── error.gohtml
├── .gitignore
├── docs
├── README.md
├── sessions.md
└── usage.md
├── charts
├── wonderwall
│ ├── templates
│ │ ├── serviceaccount.yaml
│ │ ├── configmap.yaml
│ │ ├── idporten-secret.yaml
│ │ ├── idporten-poddisruptionbudget.yaml
│ │ ├── idporten-service.yaml
│ │ ├── fa-poddisruptionbudget.yaml
│ │ ├── fa-service.yaml
│ │ ├── fa-secret.yaml
│ │ ├── idporten-horizontalpodautoscaler.yaml
│ │ ├── fa-horizontalpodautoscaler.yaml
│ │ ├── idporten-ingress.yaml
│ │ ├── fa-ingress.yaml
│ │ ├── fa-azureapp.yaml
│ │ ├── idporten-idportenclient.yaml
│ │ ├── redis.yaml
│ │ ├── networkpolicy.yaml
│ │ ├── _resources.yaml
│ │ └── prometheusrule.yaml
│ ├── Chart.yaml
│ ├── .helmignore
│ ├── values.yaml
│ └── Feature.yaml
└── wonderwall-forward-auth
│ ├── Chart.yaml
│ ├── templates
│ ├── poddisruptionbudget.yaml
│ ├── service.yaml
│ ├── secret.yaml
│ ├── servicemonitor.yaml
│ ├── horizontalpodautoscaler.yaml
│ ├── ingress.yaml
│ ├── prometheusrule.yaml
│ ├── networkpolicy.yaml
│ └── _helpers.tpl
│ ├── .helmignore
│ ├── values.yaml
│ └── Feature.yaml
├── .envrc
├── internal
├── crypto
│ ├── text.go
│ ├── jwk.go
│ ├── crypter_test.go
│ └── crypter.go
├── o11y
│ ├── otel
│ │ ├── otel_test.go
│ │ ├── middleware.go
│ │ └── otel.go
│ └── logging
│ │ └── logging.go
├── http
│ ├── transport.go
│ ├── middleware.go
│ └── request.go
└── retry
│ └── retry.go
├── pkg
├── router
│ └── paths
│ │ └── paths.go
├── handler
│ ├── path.go
│ ├── acr
│ │ └── acr.go
│ ├── handler_sso_server.go
│ └── autologin
│ │ └── autologin.go
├── mock
│ ├── request.go
│ ├── config.go
│ └── client.go
├── middleware
│ ├── correlationid.go
│ ├── cors.go
│ ├── ingress.go
│ ├── context.go
│ ├── logentry.go
│ └── prometheus.go
├── openid
│ ├── oauth2_test.go
│ ├── scopes
│ │ └── scopes.go
│ ├── config
│ │ ├── config.go
│ │ ├── provider_test.go
│ │ └── client.go
│ ├── client
│ │ ├── logout_callback.go
│ │ ├── logout.go
│ │ ├── logout_test.go
│ │ ├── logout_callback_test.go
│ │ ├── login_callback.go
│ │ └── client_test.go
│ ├── cookies.go
│ ├── acr
│ │ ├── acr_test.go
│ │ └── acr.go
│ └── provider
│ │ └── provider.go
├── cookie
│ ├── options.go
│ ├── options_test.go
│ └── cookie.go
├── session
│ ├── store_memory_test.go
│ ├── lock_test.go
│ ├── store_redis_test.go
│ ├── store_memory.go
│ ├── id.go
│ ├── store.go
│ ├── lock.go
│ ├── session_reader.go
│ ├── store_redis.go
│ ├── ticket.go
│ ├── store_test.go
│ └── session.go
├── config
│ ├── ratelimit.go
│ ├── otel.go
│ ├── session.go
│ ├── redis.go
│ ├── cookie.go
│ ├── sso.go
│ └── config_test.go
├── server
│ └── server.go
├── url
│ ├── url.go
│ └── validator.go
└── ingress
│ └── ingress.go
├── Dockerfile
├── docker-compose.yml
├── .github
├── workflows
│ └── dependabot-auto-merge.yaml
└── dependabot.yml
├── LICENSE
├── mise.toml
├── docker-compose.example.yml
├── flake.nix
├── flake.lock
├── cmd
└── wonderwall
│ └── main.go
└── hack
└── dashboard.yaml
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | @navikt/aura
2 |
3 |
--------------------------------------------------------------------------------
/templates/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "devDependencies": {
3 | "tailwindcss": "^4.0.0"
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin/
2 | run.sh
3 | .vscode/
4 | .idea
5 | *.iml
6 | .env*
7 | *.out
8 | *.tgz
9 |
10 | # nix stuffs
11 | /.direnv
12 | result
13 |
--------------------------------------------------------------------------------
/templates/css.go:
--------------------------------------------------------------------------------
1 | package templates
2 |
3 | import (
4 | _ "embed"
5 | "html/template"
6 | )
7 |
8 | //go:embed output.css
9 | var CSS template.CSS
10 |
--------------------------------------------------------------------------------
/templates/Makefile:
--------------------------------------------------------------------------------
1 | local:
2 | npx @tailwindcss/cli -i ./input.css -o ./output.css --watch
3 |
4 | build:
5 | npx @tailwindcss/cli -i ./input.css -o ./output.css --minify
6 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Table of Contents
2 |
3 | - [Architecture](architecture.md)
4 | - [Configuration](configuration.md)
5 | - [Endpoints](endpoints.md)
6 | - [Usage](usage.md)
7 | - [Session Management](sessions.md)
8 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/serviceaccount.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ServiceAccount
3 | metadata:
4 | labels:
5 | {{- include "wonderwall.labels" . | nindent 4 }}
6 | name: {{ include "wonderwall.fullname" . }}
7 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | name: wonderwall-forward-auth
3 | description: Forward-auth service for loadbalancer-fa
4 | type: application
5 | version: 1.0.0
6 | sources:
7 | - https://github.com/nais/wonderwall/tree/master/charts/wonderwall-forward-auth
8 |
--------------------------------------------------------------------------------
/charts/wonderwall/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | description: Reverse proxy with OIDC auth. Configuration for sidecar proxies and SSO server. Feature toggle in naiserator must be enabled for sidecars to be injected to Applications.
3 | name: wonderwall
4 | type: application
5 | version: 1.0.0
6 | sources:
7 | - https://github.com/nais/wonderwall/tree/master/charts/wonderwall
8 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/configmap.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ConfigMap
3 | metadata:
4 | name: {{ include "wonderwall.fullname" . }}
5 | labels:
6 | {{- include "wonderwall.labels" . | nindent 4 }}
7 | annotations:
8 | reloader.stakater.com/match: "true"
9 | data:
10 | wonderwall_image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
11 |
--------------------------------------------------------------------------------
/.envrc:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Export all:
4 | # - (should be) .gitignored
5 | # - (potentially) secret environment variables
6 | # - from dotenv-formatted files w/names starting w/`.env`
7 | DOTENV_FILES="$(find . -maxdepth 1 -type f -name '.env*' -and -not -name '.envrc')"
8 | for file in ${DOTENV_FILES}; do
9 | dotenv "${file}"
10 | done
11 | export DOTENV_FILES
12 |
13 | # Load nix env for all the cool people
14 | use flake
15 |
--------------------------------------------------------------------------------
/internal/crypto/text.go:
--------------------------------------------------------------------------------
1 | package crypto
2 |
3 | import (
4 | "crypto/rand"
5 | "encoding/base64"
6 | )
7 |
8 | // Text generates a cryptographically secure random string of a given length, and base64 URL-encodes it.
9 | func Text(length int) (string, error) {
10 | data := make([]byte, length)
11 | if _, err := rand.Read(data); err != nil {
12 | return "", err
13 | }
14 |
15 | return base64.RawURLEncoding.EncodeToString(data), nil
16 | }
17 |
--------------------------------------------------------------------------------
/internal/o11y/otel/otel_test.go:
--------------------------------------------------------------------------------
1 | package otel_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/nais/wonderwall/internal/o11y/otel"
7 | "github.com/nais/wonderwall/pkg/config"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestSetup(t *testing.T) {
12 | // Assert that version for semconv schemas don't conflict with the current otel version.
13 | _, err := otel.Setup(t.Context(), &config.Config{})
14 | assert.NoError(t, err)
15 | }
16 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/poddisruptionbudget.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: policy/v1
3 | kind: PodDisruptionBudget
4 | metadata:
5 | labels:
6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
7 | name: {{ include "wonderwall-forward-auth.fullname" . }}
8 | spec:
9 | {{- toYaml .Values.podDisruptionBudget | nindent 2 }}
10 | selector:
11 | matchLabels:
12 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }}
13 |
--------------------------------------------------------------------------------
/charts/wonderwall/.helmignore:
--------------------------------------------------------------------------------
1 | # Patterns to ignore when building packages.
2 | # This supports shell glob matching, relative path matching, and
3 | # negation (prefixed with !). Only one pattern per line.
4 | .DS_Store
5 | # Common VCS dirs
6 | .git/
7 | .gitignore
8 | .bzr/
9 | .bzrignore
10 | .hg/
11 | .hgignore
12 | .svn/
13 | # Common backup files
14 | *.swp
15 | *.bak
16 | *.tmp
17 | *.orig
18 | *~
19 | # Various IDEs
20 | .project
21 | .idea/
22 | *.tmproj
23 | .vscode/
24 |
25 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/.helmignore:
--------------------------------------------------------------------------------
1 | # Patterns to ignore when building packages.
2 | # This supports shell glob matching, relative path matching, and
3 | # negation (prefixed with !). Only one pattern per line.
4 | .DS_Store
5 | # Common VCS dirs
6 | .git/
7 | .gitignore
8 | .bzr/
9 | .bzrignore
10 | .hg/
11 | .hgignore
12 | .svn/
13 | # Common backup files
14 | *.swp
15 | *.bak
16 | *.tmp
17 | *.orig
18 | *~
19 | # Various IDEs
20 | .project
21 | .idea/
22 | *.tmproj
23 | .vscode/
24 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/idporten-secret.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.idporten.enabled }}
2 | ---
3 | apiVersion: v1
4 | kind: Secret
5 | type: kubernetes.io/Opaque
6 | metadata:
7 | name: "{{ .Values.idporten.ssoServerSecretName }}"
8 | labels:
9 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }}
10 | data:
11 | WONDERWALL_ENCRYPTION_KEY: "{{ .Values.idporten.sessionCookieEncryptionKey | required ".Values.idporten.sessionCookieEncryptionKey is required." | b64enc }}"
12 | {{ end }}
13 |
--------------------------------------------------------------------------------
/pkg/router/paths/paths.go:
--------------------------------------------------------------------------------
1 | package paths
2 |
3 | const (
4 | OAuth2 = "/oauth2"
5 | Login = "/login"
6 | LoginCallback = "/callback"
7 | Logout = "/logout"
8 | LogoutCallback = "/logout/callback"
9 | LogoutFrontChannel = "/logout/frontchannel"
10 | LogoutLocal = "/logout/local"
11 | Ping = "/ping"
12 | Refresh = "/refresh"
13 | Session = "/session"
14 | ForwardAuth = "/forwardauth"
15 | )
16 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/idporten-poddisruptionbudget.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.idporten.enabled }}
2 | apiVersion: policy/v1
3 | kind: PodDisruptionBudget
4 | metadata:
5 | labels:
6 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }}
7 | name: {{ include "wonderwall.fullname" . }}-idporten
8 | spec:
9 | {{- toYaml .Values.podDisruptionBudget | nindent 2 }}
10 | selector:
11 | matchLabels:
12 | {{- include "wonderwall.selectorLabelsIdporten" . | nindent 6 }}
13 | {{- end }}
14 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/idporten-service.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.idporten.enabled }}
2 | apiVersion: v1
3 | kind: Service
4 | metadata:
5 | labels:
6 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }}
7 | name: {{ include "wonderwall.fullname" . }}-idporten
8 | spec:
9 | type: ClusterIP
10 | ports:
11 | - name: http
12 | port: 80
13 | protocol: TCP
14 | targetPort: http
15 | selector:
16 | {{- include "wonderwall.selectorLabelsIdporten" . | nindent 4 }}
17 | {{ end }}
18 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM --platform=$BUILDPLATFORM golang:1.25 AS builder
2 | ENV CGO_ENABLED=0
3 | ENV GOTOOLCHAIN=auto
4 | WORKDIR /src
5 |
6 | COPY go.mod go.sum ./
7 | RUN go mod download
8 |
9 | COPY . .
10 | ARG TARGETOS
11 | ARG TARGETARCH
12 | RUN GOOS=$TARGETOS GOARCH=$TARGETARCH go build -trimpath -ldflags "-s -w" -a -o bin/wonderwall ./cmd/wonderwall
13 |
14 | FROM gcr.io/distroless/static-debian12:nonroot
15 | WORKDIR /app
16 | COPY --from=builder /src/bin/wonderwall /app/wonderwall
17 | ENTRYPOINT ["/app/wonderwall"]
18 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/fa-poddisruptionbudget.yaml:
--------------------------------------------------------------------------------
1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }}
2 | apiVersion: policy/v1
3 | kind: PodDisruptionBudget
4 | metadata:
5 | labels:
6 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }}
7 | name: {{ include "wonderwall.fullname" . }}-fa
8 | spec:
9 | {{- toYaml .Values.podDisruptionBudget | nindent 2 }}
10 | selector:
11 | matchLabels:
12 | {{- include "wonderwall.selectorLabelsForwardAuth" . | nindent 6 }}
13 | {{- end }}
14 |
--------------------------------------------------------------------------------
/pkg/handler/path.go:
--------------------------------------------------------------------------------
1 | package handler
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/nais/wonderwall/pkg/ingress"
7 | mw "github.com/nais/wonderwall/pkg/middleware"
8 | )
9 |
10 | // GetPath returns the matching context path from the list of registered ingresses.
11 | // If none match, an empty string is returned.
12 | func GetPath(r *http.Request, ingresses *ingress.Ingresses) string {
13 | path, ok := mw.PathFrom(r.Context())
14 | if !ok {
15 | path = ingresses.MatchingPath(r)
16 | }
17 |
18 | return path
19 | }
20 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/fa-service.yaml:
--------------------------------------------------------------------------------
1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }}
2 | apiVersion: v1
3 | kind: Service
4 | metadata:
5 | labels:
6 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }}
7 | name: {{ include "wonderwall.fullname" . }}-fa
8 | spec:
9 | type: ClusterIP
10 | ports:
11 | - name: http
12 | port: 80
13 | protocol: TCP
14 | targetPort: http
15 | selector:
16 | {{- include "wonderwall.selectorLabelsForwardAuth" . | nindent 4 }}
17 | {{- end }}
18 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/fa-secret.yaml:
--------------------------------------------------------------------------------
1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }}
2 | ---
3 | apiVersion: v1
4 | kind: Secret
5 | type: kubernetes.io/Opaque
6 | metadata:
7 | name: "{{ .Values.azure.forwardAuth.ssoServerSecretName }}"
8 | labels:
9 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }}
10 | data:
11 | WONDERWALL_ENCRYPTION_KEY: "{{ .Values.azure.forwardAuth.sessionCookieEncryptionKey | required ".Values.azure.forwardAuth.sessionCookieEncryptionKey is required." | b64enc }}"
12 | {{- end }}
13 |
--------------------------------------------------------------------------------
/templates/package-lock.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "templates",
3 | "lockfileVersion": 3,
4 | "requires": true,
5 | "packages": {
6 | "": {
7 | "devDependencies": {
8 | "tailwindcss": "^4.0.0"
9 | }
10 | },
11 | "node_modules/tailwindcss": {
12 | "version": "4.0.0",
13 | "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.0.0.tgz",
14 | "integrity": "sha512-ULRPI3A+e39T7pSaf1xoi58AqqJxVCLg8F/uM5A3FadUbnyDTgltVnXJvdkTjwCOGA6NazqHVcwPJC5h2vRYVQ==",
15 | "dev": true,
16 | "license": "MIT"
17 | }
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/service.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: Service
4 | metadata:
5 | labels:
6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
7 | name: {{ include "wonderwall-forward-auth.fullname" . }}
8 | spec:
9 | type: ClusterIP
10 | ports:
11 | - name: http
12 | port: 80
13 | protocol: TCP
14 | targetPort: http
15 | - name: http-metrics
16 | port: 8081
17 | protocol: TCP
18 | targetPort: http-metrics
19 | selector:
20 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 4 }}
21 |
--------------------------------------------------------------------------------
/internal/http/transport.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "net/http"
5 | "sync"
6 | "time"
7 |
8 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
9 | )
10 |
11 | var (
12 | defaultTransport *http.Transport
13 | once sync.Once
14 | )
15 |
16 | func Transport() http.RoundTripper {
17 | once.Do(func() {
18 | t := http.DefaultTransport.(*http.Transport).Clone()
19 | t.MaxIdleConns = 200
20 | t.MaxIdleConnsPerHost = 100
21 | t.IdleConnTimeout = 5 * time.Second
22 |
23 | defaultTransport = t
24 | })
25 |
26 | return otelhttp.NewTransport(defaultTransport)
27 | }
28 |
--------------------------------------------------------------------------------
/pkg/mock/request.go:
--------------------------------------------------------------------------------
1 | package mock
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 |
7 | "github.com/nais/wonderwall/pkg/ingress"
8 | mw "github.com/nais/wonderwall/pkg/middleware"
9 | )
10 |
11 | func NewGetRequest(target string, ingresses *ingress.Ingresses) *http.Request {
12 | req := httptest.NewRequest(http.MethodGet, target, nil)
13 |
14 | path := ingresses.MatchingPath(req)
15 | req = mw.RequestWithPath(req, path)
16 |
17 | ing, ok := ingresses.MatchingIngress(req)
18 | if ok {
19 | req = mw.RequestWithIngress(req, ing)
20 | }
21 |
22 | mw.Logger("test")
23 |
24 | return req
25 | }
26 |
--------------------------------------------------------------------------------
/pkg/middleware/correlationid.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | chi_middleware "github.com/go-chi/chi/v5/middleware"
8 | "github.com/google/uuid"
9 | )
10 |
11 | func CorrelationIDHandler(next http.Handler) http.Handler {
12 | fn := func(w http.ResponseWriter, r *http.Request) {
13 | id := r.Header.Get(chi_middleware.RequestIDHeader)
14 | if len(id) == 0 {
15 | id = uuid.New().String()
16 | }
17 |
18 | ctx := r.Context()
19 | ctx = context.WithValue(ctx, chi_middleware.RequestIDKey, id)
20 | next.ServeHTTP(w, r.WithContext(ctx))
21 | }
22 | return http.HandlerFunc(fn)
23 | }
24 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | redis:
3 | image: redis:7
4 | ports:
5 | - "6379:6379"
6 | otel:
7 | image: grafana/otel-lgtm:latest
8 | ports:
9 | - "3002:3000"
10 | - "4317:4317"
11 | - "4318:4318"
12 | mock-oauth2-server:
13 | image: ghcr.io/navikt/mock-oauth2-server:3.0.1
14 | ports:
15 | - "8888:8080"
16 | environment:
17 | JSON_CONFIG: "{\"interactiveLogin\":false}"
18 | upstream:
19 | image: mendhak/http-https-echo:30
20 | ports:
21 | - "4000:4000"
22 | environment:
23 | HTTP_PORT: 4000
24 | JWT_HEADER: Authorization
25 | LOG_IGNORE_PATH: /
26 |
--------------------------------------------------------------------------------
/pkg/handler/acr/acr.go:
--------------------------------------------------------------------------------
1 | package acr
2 |
3 | import (
4 | "github.com/nais/wonderwall/pkg/config"
5 | "github.com/nais/wonderwall/pkg/openid/acr"
6 | "github.com/nais/wonderwall/pkg/session"
7 | )
8 |
9 | type Handler struct {
10 | Enabled bool
11 | ExpectedValue string
12 | }
13 |
14 | func (h *Handler) Validate(sess *session.Session) error {
15 | if !h.Enabled || sess == nil {
16 | return nil
17 | }
18 |
19 | return acr.Validate(h.ExpectedValue, sess.Acr())
20 | }
21 |
22 | func NewHandler(cfg *config.Config) *Handler {
23 | return &Handler{
24 | Enabled: len(cfg.OpenID.ACRValues) > 0,
25 | ExpectedValue: cfg.OpenID.ACRValues,
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/secret.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: Secret
4 | type: kubernetes.io/Opaque
5 | metadata:
6 | name: {{ include "wonderwall-forward-auth.fullname" . }}
7 | labels:
8 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
9 | data:
10 | WONDERWALL_OPENID_CLIENT_SECRET: "{{ .Values.openid.clientSecret | required ".Values.openid.clientSecret is required." | b64enc }}"
11 | WONDERWALL_ENCRYPTION_KEY: "{{ .Values.session.cookieEncryptionKey | required ".Values.session.cookieEncryptionKey is required." | b64enc }}"
12 | WONDERWALL_REDIS_PASSWORD: "{{ .Values.valkey.password | required ".Values.valkey.password is required." | b64enc }}"
13 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/servicemonitor.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" }}
2 | ---
3 | apiVersion: monitoring.coreos.com/v1
4 | kind: ServiceMonitor
5 | metadata:
6 | name: {{ include "wonderwall-forward-auth.fullname" . }}
7 | labels: {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
8 | spec:
9 | endpoints:
10 | - interval: 1m
11 | port: http-metrics
12 | scrapeTimeout: 10s
13 | path: "/"
14 | namespaceSelector:
15 | matchNames:
16 | - {{ .Release.Namespace }}
17 | selector:
18 | matchLabels:
19 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }}
20 | {{- end }}
21 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/horizontalpodautoscaler.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: autoscaling/v2
3 | kind: HorizontalPodAutoscaler
4 | metadata:
5 | labels:
6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
7 | name: {{ include "wonderwall-forward-auth.fullname" . }}
8 | spec:
9 | minReplicas: {{ .Values.replicas.min }}
10 | maxReplicas: {{ .Values.replicas.max }}
11 | metrics:
12 | - resource:
13 | name: cpu
14 | target:
15 | averageUtilization: 75
16 | type: Utilization
17 | type: Resource
18 | scaleTargetRef:
19 | apiVersion: apps/v1
20 | kind: Deployment
21 | name: {{ include "wonderwall-forward-auth.fullname" . }}
22 |
--------------------------------------------------------------------------------
/pkg/openid/oauth2_test.go:
--------------------------------------------------------------------------------
1 | package openid
2 |
3 | import (
4 | "net/url"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestStateMismatchError(t *testing.T) {
11 | for _, tt := range []struct {
12 | name, expected, actual string
13 | assertion assert.ErrorAssertionFunc
14 | }{
15 | {"missing actual state", "expected", "", assert.Error},
16 | {"state mismatch", "match", "not-match", assert.Error},
17 | {"state match", "match", "match", assert.NoError},
18 | } {
19 | t.Run(tt.name, func(t *testing.T) {
20 | actual := url.Values{
21 | "state": []string{tt.actual},
22 | }
23 |
24 | err := StateMismatchError(actual, tt.expected)
25 | tt.assertion(t, err)
26 | })
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/templates/README.md:
--------------------------------------------------------------------------------
1 | # Error templates
2 |
3 | This directory contains `.gohtml` templates for static error pages served by Wonderwall.
4 |
5 | These pages are typically only shown on exceptional errors, i.e. invalid configuration or infrastructure errors.
6 | End-users should generally not see these pages unless something is really wrong.
7 |
8 | We embed the CSS directly into the `.gohtml` templates.
9 | This avoids implementing an endpoint to serve the CSS file separately.
10 |
11 | ## Prerequisites
12 |
13 | If you haven't already, [install the Tailwind CSS CLI](https://tailwindcss.com/docs/installation).
14 |
15 | ## Development
16 |
17 | ```shell
18 | make local
19 | ```
20 |
21 | ## Production
22 |
23 | ```shell
24 | make build
25 | ```
26 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/idporten-horizontalpodautoscaler.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.idporten.enabled }}
2 | apiVersion: autoscaling/v2
3 | kind: HorizontalPodAutoscaler
4 | metadata:
5 | labels:
6 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }}
7 | name: {{ include "wonderwall.fullname" . }}-idporten
8 | spec:
9 | minReplicas: {{ .Values.idporten.replicasMin }}
10 | maxReplicas: {{ .Values.idporten.replicasMax }}
11 | metrics:
12 | - resource:
13 | name: cpu
14 | target:
15 | averageUtilization: 75
16 | type: Utilization
17 | type: Resource
18 | scaleTargetRef:
19 | apiVersion: apps/v1
20 | kind: Deployment
21 | name: {{ include "wonderwall.fullname" . }}-idporten
22 | {{- end }}
23 |
--------------------------------------------------------------------------------
/pkg/middleware/cors.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "strings"
7 |
8 | "github.com/rs/cors"
9 |
10 | "github.com/nais/wonderwall/pkg/config"
11 | )
12 |
13 | func Cors(cfg *config.Config, methods []string) func(http.Handler) http.Handler {
14 | ssoDomain := strings.TrimPrefix(cfg.SSO.Domain, ".")
15 |
16 | allowedOrigins := []string{
17 | fmt.Sprintf("https://*.%s", ssoDomain),
18 | fmt.Sprintf("https://%s", ssoDomain),
19 | }
20 |
21 | return cors.New(cors.Options{
22 | AllowedOrigins: allowedOrigins,
23 | AllowedMethods: methods,
24 | AllowCredentials: true,
25 | // This reflects the request headers, essentially allowing all headers.
26 | AllowedHeaders: []string{"*"},
27 | }).Handler
28 | }
29 |
--------------------------------------------------------------------------------
/pkg/openid/scopes/scopes.go:
--------------------------------------------------------------------------------
1 | package scopes
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | const (
9 | OpenID = "openid"
10 | OfflineAccess = "offline_access"
11 | AzureAPITemplate = "api://%s/.default"
12 | )
13 |
14 | type Scopes []string
15 |
16 | func (s Scopes) String() string {
17 | return strings.Join(s, " ")
18 | }
19 |
20 | func (s Scopes) WithAdditional(scopes ...string) Scopes {
21 | return append(s, scopes...)
22 | }
23 |
24 | func (s Scopes) WithAzureScope(clientID string) Scopes {
25 | return append(s, fmt.Sprintf(AzureAPITemplate, clientID))
26 | }
27 |
28 | func (s Scopes) WithOfflineAccess() Scopes {
29 | return append(s, OfflineAccess)
30 | }
31 |
32 | func DefaultScopes() Scopes {
33 | return []string{OpenID}
34 | }
35 |
--------------------------------------------------------------------------------
/pkg/cookie/options.go:
--------------------------------------------------------------------------------
1 | package cookie
2 |
3 | import (
4 | "net/http"
5 | )
6 |
7 | type Options struct {
8 | Domain string
9 | Path string
10 | SameSite http.SameSite
11 | Secure bool
12 | }
13 |
14 | func DefaultOptions() Options {
15 | return Options{
16 | SameSite: http.SameSiteLaxMode,
17 | Secure: true,
18 | }
19 | }
20 |
21 | func (o Options) WithDomain(domain string) Options {
22 | o.Domain = domain
23 | return o
24 | }
25 |
26 | func (o Options) WithPath(path string) Options {
27 | o.Path = path
28 | return o
29 | }
30 |
31 | func (o Options) WithSameSite(sameSite http.SameSite) Options {
32 | o.SameSite = sameSite
33 | return o
34 | }
35 |
36 | func (o Options) WithSecure(secure bool) Options {
37 | o.Secure = secure
38 | return o
39 | }
40 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/fa-horizontalpodautoscaler.yaml:
--------------------------------------------------------------------------------
1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }}
2 | apiVersion: autoscaling/v2
3 | kind: HorizontalPodAutoscaler
4 | metadata:
5 | labels:
6 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }}
7 | name: {{ include "wonderwall.fullname" . }}-fa
8 | spec:
9 | minReplicas: {{ .Values.azure.forwardAuth.replicasMin }}
10 | maxReplicas: {{ .Values.azure.forwardAuth.replicasMax }}
11 | metrics:
12 | - resource:
13 | name: cpu
14 | target:
15 | averageUtilization: 75
16 | type: Utilization
17 | type: Resource
18 | scaleTargetRef:
19 | apiVersion: apps/v1
20 | kind: Deployment
21 | name: {{ include "wonderwall.fullname" . }}-fa
22 | {{- end }}
23 |
--------------------------------------------------------------------------------
/pkg/session/store_memory_test.go:
--------------------------------------------------------------------------------
1 | package session_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 |
8 | "github.com/nais/wonderwall/pkg/session"
9 | )
10 |
11 | func TestMemory(t *testing.T) {
12 | crypter := makeCrypter(t)
13 | data := makeData()
14 | encryptedData, err := data.Encrypt(crypter)
15 | assert.NoError(t, err)
16 |
17 | store := session.NewMemory()
18 | key := "key"
19 |
20 | write(t, store, key, encryptedData)
21 |
22 | decrypted := read(t, store, key, encryptedData, crypter)
23 | decryptedEqual(t, data, decrypted)
24 |
25 | data, encryptedData = update(t, store, key, data, crypter)
26 |
27 | decrypted = read(t, store, key, encryptedData, crypter)
28 | decryptedEqual(t, data, decrypted)
29 |
30 | del(t, store, key)
31 | }
32 |
--------------------------------------------------------------------------------
/templates/error.go:
--------------------------------------------------------------------------------
1 | package templates
2 |
3 | import (
4 | _ "embed"
5 | "html/template"
6 | "io"
7 |
8 | log "github.com/sirupsen/logrus"
9 | )
10 |
11 | //go:embed error.gohtml
12 | var errorGoHtml string
13 | var errorTemplate *template.Template
14 |
15 | type ErrorVariables struct {
16 | CorrelationID string
17 | CSS template.CSS
18 | DefaultRedirectURI string
19 | HttpStatusCode int
20 | RetryURI string
21 | }
22 |
23 | func init() {
24 | var err error
25 |
26 | errorTemplate = template.New("error")
27 | errorTemplate, err = errorTemplate.Parse(errorGoHtml)
28 | if err != nil {
29 | log.Fatalf("parsing error template: %+v", err)
30 | }
31 | }
32 |
33 | func ExecError(w io.Writer, vars ErrorVariables) error {
34 | return errorTemplate.Execute(w, vars)
35 | }
36 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/ingress.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: networking.k8s.io/v1
3 | kind: Ingress
4 | metadata:
5 | annotations:
6 | nginx.ingress.kubernetes.io/proxy-buffer-size: 16k
7 | nginx.ingress.kubernetes.io/enable-global-auth: "false"
8 | labels:
9 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
10 | name: {{ include "wonderwall-forward-auth.fullname" . }}
11 | spec:
12 | ingressClassName: {{ .Values.ingressClassName }}
13 | rules:
14 | - host: {{ .Values.sso.domain }}
15 | http:
16 | paths:
17 | - backend:
18 | service:
19 | name: {{ include "wonderwall-forward-auth.fullname" . }}
20 | port:
21 | number: 80
22 | path: /
23 | pathType: ImplementationSpecific
24 |
--------------------------------------------------------------------------------
/pkg/openid/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | wonderwallconfig "github.com/nais/wonderwall/pkg/config"
5 | )
6 |
7 | type Config interface {
8 | Client() Client
9 | Provider() Provider
10 | }
11 |
12 | type openidconfig struct {
13 | clientConfig Client
14 | providerConfig Provider
15 | }
16 |
17 | func (c *openidconfig) Client() Client {
18 | return c.clientConfig
19 | }
20 |
21 | func (c *openidconfig) Provider() Provider {
22 | return c.providerConfig
23 | }
24 |
25 | func NewConfig(cfg *wonderwallconfig.Config) (Config, error) {
26 | clientCfg, err := NewClientConfig(cfg)
27 | if err != nil {
28 | return nil, err
29 | }
30 |
31 | providerCfg, err := NewProviderConfig(cfg)
32 | if err != nil {
33 | return nil, err
34 | }
35 |
36 | return &openidconfig{
37 | clientConfig: clientCfg,
38 | providerConfig: providerCfg,
39 | }, nil
40 | }
41 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/idporten-ingress.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.idporten.enabled }}
2 | apiVersion: networking.k8s.io/v1
3 | kind: Ingress
4 | metadata:
5 | annotations:
6 | nginx.ingress.kubernetes.io/proxy-buffer-size: 16k
7 | prometheus.io/path: /oauth2/ping
8 | prometheus.io/scrape: "true"
9 | labels:
10 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }}
11 | name: {{ include "wonderwall.fullname" . }}-idporten
12 | spec:
13 | ingressClassName: {{ .Values.idporten.ingressClassName }}
14 | rules:
15 | - host: {{ .Values.idporten.ssoServerHost }}
16 | http:
17 | paths:
18 | - backend:
19 | service:
20 | name: {{ include "wonderwall.fullname" . }}-idporten
21 | port:
22 | number: 80
23 | path: /
24 | pathType: ImplementationSpecific
25 | {{ end }}
26 |
--------------------------------------------------------------------------------
/.github/workflows/dependabot-auto-merge.yaml:
--------------------------------------------------------------------------------
1 | name: Dependabot auto-merge
2 | on: pull_request
3 |
4 | permissions:
5 | contents: write
6 | pull-requests: write
7 |
8 | jobs:
9 | dependabot:
10 | runs-on: ubuntu-latest
11 | if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' }}
12 | steps:
13 | - name: Dependabot metadata
14 | id: metadata
15 | uses: dependabot/fetch-metadata@08eff52bf64351f401fb50d4972fa95b9f2c2d1b # ratchet:dependabot/fetch-metadata@v2
16 | with:
17 | github-token: "${{ secrets.GITHUB_TOKEN }}"
18 | - name: Enable auto-merge for Dependabot PRs
19 | if: steps.metadata.outputs.update-type != 'version-update:semver-major'
20 | run: gh pr merge --auto --squash "$PR_URL"
21 | env:
22 | PR_URL: ${{github.event.pull_request.html_url}}
23 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
24 |
--------------------------------------------------------------------------------
/pkg/middleware/ingress.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/nais/wonderwall/pkg/ingress"
7 | )
8 |
9 | type IngressSource interface {
10 | GetIngresses() *ingress.Ingresses
11 | }
12 |
13 | type IngressMiddleware struct {
14 | IngressSource
15 | }
16 |
17 | func Ingress(source IngressSource) IngressMiddleware {
18 | return IngressMiddleware{IngressSource: source}
19 | }
20 |
21 | func (i *IngressMiddleware) Handler(next http.Handler) http.Handler {
22 | fn := func(w http.ResponseWriter, r *http.Request) {
23 | ingresses := i.GetIngresses()
24 | ctx := r.Context()
25 |
26 | path := ingresses.MatchingPath(r)
27 | ctx = WithPath(ctx, path)
28 |
29 | matchingIngress, ok := ingresses.MatchingIngress(r)
30 | if ok {
31 | ctx = WithIngress(ctx, matchingIngress)
32 | }
33 |
34 | next.ServeHTTP(w, r.WithContext(ctx))
35 | }
36 | return http.HandlerFunc(fn)
37 | }
38 |
--------------------------------------------------------------------------------
/pkg/session/lock_test.go:
--------------------------------------------------------------------------------
1 | package session_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/alicebob/miniredis/v2"
9 | "github.com/redis/go-redis/v9"
10 | "github.com/stretchr/testify/assert"
11 |
12 | "github.com/nais/wonderwall/pkg/session"
13 | )
14 |
15 | func TestRedisLock(t *testing.T) {
16 | s, err := miniredis.Run()
17 | if err != nil {
18 | panic(err)
19 | }
20 | defer s.Close()
21 |
22 | client := redis.NewClient(&redis.Options{
23 | Network: "tcp",
24 | Addr: s.Addr(),
25 | })
26 |
27 | key := "some-key"
28 | ctx := context.Background()
29 | lock := session.NewRedisLock(client, key)
30 |
31 | err = lock.Acquire(ctx, time.Minute)
32 | assert.NoError(t, err)
33 |
34 | err = lock.Acquire(ctx, time.Minute)
35 | assert.Error(t, err)
36 | assert.ErrorIs(t, err, session.ErrAcquireLock)
37 |
38 | err = lock.Release(ctx)
39 | assert.NoError(t, err)
40 | }
41 |
--------------------------------------------------------------------------------
/pkg/config/ratelimit.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "time"
5 |
6 | flag "github.com/spf13/pflag"
7 | )
8 |
9 | type RateLimit struct {
10 | Enabled bool `json:"enabled"`
11 | Logins int `json:"logins"`
12 | Window time.Duration `json:"window"`
13 | }
14 |
15 | const (
16 | RateLimitEnabled = "ratelimit.enabled"
17 | RateLimitLogins = "ratelimit.logins"
18 | RateLimitWindow = "ratelimit.window"
19 | )
20 |
21 | func rateLimitFlags() {
22 | flag.Bool(RateLimitEnabled, true, "Enable rate limiting per user-agent.")
23 | flag.Int(RateLimitLogins, 5, "Maximum permitted login attempts within 'ratelimit.window' before rate limiting.")
24 | flag.Duration(RateLimitWindow, 5*time.Second, "Time window for counting consecutive attempts towards rate limit."+
25 | "Each attempt within the window will increment the attempt counter and reset the window."+
26 | "If the window expires with no additional attempts, the counter is discarded.")
27 | }
28 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/fa-ingress.yaml:
--------------------------------------------------------------------------------
1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }}
2 | apiVersion: networking.k8s.io/v1
3 | kind: Ingress
4 | metadata:
5 | annotations:
6 | nginx.ingress.kubernetes.io/proxy-buffer-size: 16k
7 | nginx.ingress.kubernetes.io/enable-global-auth: "false"
8 | prometheus.io/path: /oauth2/ping
9 | prometheus.io/scrape: "true"
10 | labels:
11 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }}
12 | name: {{ include "wonderwall.fullname" . }}-fa
13 | spec:
14 | ingressClassName: {{ .Values.azure.forwardAuth.ingressClassName }}
15 | rules:
16 | - host: {{ .Values.azure.forwardAuth.ssoDomain }}
17 | http:
18 | paths:
19 | - backend:
20 | service:
21 | name: {{ include "wonderwall.fullname" . }}-fa
22 | port:
23 | number: 80
24 | path: /
25 | pathType: ImplementationSpecific
26 | {{- end }}
27 |
--------------------------------------------------------------------------------
/internal/o11y/logging/logging.go:
--------------------------------------------------------------------------------
1 | package logging
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | log "github.com/sirupsen/logrus"
8 | )
9 |
10 | func textFormatter() log.Formatter {
11 | return &log.TextFormatter{
12 | DisableTimestamp: false,
13 | FullTimestamp: true,
14 | TimestampFormat: time.RFC3339Nano,
15 | }
16 | }
17 |
18 | func jsonFormatter() log.Formatter {
19 | return &log.JSONFormatter{
20 | TimestampFormat: time.RFC3339Nano,
21 | }
22 | }
23 |
24 | func Setup(level, format string) error {
25 | switch format {
26 | case "json":
27 | log.SetFormatter(jsonFormatter())
28 | case "text":
29 | log.SetFormatter(textFormatter())
30 | default:
31 | return fmt.Errorf("log format '%s' is not recognized", format)
32 | }
33 |
34 | logLevel, err := log.ParseLevel(level)
35 | if err != nil {
36 | return fmt.Errorf("while setting log level: %s", err)
37 | }
38 |
39 | log.SetLevel(logLevel)
40 | log.Tracef("Trace logging enabled")
41 |
42 | return nil
43 | }
44 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/fa-azureapp.yaml:
--------------------------------------------------------------------------------
1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }}
2 | ---
3 | apiVersion: nais.io/v1
4 | kind: AzureAdApplication
5 | metadata:
6 | {{- if .Values.resourceSuffix }}
7 | name: {{ include "wonderwall.fullname" . }}-fa-{{ .Values.resourceSuffix }}
8 | {{- else }}
9 | name: {{ include "wonderwall.fullname" . }}-fa
10 | {{- end }}
11 | labels:
12 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }}
13 | spec:
14 | secretName: {{ .Values.azure.forwardAuth.clientSecretName }}
15 | allowAllUsers: true
16 | {{- if .Values.azure.forwardAuth.groupIds }}
17 | claims:
18 | groups:
19 | {{- range .Values.azure.forwardAuth.groupIds }}
20 | - id: {{ . }}
21 | {{- end }}
22 | {{- end }}
23 | logoutUrl: "{{ include "wonderwall.azure.forwardAuthURL" . }}/oauth2/logout/frontchannel"
24 | replyUrls:
25 | - url: "{{- include "wonderwall.azure.forwardAuthURL" . }}/oauth2/callback"
26 | tenant: nav.no
27 | {{- end }}
28 |
--------------------------------------------------------------------------------
/pkg/session/store_redis_test.go:
--------------------------------------------------------------------------------
1 | package session_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/alicebob/miniredis/v2"
7 | "github.com/redis/go-redis/v9"
8 | "github.com/stretchr/testify/assert"
9 |
10 | "github.com/nais/wonderwall/pkg/session"
11 | )
12 |
13 | func TestRedis(t *testing.T) {
14 | crypter := makeCrypter(t)
15 | data := makeData()
16 | encryptedData, err := data.Encrypt(crypter)
17 | assert.NoError(t, err)
18 |
19 | s, err := miniredis.Run()
20 | if err != nil {
21 | panic(err)
22 | }
23 | defer s.Close()
24 |
25 | client := redis.NewClient(&redis.Options{
26 | Network: "tcp",
27 | Addr: s.Addr(),
28 | })
29 |
30 | store := session.NewRedis(client)
31 | key := "key"
32 |
33 | write(t, store, key, encryptedData)
34 |
35 | decrypted := read(t, store, key, encryptedData, crypter)
36 | decryptedEqual(t, data, decrypted)
37 |
38 | data, encryptedData = update(t, store, key, data, crypter)
39 |
40 | decrypted = read(t, store, key, encryptedData, crypter)
41 | decryptedEqual(t, data, decrypted)
42 |
43 | del(t, store, key)
44 | }
45 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "gomod"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
7 | day: "monday"
8 | time: "09:00"
9 | timezone: "Europe/Oslo"
10 | groups:
11 | otel:
12 | patterns:
13 | - 'go.opentelemetry.io/*'
14 | redis:
15 | patterns:
16 | - 'github.com/redis/go-redis/*'
17 | cooldown:
18 | default-days: 7
19 | - package-ecosystem: "github-actions"
20 | directory: "/"
21 | schedule:
22 | interval: "weekly"
23 | day: "monday"
24 | time: "09:00"
25 | timezone: "Europe/Oslo"
26 | groups:
27 | gh-actions:
28 | patterns:
29 | - '*'
30 | cooldown:
31 | default-days: 7
32 | - package-ecosystem: "docker"
33 | directory: "/"
34 | schedule:
35 | interval: "weekly"
36 | day: "monday"
37 | time: "09:00"
38 | timezone: "Europe/Oslo"
39 | groups:
40 | docker:
41 | patterns:
42 | - '*'
43 | cooldown:
44 | default-days: 7
45 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/values.yaml:
--------------------------------------------------------------------------------
1 | nameOverride: ""
2 | fullnameOverride: ""
3 |
4 | image:
5 | repository: europe-north1-docker.pkg.dev/nais-io/nais/images/wonderwall
6 | tag: latest
7 |
8 | # mapped by fasit
9 | fasit:
10 | tenant:
11 | name:
12 |
13 | resources:
14 | limits:
15 | cpu: "2"
16 | memory: 512Mi
17 | requests:
18 | cpu: 100m
19 | memory: 64Mi
20 | replicas:
21 | min: 2
22 | max: 4
23 | podDisruptionBudget:
24 | maxUnavailable: 1
25 | ingressClassName: nais-ingress-fa
26 | otel:
27 | endpoint: http://opentelemetry-management-collector.nais-system:4317
28 |
29 | openid:
30 | clientID:
31 | clientSecret:
32 | extraAudience:
33 | extraScopes:
34 | wellKnownURL: https://auth.nais.io/.well-known/openid-configuration
35 | session:
36 | maxLifetime: 10h
37 | # 256 bits key, in standard base64 encoding
38 | cookieEncryptionKey:
39 | cookieName: nais-io-forward-auth
40 | sso:
41 | defaultRedirectURL:
42 | domain:
43 |
44 | valkey:
45 | host:
46 | port:
47 | username:
48 | password:
49 | connectionIdleTimeoutSeconds: 299
50 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 NAV
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/internal/o11y/otel/middleware.go:
--------------------------------------------------------------------------------
1 | package otel
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 |
8 | chi_middleware "github.com/go-chi/chi/v5/middleware"
9 | httpinternal "github.com/nais/wonderwall/internal/http"
10 | "go.opentelemetry.io/otel/attribute"
11 | "go.opentelemetry.io/otel/trace"
12 | )
13 |
14 | func Middleware(next http.Handler) http.Handler {
15 | fn := func(w http.ResponseWriter, r *http.Request) {
16 | ctx := r.Context()
17 | span := trace.SpanFromContext(ctx)
18 |
19 | attrs := httpinternal.Attributes(r)
20 | for k, v := range attrs {
21 | switch v := v.(type) {
22 | case bool:
23 | span.SetAttributes(attribute.Bool(k, v))
24 | default:
25 | span.SetAttributes(attribute.String(k, fmt.Sprint(v)))
26 | }
27 | }
28 |
29 | // Override request ID with trace ID if available.
30 | if span.SpanContext().HasTraceID() {
31 | id := span.SpanContext().TraceID().String()
32 | ctx = context.WithValue(ctx, chi_middleware.RequestIDKey, id)
33 | next.ServeHTTP(w, r.WithContext(ctx))
34 | } else {
35 | next.ServeHTTP(w, r)
36 | }
37 | }
38 |
39 | return http.HandlerFunc(fn)
40 | }
41 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/idporten-idportenclient.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.idporten.enabled }}
2 | ---
3 | apiVersion: nais.io/v1
4 | kind: IDPortenClient
5 | metadata:
6 | {{- if .Values.resourceSuffix }}
7 | name: {{ include "wonderwall.fullname" . }}-idporten-{{ .Values.resourceSuffix }}
8 | {{- else }}
9 | name: {{ include "wonderwall.fullname" . }}-idporten
10 | {{- end }}
11 | labels:
12 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }}
13 | annotations:
14 | "digdir.nais.io/preserve": "true"
15 | "helm.sh/resource-policy": "keep"
16 | spec:
17 | clientURI: "{{ .Values.idporten.ssoDefaultRedirectURL }}"
18 | redirectURIs:
19 | - "{{- include "wonderwall.idporten.ssoServerURL" . }}/oauth2/callback"
20 | secretName: "{{ .Values.idporten.clientSecretName }}"
21 | frontchannelLogoutURI: "{{ include "wonderwall.idporten.ssoServerURL" . }}/oauth2/logout/frontchannel"
22 | postLogoutRedirectURIs:
23 | - "{{- include "wonderwall.idporten.ssoServerURL" . }}/oauth2/logout/callback"
24 | accessTokenLifetime: {{ .Values.idporten.clientAccessTokenLifetime }}
25 | sessionLifetime: {{ .Values.idporten.clientSessionLifetime }}
26 | {{- end }}
27 |
--------------------------------------------------------------------------------
/internal/retry/retry.go:
--------------------------------------------------------------------------------
1 | package retry
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/sethvargo/go-retry"
8 | )
9 |
10 | var RetryableError = retry.RetryableError
11 |
12 | type fibonacci struct {
13 | base time.Duration
14 | max time.Duration
15 | }
16 |
17 | type Option func(*fibonacci)
18 |
19 | func WithBase(d time.Duration) Option {
20 | return func(f *fibonacci) {
21 | f.base = d
22 | }
23 | }
24 |
25 | func WithMax(d time.Duration) Option {
26 | return func(f *fibonacci) {
27 | f.max = d
28 | }
29 | }
30 |
31 | func Do(ctx context.Context, f retry.RetryFunc, opts ...Option) error {
32 | return retry.Do(ctx, fibonacciBackoff(opts...), f)
33 | }
34 |
35 | func DoValue[T any](ctx context.Context, f retry.RetryFuncValue[T], opts ...Option) (T, error) {
36 | return retry.DoValue(ctx, fibonacciBackoff(opts...), f)
37 | }
38 |
39 | func fibonacciBackoff(opts ...Option) retry.Backoff {
40 | f := &fibonacci{
41 | base: 50 * time.Millisecond,
42 | max: 5 * time.Second,
43 | }
44 |
45 | for _, opt := range opts {
46 | opt(f)
47 | }
48 |
49 | b := retry.NewFibonacci(f.base)
50 | // beware: this starts a timer when invoked, on which the max duration is evaluated against
51 | b = retry.WithMaxDuration(f.max, b)
52 | return b
53 | }
54 |
--------------------------------------------------------------------------------
/pkg/config/otel.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "os"
5 |
6 | flag "github.com/spf13/pflag"
7 | "github.com/spf13/viper"
8 | )
9 |
10 | type OpenTelemetry struct {
11 | Enabled bool `json:"enabled"`
12 | ServiceName string `json:"service-name"`
13 | }
14 |
15 | const (
16 | OpenTelemetryEnabled = "otel.enabled"
17 | OpenTelemetryServiceName = "otel.service-name"
18 | )
19 |
20 | func otelFlags() {
21 | flag.Bool(OpenTelemetryEnabled, false, "Enable OpenTelemetry tracing. Automatically enabled if OTEL_EXPORTER_OTLP_ENDPOINT is set.")
22 | flag.String(OpenTelemetryServiceName, "wonderwall", "Service name to use for OpenTelemetry. The OTEL_SERVICE_NAME environment variable overrides this value.")
23 | }
24 |
25 | func resolveOtel() {
26 | otelEndpoint := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT")
27 | if otelEndpoint != "" {
28 | logger.Debugf("config: OTLP endpoint set to %q, enabling OpenTelemetry", otelEndpoint)
29 | viper.Set(OpenTelemetryEnabled, "true")
30 | }
31 |
32 | otelServiceName := os.Getenv("OTEL_SERVICE_NAME")
33 | if otelServiceName != "" {
34 | logger.Debugf("config: OTEL_SERVICE_NAME set to %q; overriding %q flag", otelServiceName, OpenTelemetryServiceName)
35 | viper.Set(OpenTelemetryServiceName, otelServiceName)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/redis.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.azure.enabled }}
2 | {{ $provider := "azure" }}
3 | {{ include "common.redis.tpl" (dict "root" . "provider" $provider) }}
4 | {{ include "common.serviceintegration.tpl" (dict "root" . "provider" $provider) }}
5 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "readwrite" "secretName" .Values.azure.redisSecretName) }}
6 | {{- end }}
7 |
8 | {{- if .Values.idporten.enabled }}
9 | {{ $provider := "idporten" }}
10 | {{ include "common.redis.tpl" (dict "root" . "provider" $provider) }}
11 | {{ include "common.serviceintegration.tpl" (dict "root" . "provider" $provider) }}
12 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "readwrite" "secretName" .Values.idporten.redisSecretNames.readwrite) }}
13 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "read" "secretName" .Values.idporten.redisSecretNames.read) }}
14 | {{- end }}
15 |
16 | {{- if .Values.openid.enabled }}
17 | {{ $provider := "openid" }}
18 | {{ include "common.redis.tpl" (dict "root" . "provider" $provider) }}
19 | {{ include "common.serviceintegration.tpl" (dict "root" . "provider" $provider) }}
20 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "readwrite" "secretName" .Values.openid.redisSecretName) }}
21 | {{- end }}
22 |
--------------------------------------------------------------------------------
/pkg/openid/client/logout_callback.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 |
7 | "github.com/nais/wonderwall/pkg/openid"
8 | urlpkg "github.com/nais/wonderwall/pkg/url"
9 | )
10 |
11 | type LogoutCallback struct {
12 | *Client
13 | cookie *openid.LogoutCookie
14 | validator urlpkg.Validator
15 | request *http.Request
16 | }
17 |
18 | func NewLogoutCallback(c *Client, r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback {
19 | return &LogoutCallback{
20 | Client: c,
21 | cookie: cookie,
22 | validator: validator,
23 | request: r,
24 | }
25 | }
26 |
27 | func (in *LogoutCallback) PostLogoutRedirectURI() string {
28 | if in.cookie != nil && in.stateMismatchError() == nil && in.validator.IsValidRedirect(in.request, in.cookie.RedirectTo) {
29 | return in.cookie.RedirectTo
30 | }
31 |
32 | defaultRedirect := in.cfg.Client().PostLogoutRedirectURI()
33 | if defaultRedirect != "" {
34 | return defaultRedirect
35 | }
36 |
37 | ingress, err := urlpkg.MatchingIngress(in.request)
38 | if err != nil {
39 | return "/"
40 | }
41 |
42 | return ingress.String()
43 | }
44 |
45 | func (in *LogoutCallback) stateMismatchError() error {
46 | if in.cookie == nil {
47 | return fmt.Errorf("logout cookie is nil")
48 | }
49 |
50 | return openid.StateMismatchError(in.request.URL.Query(), in.cookie.State)
51 | }
52 |
--------------------------------------------------------------------------------
/templates/input.css:
--------------------------------------------------------------------------------
1 | @import 'tailwindcss';
2 |
3 | @theme {
4 | --color-primary-50: #ffe6e6;
5 | --color-primary-100: #ffc2c2;
6 | --color-primary-200: #f68282;
7 | --color-primary-300: #f25c5c;
8 | --color-primary-400: #de2e2e;
9 | --color-primary-500: #c30000;
10 | --color-primary-600: #ad0000;
11 | --color-primary-700: #8c0000;
12 | --color-primary-800: #5c0000;
13 | --color-primary-900: #260000;
14 | --color-primary-950: #180000;
15 |
16 | --color-action-50: #e6f0ff;
17 | --color-action-100: #cce1ff;
18 | --color-action-200: #99c3ff;
19 | --color-action-300: #66a5f4;
20 | --color-action-400: #3386e0;
21 | --color-action-500: #0067c5;
22 | --color-action-600: #0056b4;
23 | --color-action-700: #00459c;
24 | --color-action-800: #00347d;
25 | --color-action-900: #002252;
26 | --color-action-950: #00131a;
27 | }
28 |
29 | /*
30 | The default border color has changed to `currentColor` in Tailwind CSS v4,
31 | so we've added these compatibility styles to make sure everything still
32 | looks the same as it did with Tailwind CSS v3.
33 |
34 | If we ever want to remove these styles, we need to add an explicit border
35 | color utility to any element that depends on these defaults.
36 | */
37 | @layer base {
38 | *,
39 | ::after,
40 | ::before,
41 | ::backdrop,
42 | ::file-selector-button {
43 | border-color: var(--color-gray-200, currentColor);
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/mise.toml:
--------------------------------------------------------------------------------
1 | [tools]
2 | go = "1.25.5"
3 |
4 | [tasks.build]
5 | run = "go build -a -o bin/wonderwall ./cmd/wonderwall"
6 |
7 | [tasks.local]
8 | depends = ["fmt"]
9 | env = { OTEL_EXPORTER_OTLP_ENDPOINT = "http://localhost:4317" }
10 | run = '''go run cmd/wonderwall/main.go \
11 | --openid.client-id=bogus \
12 | --openid.client-secret=not-so-secret \
13 | --openid.well-known-url=http://localhost:8888/default/.well-known/openid-configuration \
14 | --ingress=http://localhost:3000 \
15 | --bind-address=127.0.0.1:3000 \
16 | --upstream-host=localhost:4000 \
17 | --redis.uri=redis://localhost:6379 \
18 | --log-level=info \
19 | --log-format=text
20 | '''
21 |
22 | [tasks.fmt]
23 | run = "go tool gofumpt -w ./"
24 |
25 | [tasks.test]
26 | depends = ["fmt"]
27 | run = "go test -count=1 -shuffle=on ./... -coverprofile cover.out"
28 |
29 | [tasks.check]
30 | depends = ["fmt"]
31 | run = [
32 | "go vet ./...",
33 | "go tool staticcheck ./...",
34 | "go tool govulncheck -show=traces ./...",
35 | "go tool ratchet lint .github/workflows/*.yaml"
36 | ]
37 |
38 | [tasks.actions-update]
39 | description = "Upgrade all github actions to latest version satisfying their version tag"
40 | run = "go tool ratchet update .github/workflows/*.yaml"
41 |
42 | [tasks.actions-upgrade]
43 | description = "Upgrade all github actions to latest"
44 | run = "go tool ratchet upgrade .github/workflows/*.yaml"
45 |
--------------------------------------------------------------------------------
/pkg/handler/handler_sso_server.go:
--------------------------------------------------------------------------------
1 | package handler
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/nais/wonderwall/pkg/config"
7 | "github.com/nais/wonderwall/pkg/cookie"
8 | "github.com/nais/wonderwall/pkg/router"
9 | "github.com/nais/wonderwall/pkg/url"
10 | )
11 |
12 | var _ router.Source = &SSOServer{}
13 |
14 | type SSOServer struct {
15 | *Standalone
16 | }
17 |
18 | func NewSSOServer(cfg *config.Config, handler *Standalone) (*SSOServer, error) {
19 | redirect, err := url.NewSSOServerRedirect(cfg)
20 | if err != nil {
21 | return nil, err
22 | }
23 |
24 | handler.Redirect = redirect
25 | handler.CookieOptions = cookie.DefaultOptions().
26 | WithPath("/").
27 | WithDomain(cfg.SSO.Domain).
28 | WithSameSite(cfg.Cookie.SameSite.ToHttp()).
29 | WithSecure(cfg.Cookie.Secure)
30 |
31 | return &SSOServer{Standalone: handler}, nil
32 | }
33 |
34 | func (s *SSOServer) Logout(w http.ResponseWriter, r *http.Request) {
35 | s.Standalone.Logout(w, r)
36 | }
37 |
38 | func (s *SSOServer) LogoutFrontChannel(w http.ResponseWriter, r *http.Request) {
39 | s.Standalone.LogoutFrontChannel(w, r)
40 | }
41 |
42 | func (s *SSOServer) LogoutLocal(w http.ResponseWriter, r *http.Request) {
43 | s.Standalone.LogoutLocal(w, r)
44 | }
45 |
46 | // Wildcard redirects unhandled requests to the default redirect URL.
47 | func (s *SSOServer) Wildcard(w http.ResponseWriter, r *http.Request) {
48 | http.Redirect(w, r, s.Config.SSO.ServerDefaultRedirectURL, http.StatusFound)
49 | }
50 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/prometheusrule.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" }}
2 | ---
3 | apiVersion: monitoring.coreos.com/v1
4 | kind: PrometheusRule
5 | metadata:
6 | name: {{ include "wonderwall-forward-auth.fullname" . }}
7 | labels:
8 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
9 | spec:
10 | groups:
11 | - name: "wonderwall-forward-auth"
12 | rules:
13 | - alert: wonderwall-forward-auth (Zitadel) reports a high amount of internal errors
14 | expr: sum(increase(requests_total{service="{{ include "wonderwall-forward-auth.fullname" . }}", namespace="{{ .Release.Namespace }}", code="500"}[5m])) > 30
15 | for: 5m
16 | annotations:
17 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes.
18 | consequence: This probably means that end-users are having trouble with authentication.
19 | action: |
20 | * Check the logs and metrics in the dashboard
21 | * Check Aiven Valkey (session store) and verify Aiven network connectivity
22 | * Check the Zitadel dashboard:
23 | dashboard_url: "https://monitoring.nais.io/d/ben86a369fj7kd?var-tenant={{ .Values.fasit.tenant.name }}"
24 | labels:
25 | severity: critical
26 | namespace: {{ .Release.Namespace }}
27 | {{ end }}
28 |
--------------------------------------------------------------------------------
/pkg/openid/cookies.go:
--------------------------------------------------------------------------------
1 | package openid
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 |
8 | "github.com/nais/wonderwall/internal/crypto"
9 | "github.com/nais/wonderwall/pkg/cookie"
10 | )
11 |
12 | type LoginCookie struct {
13 | Acr string `json:"acr"`
14 | CodeVerifier string `json:"code_verifier"`
15 | Nonce string `json:"nonce"`
16 | RedirectURI string `json:"redirect_uri"`
17 | Referer string `json:"referer"`
18 | State string `json:"state"`
19 | }
20 |
21 | type LogoutCookie struct {
22 | State string `json:"state"`
23 | RedirectTo string `json:"redirect_to"`
24 | }
25 |
26 | func GetLoginCookie(r *http.Request, crypter crypto.Crypter) (*LoginCookie, error) {
27 | loginCookieJson, err := cookie.GetDecrypted(r, cookie.Login, crypter)
28 | if err != nil {
29 | return nil, err
30 | }
31 |
32 | var loginCookie LoginCookie
33 | err = json.Unmarshal([]byte(loginCookieJson), &loginCookie)
34 | if err != nil {
35 | return nil, fmt.Errorf("unmarshalling: %w", err)
36 | }
37 |
38 | return &loginCookie, nil
39 | }
40 |
41 | func GetLogoutCookie(r *http.Request, crypter crypto.Crypter) (*LogoutCookie, error) {
42 | logoutCookieJson, err := cookie.GetDecrypted(r, cookie.Logout, crypter)
43 | if err != nil {
44 | return nil, err
45 | }
46 |
47 | var logoutCookie LogoutCookie
48 | err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie)
49 | if err != nil {
50 | return nil, fmt.Errorf("unmarshalling: %w", err)
51 | }
52 |
53 | return &logoutCookie, nil
54 | }
55 |
--------------------------------------------------------------------------------
/pkg/openid/acr/acr_test.go:
--------------------------------------------------------------------------------
1 | package acr
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestValidateAcr(t *testing.T) {
10 | for _, tt := range []struct {
11 | name string
12 | expected string
13 | actual string
14 | wantErr bool
15 | }{
16 | {"no mapping found, not equal", "some-value", "some-other-value", true},
17 | {"no mapping found, expected equals actual", "some-value", "some-value", false},
18 | {"Level3, higher acr accepted", "Level3", "idporten-loa-high", false},
19 | {"Level3, no matching value", "Level3", "Level2", true},
20 | {"Level3 -> idporten-loa-substantial", "Level3", "idporten-loa-substantial", false},
21 | {"idporten-loa-substantial", "idporten-loa-substantial", "idporten-loa-substantial", false},
22 | {"idporten-loa-substantial, higher acr accepted", "idporten-loa-substantial", "idporten-loa-high", false},
23 | {"Level4, lower acr not accepted", "Level4", "idporten-loa-substantial", true},
24 | {"Level4, no matching value", "Level4", "Level5", true},
25 | {"Level4 -> idporten-loa-high", "Level4", "idporten-loa-high", false},
26 | {"idporten-loa-high", "idporten-loa-high", "idporten-loa-high", false},
27 | {"idporten-loa-high, lower acr not accepted", "idporten-loa-high", "idporten-loa-substantial", true},
28 | } {
29 | t.Run(tt.name, func(t *testing.T) {
30 | err := Validate(tt.expected, tt.actual)
31 | if tt.wantErr {
32 | assert.Error(t, err)
33 | } else {
34 | assert.NoError(t, err)
35 | }
36 | })
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/pkg/session/store_memory.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | "time"
8 | )
9 |
10 | type memorySessionStore struct {
11 | lock sync.Mutex
12 | sessions map[string]*EncryptedData
13 | }
14 |
15 | var _ Store = &memorySessionStore{}
16 |
17 | func NewMemory() Store {
18 | return &memorySessionStore{
19 | sessions: make(map[string]*EncryptedData),
20 | }
21 | }
22 |
23 | func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData, error) {
24 | s.lock.Lock()
25 | defer s.lock.Unlock()
26 |
27 | data, ok := s.sessions[key]
28 | if !ok {
29 | return nil, fmt.Errorf("%w: no such session: %s", ErrNotFound, key)
30 | }
31 |
32 | return data, nil
33 | }
34 |
35 | func (s *memorySessionStore) Write(_ context.Context, key string, value *EncryptedData, _ time.Duration) error {
36 | s.lock.Lock()
37 | defer s.lock.Unlock()
38 |
39 | s.sessions[key] = value
40 | return nil
41 | }
42 |
43 | func (s *memorySessionStore) Delete(_ context.Context, keys ...string) error {
44 | s.lock.Lock()
45 | defer s.lock.Unlock()
46 |
47 | for _, key := range keys {
48 | delete(s.sessions, key)
49 | }
50 |
51 | return nil
52 | }
53 |
54 | func (s *memorySessionStore) Update(ctx context.Context, key string, value *EncryptedData) error {
55 | _, err := s.Read(ctx, key)
56 | if err != nil {
57 | return err
58 | }
59 |
60 | s.lock.Lock()
61 | defer s.lock.Unlock()
62 |
63 | s.sessions[key] = value
64 | return nil
65 | }
66 |
67 | func (s *memorySessionStore) MakeLock(_ string) Lock {
68 | return NewNoOpLock()
69 | }
70 |
--------------------------------------------------------------------------------
/internal/http/middleware.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "net/http"
5 |
6 | "go.opentelemetry.io/otel/attribute"
7 | "go.opentelemetry.io/otel/trace"
8 | )
9 |
10 | // DisallowNonNavigationalRequests checks if the request is non-navigational, and if so, responds with a 401.
11 | // We do this to separate between redirects for browser navigation and redirects for resource requests.
12 | //
13 | // This should only be used for endpoints that are only supposed to be _navigated to_ from a browser.
14 | // The 401 response prevents redirecting non-navigation requests to the identity provider, which usually results in
15 | // a CORS error for typical Fetch or XHR requests from the browser.
16 | //
17 | // This depends on the presence of the Fetch metadata headers, mostly present in modern browsers.
18 | // For compatibility with older browsers, requests without these headers are still allowed to pass through.
19 | func DisallowNonNavigationalRequests(next http.Handler) http.Handler {
20 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21 | if HasSecFetchMetadata(r) && !IsNavigationRequest(r) {
22 | span := trace.SpanFromContext(r.Context())
23 | span.SetAttributes(attribute.Bool("request.disallowed", true))
24 |
25 | w.Header().Set("Content-Type", "application/json")
26 | w.WriteHeader(http.StatusUnauthorized)
27 | w.Write([]byte(`{"error": "unauthenticated", "error_description": "this is an interactive endpoint; user-agents must be navigated to this endpoint", "error_path": "` + r.URL.Path + `"}`))
28 | return
29 | }
30 |
31 | next.ServeHTTP(w, r)
32 | })
33 | }
34 |
--------------------------------------------------------------------------------
/internal/crypto/jwk.go:
--------------------------------------------------------------------------------
1 | package crypto
2 |
3 | import (
4 | "crypto/rand"
5 | "crypto/rsa"
6 | "fmt"
7 |
8 | "github.com/lestrrat-go/jwx/v3/jwa"
9 | "github.com/lestrrat-go/jwx/v3/jwk"
10 | )
11 |
12 | type JwkSet struct {
13 | Private jwk.Set
14 | Public jwk.Set
15 | }
16 |
17 | func NewJwk() (jwk.Key, error) {
18 | privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
19 | if err != nil {
20 | return nil, fmt.Errorf("generating key: %w", err)
21 | }
22 |
23 | key, err := jwk.Import(privateKey)
24 | if err != nil {
25 | return nil, fmt.Errorf("importing key: %w", err)
26 | }
27 |
28 | err = key.Set(jwk.AlgorithmKey, jwa.RS256().String())
29 | if err != nil {
30 | return nil, fmt.Errorf("setting algorithm: %w", err)
31 | }
32 |
33 | err = key.Set(jwk.KeyTypeKey, jwa.RSA().String())
34 | if err != nil {
35 | return nil, fmt.Errorf("setting key type: %w", err)
36 | }
37 |
38 | err = jwk.AssignKeyID(key)
39 | if err != nil {
40 | return nil, fmt.Errorf("assigning key id: %w", err)
41 | }
42 |
43 | return key, nil
44 | }
45 |
46 | func NewJwkSet() (*JwkSet, error) {
47 | key, err := NewJwk()
48 | if err != nil {
49 | return nil, fmt.Errorf("creating jwk: %w", err)
50 | }
51 |
52 | privateKeys := jwk.NewSet()
53 | err = privateKeys.AddKey(key)
54 | if err != nil {
55 | return nil, fmt.Errorf("adding key to set: %w", err)
56 | }
57 |
58 | publicKeys, err := jwk.PublicSetOf(privateKeys)
59 | if err != nil {
60 | return nil, fmt.Errorf("creating public set: %w", err)
61 | }
62 |
63 | return &JwkSet{
64 | Private: privateKeys,
65 | Public: publicKeys,
66 | }, nil
67 | }
68 |
--------------------------------------------------------------------------------
/pkg/session/id.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 |
7 | "github.com/nais/wonderwall/internal/crypto"
8 | "github.com/nais/wonderwall/pkg/openid"
9 | openidconfig "github.com/nais/wonderwall/pkg/openid/config"
10 | )
11 |
12 | // ExternalID returns the external session ID, derived from the given request or id_token; e.g. `sid` or `session_state`.
13 | // If none are present, a generated ID is returned.
14 | func ExternalID(r *http.Request, cfg openidconfig.Provider, idToken *openid.IDToken) (string, error) {
15 | // 1. check for 'sid' claim in id_token
16 | sessionID, err := idToken.Sid()
17 | if err == nil {
18 | return sessionID, nil
19 | }
20 | // 1a. error if sid claim is required according to openid config
21 | if err != nil && cfg.SidClaimRequired() {
22 | return "", err
23 | }
24 |
25 | // 2. check for session_state in callback params
26 | sessionID, err = getSessionStateFrom(r)
27 | if err == nil {
28 | return sessionID, nil
29 | }
30 | // 2a. error if session_state is required according to openid config
31 | if err != nil && cfg.SessionStateRequired() {
32 | return "", err
33 | }
34 |
35 | // 3. generate ID if all else fails
36 | sessionID, err = crypto.Text(64)
37 | if err != nil {
38 | return "", fmt.Errorf("generating session ID: %w", err)
39 | }
40 | return sessionID, nil
41 | }
42 |
43 | func getSessionStateFrom(r *http.Request) (string, error) {
44 | params := r.URL.Query()
45 |
46 | sessionState := params.Get("session_state")
47 | if len(sessionState) == 0 {
48 | return "", fmt.Errorf("missing required 'session_state' in params")
49 | }
50 | return sessionState, nil
51 | }
52 |
--------------------------------------------------------------------------------
/internal/crypto/crypter_test.go:
--------------------------------------------------------------------------------
1 | package crypto_test
2 |
3 | import (
4 | "crypto/rand"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 |
9 | "github.com/nais/wonderwall/internal/crypto"
10 | )
11 |
12 | var (
13 | plaintext = []byte("foo bar, this is a very nice plaintext")
14 | key = make([]byte, 32)
15 | )
16 |
17 | const (
18 | // Run this many iterations to make sure the IV is not re-used
19 | ivIterations = 50000
20 | )
21 |
22 | // Generate a new encryption key on every test run.
23 | func init() {
24 | _, err := rand.Read(key)
25 | if err != nil {
26 | panic(err)
27 | }
28 | }
29 |
30 | // Test that encryption with a 256-bit key works,
31 | // and that the ciphertext differs from one message to the next.
32 | func TestEncrypt(t *testing.T) {
33 | var cur, prev []byte
34 | var err error
35 |
36 | crypter := crypto.NewCrypter(key)
37 |
38 | for i := ivIterations; i != 0; i-- {
39 | cur, err = crypter.Encrypt(plaintext)
40 | assert.Nil(t, err)
41 | assert.NotNil(t, cur)
42 | assert.NotEqual(t, prev, cur, "IV re-used")
43 | prev = make([]byte, len(cur))
44 | copy(prev, cur)
45 | }
46 | }
47 |
48 | // Test that encrypted messages can be decrypted.
49 | func TestDecrypt(t *testing.T) {
50 | crypter := crypto.NewCrypter(key)
51 |
52 | ciphertext, err := crypter.Encrypt(plaintext)
53 | assert.Nil(t, err)
54 |
55 | decrypted, err := crypter.Decrypt(ciphertext)
56 | assert.Nil(t, err)
57 | assert.Equal(t, plaintext, decrypted)
58 | }
59 |
60 | func BenchmarkEncrypt(b *testing.B) {
61 | crypter := crypto.NewCrypter(key)
62 |
63 | for n := 0; n < b.N; n++ {
64 | crypter.Encrypt(plaintext)
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/pkg/config/session.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | flag "github.com/spf13/pflag"
8 | )
9 |
10 | type Session struct {
11 | ForwardAuth bool `json:"forward-auth"`
12 | ForwardAuthSetHeaders bool `json:"forward-auth-set-headers"`
13 | Inactivity bool `json:"inactivity"`
14 | InactivityTimeout time.Duration `json:"inactivity-timeout"`
15 | MaxLifetime time.Duration `json:"max-lifetime"`
16 | }
17 |
18 | func (s *Session) Validate() error {
19 | if s.ForwardAuthSetHeaders && !s.ForwardAuth {
20 | return fmt.Errorf("%q must be enabled when %q is enabled", SessionForwardAuth, SessionForwardAuthSetHeaders)
21 | }
22 | return nil
23 | }
24 |
25 | const (
26 | SessionForwardAuth = "session.forward-auth"
27 | SessionForwardAuthSetHeaders = "session.forward-auth-set-headers"
28 | SessionInactivity = "session.inactivity"
29 | SessionInactivityTimeout = "session.inactivity-timeout"
30 | SessionMaxLifetime = "session.max-lifetime"
31 | )
32 |
33 | func sessionFlags() {
34 | flag.Bool(SessionForwardAuth, false, "Enable endpoint for forward authentication.")
35 | flag.Bool(SessionForwardAuthSetHeaders, false, "Set 'X-Wonderwall-Forward-Auth-Token' header for responses from forward-auth endpoint.")
36 | flag.Bool(SessionInactivity, false, "Automatically expire user sessions if they have not refreshed their tokens within a given duration.")
37 | flag.Duration(SessionInactivityTimeout, 30*time.Minute, "Inactivity timeout for user sessions.")
38 | flag.Duration(SessionMaxLifetime, 10*time.Hour, "Max lifetime for user sessions.")
39 | }
40 |
--------------------------------------------------------------------------------
/pkg/mock/config.go:
--------------------------------------------------------------------------------
1 | package mock
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/nais/wonderwall/pkg/config"
7 | "github.com/nais/wonderwall/pkg/ingress"
8 | openidconfig "github.com/nais/wonderwall/pkg/openid/config"
9 | )
10 |
11 | const (
12 | Ingress = "http://wonderwall"
13 | )
14 |
15 | func Config() *config.Config {
16 | return &config.Config{
17 | EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits key
18 | Ingresses: []string{Ingress},
19 | OpenID: config.OpenID{
20 | ACRValues: "idporten-loa-high",
21 | ClientID: "client-id",
22 | IDTokenSigningAlg: "RS256",
23 | PostLogoutRedirectURI: "https://google.com",
24 | Provider: "test",
25 | Scopes: []string{"some-scope"},
26 | UILocales: "nb",
27 | },
28 | Session: config.Session{
29 | MaxLifetime: time.Hour,
30 | },
31 | }
32 | }
33 |
34 | type TestConfiguration struct {
35 | TestClient *TestClientConfiguration
36 | TestProvider *TestProviderConfiguration
37 | }
38 |
39 | func (c *TestConfiguration) Client() openidconfig.Client {
40 | return c.TestClient
41 | }
42 |
43 | func (c *TestConfiguration) Provider() openidconfig.Provider {
44 | return c.TestProvider
45 | }
46 |
47 | func NewTestConfiguration(cfg *config.Config) *TestConfiguration {
48 | return &TestConfiguration{
49 | TestClient: clientConfiguration(cfg),
50 | TestProvider: providerConfiguration(cfg),
51 | }
52 | }
53 |
54 | func Ingresses(cfg *config.Config) *ingress.Ingresses {
55 | parsed, err := ingress.ParseIngresses(cfg)
56 | if err != nil {
57 | panic(err)
58 | }
59 |
60 | return parsed
61 | }
62 |
--------------------------------------------------------------------------------
/docker-compose.example.yml:
--------------------------------------------------------------------------------
1 | services:
2 | redis:
3 | image: redis:7
4 | ports:
5 | - "6379:6379"
6 | mock-oauth2-server:
7 | image: ghcr.io/navikt/mock-oauth2-server:3.0.1
8 | ports:
9 | - "8888:8080"
10 | environment:
11 | JSON_CONFIG: "{\"interactiveLogin\":false}"
12 | wonderwall:
13 | image: ghcr.io/nais/wonderwall:latest
14 | ports:
15 | - "3000:3000"
16 | command: >
17 | --openid.client-id=bogus
18 | --openid.client-secret=not-so-secret
19 | --openid.well-known-url=http://localhost:8888/default/.well-known/openid-configuration
20 | --ingress=http://localhost:3000
21 | --bind-address=0.0.0.0:3000
22 | --upstream-host=upstream:4000
23 | --redis.uri=redis://redis:6379
24 | --log-level=debug
25 | --log-format=text
26 | restart: on-failure
27 | extra_hosts:
28 | # Wonderwall needs to both reach and redirect user agents to the mock-oauth2-server:
29 | # - 'mock-oauth2-server:8888' resolves from the container, but is not resolvable for user agents at the host (e.g. during redirects).
30 | # - 'localhost:8888' allows user agents to resolve redirects to the mock-oauth2-server, but breaks connectivity from the container itself.
31 | # This additional mapping allows the container to reach the mock-oauth2-server at 'localhost' through the host network, as well as allowing user agents to correctly resolve redirects.
32 | - localhost:host-gateway
33 | upstream:
34 | image: mendhak/http-https-echo:30
35 | ports:
36 | - "4000:4000"
37 | environment:
38 | HTTP_PORT: 4000
39 | JWT_HEADER: Authorization
40 | LOG_IGNORE_PATH: /
41 |
--------------------------------------------------------------------------------
/pkg/handler/autologin/autologin.go:
--------------------------------------------------------------------------------
1 | package autologin
2 |
3 | import (
4 | "net/http"
5 | "strings"
6 | "sync"
7 |
8 | "github.com/bmatcuk/doublestar/v4"
9 |
10 | "github.com/nais/wonderwall/pkg/config"
11 | )
12 |
13 | var DefaultIgnorePatterns = []string{
14 | "/favicon.ico",
15 | "/robots.txt",
16 | }
17 |
18 | type AutoLogin struct {
19 | Enabled bool
20 | IgnorePatterns []string
21 | cache sync.Map
22 | }
23 |
24 | func (a *AutoLogin) NeedsLogin(r *http.Request, isAuthenticated bool) bool {
25 | if isAuthenticated || !a.Enabled {
26 | return false
27 | }
28 |
29 | path := r.URL.Path
30 | if !strings.HasPrefix(path, "/") {
31 | path = "/" + path
32 | }
33 |
34 | if path != "/" {
35 | path = strings.TrimSuffix(path, "/")
36 | }
37 |
38 | if result, found := a.cache.Load(path); found {
39 | return result.(bool)
40 | }
41 |
42 | for _, pattern := range a.IgnorePatterns {
43 | match, _ := doublestar.Match(pattern, path)
44 | if match {
45 | a.cache.Store(path, false)
46 | return false
47 | }
48 | }
49 |
50 | a.cache.Store(path, true)
51 | return true
52 | }
53 |
54 | func New(cfg *config.Config) (*AutoLogin, error) {
55 | seen := make(map[string]bool)
56 | patterns := make([]string, 0)
57 |
58 | for _, path := range append(DefaultIgnorePatterns, cfg.AutoLoginIgnorePaths...) {
59 | if len(path) == 0 {
60 | continue
61 | }
62 |
63 | if path != "/" {
64 | path = strings.TrimSuffix(path, "/")
65 | }
66 |
67 | if _, found := seen[path]; !found {
68 | seen[path] = true
69 | patterns = append(patterns, path)
70 | }
71 | }
72 |
73 | return &AutoLogin{
74 | Enabled: cfg.AutoLogin,
75 | IgnorePatterns: patterns,
76 | cache: sync.Map{},
77 | }, nil
78 | }
79 |
--------------------------------------------------------------------------------
/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "Wonderwall";
3 |
4 | inputs = {
5 | nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
6 | flake-utils.url = "github:numtide/flake-utils";
7 | gitignore = {
8 | url = "github:hercules-ci/gitignore.nix";
9 | inputs.nixpkgs.follows = "nixpkgs";
10 | };
11 | };
12 |
13 | outputs = {self, ...} @ inputs:
14 | inputs.flake-utils.lib.eachDefaultSystem (system: let
15 | pkgs = import inputs.nixpkgs {localSystem = {inherit system;};};
16 | name = "wonderwall";
17 | wonderwall = pkgs.buildGoModule {
18 | inherit name;
19 | # nativeBuildInputs = with pkgs; [
20 | # golangci-lint
21 | # ];
22 | # GOLANGCI_LINT = "${pkgs.golangci-lint}";
23 | src = inputs.gitignore.lib.gitignoreSource ./.;
24 | vendorHash = "sha256-3RqVAgA9iJhX0mbwlVMH+NSUz3H9Uobgs8zm1x9fb1o="; # nixpkgs.lib.fakeSha256;
25 | };
26 | in {
27 | devShells.default = pkgs.mkShell {
28 | inputsFrom = [wonderwall];
29 | };
30 | packages = {
31 | inherit wonderwall;
32 | docker = let
33 | imageRef = "europe-north1-docker.pkg.dev/nais-management-233d";
34 | teamName = "nais";
35 | dockerTag =
36 | if pkgs.lib.hasAttr "rev" self
37 | then "${builtins.toString self.revCount}-${self.shortRev}"
38 | else "gitDirty";
39 | in
40 | pkgs.dockerTools.buildImage {
41 | config = {Entrypoint = ["${wonderwall}/bin/${name}"];};
42 | name = "${imageRef}/${teamName}/${name}";
43 | tag = "${dockerTag}";
44 | };
45 | };
46 | packages.default = wonderwall;
47 | formatter = pkgs.alejandra;
48 | });
49 | }
50 |
--------------------------------------------------------------------------------
/templates/error.gohtml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
8 | Teknisk feil
9 |
10 |
11 |
12 |
13 |
14 |
15 | Beklager, noe gikk galt.
16 |
17 |
18 | Statuskode {{.HttpStatusCode}}
19 |
20 |
21 | En teknisk feil gjør at siden er utilgjengelig. Dette skyldes ikke noe du gjorde.
22 | Vent litt og prøv igjen.
23 |
24 |
34 |
35 | ID: {{.CorrelationID}}
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/pkg/openid/acr/acr.go:
--------------------------------------------------------------------------------
1 | package acr
2 |
3 | import (
4 | "fmt"
5 | )
6 |
7 | const (
8 | IDPortenLevel3 = "Level3"
9 | IDPortenLevelSubstantial = "idporten-loa-substantial"
10 | IDPortenLevel4 = "Level4"
11 | IDPortenLevelHigh = "idporten-loa-high"
12 | )
13 |
14 | // IDPortenLegacyMapping is a translation table of valid acr_values that maps values from "old" to "new" ID-porten.
15 | var IDPortenLegacyMapping = map[string]string{
16 | IDPortenLevel3: IDPortenLevelSubstantial,
17 | IDPortenLevel4: IDPortenLevelHigh,
18 | }
19 |
20 | // acceptedValuesMapping is a map of ACR (authentication context class reference) values.
21 | // Each value has an associated list of values that are regarded as equivalent or greater in terms of assurance levels.
22 | // Example:
23 | // - if we require an ACR value of "idporten-loa-substantial", then both "idporten-loa-substantial" and "idporten-loa-high" are accepted values.
24 | // - if we require an ACR value of "idporten-loa-high", then only "idporten-loa-high" is an acceptable value.
25 | var acceptedValuesMapping = map[string][]string{
26 | IDPortenLevelSubstantial: {IDPortenLevelSubstantial, IDPortenLevelHigh},
27 | IDPortenLevelHigh: {IDPortenLevelHigh},
28 | }
29 |
30 | func Validate(expected, actual string) error {
31 | if translated, found := IDPortenLegacyMapping[expected]; found {
32 | expected = translated
33 | }
34 |
35 | acceptedValues, found := acceptedValuesMapping[expected]
36 | if !found {
37 | if expected == actual {
38 | return nil
39 | }
40 | return fmt.Errorf("invalid acr: got %q, expected %q", actual, expected)
41 | }
42 |
43 | for _, accepted := range acceptedValues {
44 | if actual == accepted {
45 | return nil
46 | }
47 | }
48 |
49 | return fmt.Errorf("invalid acr: got %q, must be one of %s", actual, acceptedValues)
50 | }
51 |
--------------------------------------------------------------------------------
/pkg/middleware/context.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "github.com/nais/wonderwall/pkg/ingress"
8 | )
9 |
10 | type contextKey string
11 |
12 | const (
13 | ctxAccessToken = contextKey("AccessToken")
14 | ctxIDToken = contextKey("IDToken")
15 | ctxIngress = contextKey("Ingress")
16 | ctxPath = contextKey("Path")
17 | )
18 |
19 | func AccessTokenFrom(ctx context.Context) (string, bool) {
20 | accessToken, ok := ctx.Value(ctxAccessToken).(string)
21 | return accessToken, ok
22 | }
23 |
24 | func WithAccessToken(ctx context.Context, accessToken string) context.Context {
25 | return context.WithValue(ctx, ctxAccessToken, accessToken)
26 | }
27 |
28 | func IDTokenFrom(ctx context.Context) (string, bool) {
29 | idToken, ok := ctx.Value(ctxIDToken).(string)
30 | return idToken, ok
31 | }
32 |
33 | func WithIDToken(ctx context.Context, idToken string) context.Context {
34 | return context.WithValue(ctx, ctxIDToken, idToken)
35 | }
36 |
37 | func IngressFrom(ctx context.Context) (ingress.Ingress, bool) {
38 | i, ok := ctx.Value(ctxIngress).(ingress.Ingress)
39 | return i, ok
40 | }
41 |
42 | func WithIngress(ctx context.Context, ingress ingress.Ingress) context.Context {
43 | return context.WithValue(ctx, ctxIngress, ingress)
44 | }
45 |
46 | func RequestWithIngress(r *http.Request, ing ingress.Ingress) *http.Request {
47 | ctx := r.Context()
48 | ctx = WithIngress(ctx, ing)
49 | return r.WithContext(ctx)
50 | }
51 |
52 | func PathFrom(ctx context.Context) (string, bool) {
53 | path, ok := ctx.Value(ctxPath).(string)
54 | return path, ok
55 | }
56 |
57 | func WithPath(ctx context.Context, path string) context.Context {
58 | return context.WithValue(ctx, ctxPath, path)
59 | }
60 |
61 | func RequestWithPath(r *http.Request, path string) *http.Request {
62 | ctx := r.Context()
63 | ctx = WithPath(ctx, path)
64 | return r.WithContext(ctx)
65 | }
66 |
--------------------------------------------------------------------------------
/pkg/openid/client/logout.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 |
8 | "github.com/nais/wonderwall/internal/crypto"
9 | "github.com/nais/wonderwall/pkg/cookie"
10 | "github.com/nais/wonderwall/pkg/openid"
11 | urlpkg "github.com/nais/wonderwall/pkg/url"
12 | )
13 |
14 | type Logout struct {
15 | *Client
16 | Cookie *openid.LogoutCookie
17 | logoutCallbackURL string
18 | }
19 |
20 | func NewLogout(c *Client, r *http.Request) (*Logout, error) {
21 | logoutCallbackURL, err := urlpkg.LogoutCallback(r)
22 | if err != nil {
23 | return nil, fmt.Errorf("generating logout callback url: %w", err)
24 | }
25 |
26 | state, err := crypto.Text(32)
27 | if err != nil {
28 | return nil, fmt.Errorf("generating state: %w", err)
29 | }
30 |
31 | logoutCookie := &openid.LogoutCookie{
32 | State: state,
33 | }
34 |
35 | return &Logout{
36 | Client: c,
37 | Cookie: logoutCookie,
38 | logoutCallbackURL: logoutCallbackURL,
39 | }, nil
40 | }
41 |
42 | func (in *Logout) SingleLogoutURL(idToken string) string {
43 | endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL()
44 | v := endSessionEndpoint.Query()
45 | v.Set("post_logout_redirect_uri", in.logoutCallbackURL)
46 | v.Set("state", in.Cookie.State)
47 |
48 | if len(idToken) > 0 {
49 | v.Set("id_token_hint", idToken)
50 | }
51 |
52 | endSessionEndpoint.RawQuery = v.Encode()
53 | return endSessionEndpoint.String()
54 | }
55 |
56 | func (in *Logout) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
57 | in.Cookie.RedirectTo = canonicalRedirect
58 |
59 | logoutCookieJson, err := json.Marshal(in.Cookie)
60 | if err != nil {
61 | return fmt.Errorf("marshalling logout cookie: %w", err)
62 | }
63 |
64 | value := string(logoutCookieJson)
65 | return cookie.EncryptAndSet(w, cookie.Logout, value, opts, crypter)
66 | }
67 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/networkpolicy.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: networking.k8s.io/v1
3 | kind: NetworkPolicy
4 | metadata:
5 | labels:
6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
7 | name: {{ include "wonderwall-forward-auth.fullname" . }}
8 | spec:
9 | ingress:
10 | - from:
11 | - namespaceSelector:
12 | matchLabels:
13 | kubernetes.io/metadata.name: nais-system
14 | podSelector:
15 | matchLabels:
16 | app.kubernetes.io/name: prometheus
17 | - from:
18 | - namespaceSelector:
19 | matchLabels:
20 | kubernetes.io/metadata.name: nais-system
21 | podSelector:
22 | matchLabels:
23 | app.kubernetes.io/name: alloy
24 | - from:
25 | - namespaceSelector:
26 | matchLabels:
27 | kubernetes.io/metadata.name: nais-system
28 | podSelector:
29 | matchLabels:
30 | nais.io/ingressClass: {{ .Values.ingressClassName }}
31 | podSelector:
32 | matchLabels:
33 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }}
34 | policyTypes:
35 | - Ingress
36 | {{- if .Capabilities.APIVersions.Has "networking.gke.io/v1alpha3" }}
37 | ---
38 | apiVersion: networking.gke.io/v1alpha3
39 | kind: FQDNNetworkPolicy
40 | metadata:
41 | name: {{ include "wonderwall-forward-auth.fullname" . }}-fqdn
42 | labels:
43 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }}
44 | annotations:
45 | fqdnnetworkpolicies.networking.gke.io/aaaa-lookups: "skip"
46 | spec:
47 | egress:
48 | - ports:
49 | - port: 443
50 | protocol: TCP
51 | to:
52 | - fqdns:
53 | - auth.nais.io
54 | podSelector:
55 | matchLabels:
56 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }}
57 | policyTypes:
58 | - Egress
59 | {{- end }}
60 |
--------------------------------------------------------------------------------
/pkg/server/server.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net/http"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/go-chi/chi/v5"
13 | log "github.com/sirupsen/logrus"
14 |
15 | "github.com/nais/wonderwall/pkg/config"
16 | )
17 |
18 | func Start(cfg *config.Config, r chi.Router) error {
19 | server := http.Server{
20 | Addr: cfg.BindAddress,
21 | Handler: r,
22 | }
23 |
24 | serverCtx, serverStopCtx := context.WithCancel(context.Background())
25 |
26 | sig := make(chan os.Signal, 1)
27 | signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
28 | go func() {
29 | s := <-sig
30 | log.Infof("server: received %q; waiting for %s before starting graceful shutdown...", s, cfg.ShutdownWaitBeforePeriod)
31 | time.Sleep(cfg.ShutdownWaitBeforePeriod)
32 |
33 | // the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, so we need to subtract the wait-before period to exit before SIGKILL
34 | shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod
35 | shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, shutdownTimeout)
36 |
37 | go func() {
38 | <-shutdownCtx.Done()
39 | if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) {
40 | log.Fatalf("server: graceful shutdown timed out after %s; forcing exit.", shutdownTimeout)
41 | }
42 | }()
43 |
44 | log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout)
45 | err := server.Shutdown(shutdownCtx)
46 | if err != nil {
47 | log.Fatal(err)
48 | }
49 | shutdownStopCtx()
50 | serverStopCtx()
51 | }()
52 |
53 | log.Infof("server: listening on %s", cfg.BindAddress)
54 | err := server.ListenAndServe()
55 | if err != nil && !errors.Is(err, http.ErrServerClosed) {
56 | return err
57 | }
58 |
59 | <-serverCtx.Done()
60 | log.Infof("server: shutdown completed")
61 | return nil
62 | }
63 |
--------------------------------------------------------------------------------
/pkg/session/store.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/prometheus/client_golang/prometheus"
9 | "github.com/redis/go-redis/extra/redisotel/v9"
10 | "github.com/redis/go-redis/extra/redisprometheus/v9"
11 | log "github.com/sirupsen/logrus"
12 |
13 | "github.com/nais/wonderwall/pkg/config"
14 | )
15 |
16 | type Store interface {
17 | Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error
18 | Read(ctx context.Context, key string) (*EncryptedData, error)
19 | Delete(ctx context.Context, keys ...string) error
20 | Update(ctx context.Context, key string, value *EncryptedData) error
21 |
22 | MakeLock(key string) Lock
23 | }
24 |
25 | func NewStore(cfg *config.Config) (Store, error) {
26 | if len(cfg.Redis.Address) == 0 && len(cfg.Redis.URI) == 0 {
27 | log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
28 | return NewMemory(), nil
29 | }
30 |
31 | redisClient, err := cfg.Redis.Client()
32 | if err != nil {
33 | return nil, fmt.Errorf("failed to create Redis Client: %w", err)
34 | }
35 |
36 | collector := redisprometheus.NewCollector("wonderwall", "", redisClient)
37 | prometheus.Register(collector)
38 |
39 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
40 | defer cancel()
41 |
42 | err = redisClient.Ping(ctx).Err()
43 | if err != nil {
44 | return nil, fmt.Errorf("failed to connect to configured Redis: %w", err)
45 | }
46 |
47 | if cfg.OpenTelemetry.Enabled {
48 | opts := []redisotel.TracingOption{
49 | redisotel.WithDBStatement(false),
50 | redisotel.WithCallerEnabled(false),
51 | }
52 | if err := redisotel.InstrumentTracing(redisClient, opts...); err != nil {
53 | return nil, fmt.Errorf("failed to instrument Redis Client: %w", err)
54 | }
55 | log.Infof("session: using redis as backing store with OpenTelemetry instrumentation")
56 | } else {
57 | log.Infof("session: using redis as backing store")
58 | }
59 |
60 | return NewRedis(redisClient), nil
61 | }
62 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/networkpolicy.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: networking.k8s.io/v1
2 | kind: NetworkPolicy
3 | metadata:
4 | labels:
5 | {{- include "wonderwall.labels" . | nindent 4 }}
6 | name: {{ include "wonderwall.fullname" . }}
7 | spec:
8 | egress:
9 | - to:
10 | - namespaceSelector: {}
11 | podSelector:
12 | matchLabels:
13 | k8s-app: kube-dns
14 | - to:
15 | - namespaceSelector:
16 | matchLabels:
17 | linkerd.io/is-control-plane: "true"
18 | - to:
19 | - namespaceSelector:
20 | matchLabels:
21 | kubernetes.io/metadata.name: nais-system
22 | podSelector:
23 | matchLabels:
24 | app.kubernetes.io/name: tempo
25 | ingress:
26 | - from:
27 | - namespaceSelector:
28 | matchLabels:
29 | kubernetes.io/metadata.name: nais-system
30 | podSelector:
31 | matchLabels:
32 | app.kubernetes.io/name: prometheus
33 | - from:
34 | - namespaceSelector:
35 | matchLabels:
36 | kubernetes.io/metadata.name: nais-system
37 | podSelector:
38 | matchLabels:
39 | app.kubernetes.io/name: alloy
40 | - from:
41 | - namespaceSelector:
42 | matchLabels:
43 | kubernetes.io/metadata.name: nais-system
44 | podSelector:
45 | matchLabels:
46 | nais.io/ingressClass: {{ .Values.azure.forwardAuth.ingressClassName }}
47 | - from:
48 | - namespaceSelector:
49 | matchLabels:
50 | kubernetes.io/metadata.name: nais-system
51 | podSelector:
52 | matchLabels:
53 | nais.io/ingressClass: {{ .Values.idporten.ingressClassName }}
54 | - from:
55 | - namespaceSelector:
56 | matchLabels:
57 | linkerd.io/is-control-plane: "true"
58 | podSelector:
59 | matchLabels:
60 | {{- include "wonderwall.selectorLabels" . | nindent 6 }}
61 | policyTypes:
62 | - Ingress
63 | - Egress
64 |
--------------------------------------------------------------------------------
/pkg/session/lock.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/bsm/redislock"
10 | "github.com/nais/wonderwall/internal/o11y/otel"
11 | "github.com/redis/go-redis/v9"
12 | "go.opentelemetry.io/otel/attribute"
13 | )
14 |
15 | const (
16 | KeyTemplate = "%s.lock"
17 | )
18 |
19 | var ErrAcquireLock = errors.New("could not acquire lock")
20 |
21 | type Lock interface {
22 | Acquire(ctx context.Context, duration time.Duration) error
23 | Release(ctx context.Context) error
24 | }
25 |
26 | var _ Lock = &RedisLock{}
27 |
28 | type RedisLock struct {
29 | locker *redislock.Client
30 | lock *redislock.Lock
31 | key string
32 | }
33 |
34 | func NewRedisLock(client redis.Cmdable, key string) *RedisLock {
35 | return &RedisLock{
36 | locker: redislock.New(client),
37 | key: key,
38 | }
39 | }
40 |
41 | func (r *RedisLock) Acquire(ctx context.Context, duration time.Duration) error {
42 | ctx, span := otel.StartSpan(ctx, "RedisLock.Acquire")
43 | defer span.End()
44 | span.SetAttributes(attribute.String("redis.lock_duration", duration.String()))
45 | span.SetAttributes(attribute.Bool("redis.lock_acquired", false))
46 |
47 | lock, err := r.locker.Obtain(ctx, lockKey(r.key), duration, nil)
48 | if errors.Is(err, redislock.ErrNotObtained) {
49 | return ErrAcquireLock
50 | }
51 | if err != nil {
52 | return err
53 | }
54 |
55 | r.lock = lock
56 | span.SetAttributes(attribute.Bool("redis.lock_acquired", true))
57 | return nil
58 | }
59 |
60 | func (r *RedisLock) Release(ctx context.Context) error {
61 | ctx, span := otel.StartSpan(ctx, "RedisLock.Release")
62 | defer span.End()
63 | return r.lock.Release(ctx)
64 | }
65 |
66 | var _ Lock = &NoOpLock{}
67 |
68 | type NoOpLock struct{}
69 |
70 | func NewNoOpLock() *NoOpLock {
71 | return new(NoOpLock)
72 | }
73 |
74 | func (n *NoOpLock) Acquire(_ context.Context, _ time.Duration) error {
75 | return nil
76 | }
77 |
78 | func (n *NoOpLock) Release(_ context.Context) error {
79 | return nil
80 | }
81 |
82 | func lockKey(key string) string {
83 | return fmt.Sprintf(KeyTemplate, key)
84 | }
85 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/Feature.yaml:
--------------------------------------------------------------------------------
1 | environmentKinds:
2 | - management
3 | dependencies:
4 | - allOf:
5 | - nais-netpols-management
6 | values:
7 | openid.clientID:
8 | computed:
9 | template: |
10 | {{ .Env.wonderwall_forward_auth_zitadel_client_id | quote }}
11 | openid.clientSecret:
12 | computed:
13 | template: |
14 | {{ .Env.wonderwall_forward_auth_zitadel_client_secret | quote }}
15 | openid.extraAudience:
16 | description: Comma separated list of additional audiences for id_token validation.
17 | computed:
18 | template: |
19 | {{ .Env.wonderwall_forward_auth_zitadel_project_id | quote }}
20 | openid.extraScopes:
21 | description: Comma separated list of additional scopes to request from the OpenID provider.
22 | computed:
23 | template: |
24 | "urn:zitadel:iam:org:id:{{ .Env.zitadel_organization_id }}"
25 | replicas.min:
26 | config:
27 | type: int
28 | replicas.max:
29 | config:
30 | type: int
31 | session.cookieEncryptionKey:
32 | description: Cookie encryption key, 256 bits (e.g. 32 ASCII characters) encoded with standard base64.
33 | computed:
34 | template: |
35 | {{ .Env.wonderwall_forward_auth_encryption_key | quote }}
36 | sso.domain:
37 | description: Domain for forward auth
38 | computed:
39 | template: |
40 | {{ .Tenant.Name }}.cloud.nais.io
41 | sso.defaultRedirectURL:
42 | description: Default redirect URL for forward auth
43 | computed:
44 | template: |
45 | {{ printf "https://%s" (subdomain . "console") | quote }}
46 | valkey.host:
47 | computed:
48 | template: |
49 | {{ .Env.wonderwall_forward_auth_valkey_host | quote }}
50 | valkey.port:
51 | computed:
52 | template: |
53 | {{ .Env.wonderwall_forward_auth_valkey_port | quote }}
54 | valkey.username:
55 | computed:
56 | template: |
57 | {{ .Env.wonderwall_forward_auth_valkey_username | quote }}
58 | valkey.password:
59 | computed:
60 | template: |
61 | {{ .Env.wonderwall_forward_auth_valkey_password | quote }}
62 |
--------------------------------------------------------------------------------
/charts/wonderwall-forward-auth/templates/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "wonderwall-forward-auth.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
11 | If release name contains chart name it will be used as a full name.
12 | */}}
13 | {{- define "wonderwall-forward-auth.fullname" -}}
14 | {{- if .Values.fullnameOverride }}
15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
16 | {{- else }}
17 | {{- $name := default .Chart.Name .Values.nameOverride }}
18 | {{- if contains $name .Release.Name }}
19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
20 | {{- else }}
21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
22 | {{- end }}
23 | {{- end }}
24 | {{- end }}
25 |
26 | {{/*
27 | Create chart name and version as used by the chart label.
28 | */}}
29 | {{- define "wonderwall-forward-auth.chart" -}}
30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
31 | {{- end }}
32 |
33 | {{/*
34 | Common labels
35 | */}}
36 | {{- define "wonderwall-forward-auth.labels" -}}
37 | helm.sh/chart: {{ include "wonderwall-forward-auth.chart" . }}
38 | {{ include "wonderwall-forward-auth.selectorLabels" . }}
39 | {{- if .Chart.AppVersion }}
40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
41 | {{- end }}
42 | app.kubernetes.io/managed-by: {{ .Release.Service }}
43 | {{- end }}
44 |
45 | {{/*
46 | Selector labels
47 | */}}
48 | {{- define "wonderwall-forward-auth.selectorLabels" -}}
49 | app.kubernetes.io/name: {{ include "wonderwall-forward-auth.name" . }}
50 | app.kubernetes.io/instance: {{ .Release.Name }}
51 | {{- end }}
52 |
53 | {{/*
54 | Create the name of the service account to use
55 | */}}
56 | {{- define "wonderwall-forward-auth.serviceAccountName" -}}
57 | {{- if .Values.serviceAccount.create }}
58 | {{- default (include "wonderwall-forward-auth.fullname" .) .Values.serviceAccount.name }}
59 | {{- else }}
60 | {{- default "default" .Values.serviceAccount.name }}
61 | {{- end }}
62 | {{- end }}
63 |
--------------------------------------------------------------------------------
/pkg/openid/config/provider_test.go:
--------------------------------------------------------------------------------
1 | package config_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/nais/wonderwall/pkg/config"
7 | "github.com/stretchr/testify/assert"
8 |
9 | "github.com/nais/wonderwall/pkg/mock"
10 | openidconfig "github.com/nais/wonderwall/pkg/openid/config"
11 | )
12 |
13 | func TestProviderMetadata_Validate(t *testing.T) {
14 | metadata := &openidconfig.ProviderMetadata{
15 | ACRValuesSupported: openidconfig.Supported{"idporten-loa-substantial", "idporten-loa-high"},
16 | UILocalesSupported: openidconfig.Supported{"nb", "nb", "en", "se"},
17 | IDTokenSigningAlgValuesSupported: openidconfig.Supported{"RS256"},
18 | }
19 |
20 | for _, tt := range []struct {
21 | name string
22 | config config.OpenID
23 | assertion assert.ErrorAssertionFunc
24 | }{
25 | {
26 | name: "happy path",
27 | config: config.OpenID{ACRValues: "idporten-loa-high", UILocales: "nb"},
28 | assertion: assert.NoError,
29 | },
30 | {
31 | name: "invalid acr",
32 | config: config.OpenID{ACRValues: "Level5"},
33 | assertion: assert.Error,
34 | },
35 | {
36 | name: "invalid locale",
37 | config: config.OpenID{UILocales: "de"},
38 | assertion: assert.Error,
39 | },
40 | {
41 | name: "has acr translation for Level4",
42 | config: config.OpenID{ACRValues: "Level4"},
43 | assertion: assert.NoError,
44 | },
45 | {
46 | name: "has acr translation for Level3",
47 | config: config.OpenID{ACRValues: "Level3"},
48 | assertion: assert.NoError,
49 | },
50 | {
51 | name: "invalid signing algorithm",
52 | config: config.OpenID{IDTokenSigningAlg: "HS256"},
53 | assertion: assert.Error,
54 | },
55 | } {
56 | t.Run(tt.name, func(t *testing.T) {
57 | cfg := mock.Config()
58 | if tt.config.ACRValues != "" {
59 | cfg.OpenID.ACRValues = tt.config.ACRValues
60 | }
61 | if tt.config.UILocales != "" {
62 | cfg.OpenID.UILocales = tt.config.UILocales
63 | }
64 | if tt.config.IDTokenSigningAlg != "" {
65 | cfg.OpenID.IDTokenSigningAlg = tt.config.IDTokenSigningAlg
66 | }
67 |
68 | err := metadata.Validate(cfg.OpenID)
69 | tt.assertion(t, err)
70 | })
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/charts/wonderwall/values.yaml:
--------------------------------------------------------------------------------
1 | nameOverride: ""
2 | fullnameOverride: ""
3 |
4 | image:
5 | repository: europe-north1-docker.pkg.dev/nais-io/nais/images/wonderwall
6 | tag: latest
7 | imagePullSecrets: []
8 |
9 | aiven:
10 | project:
11 | prometheusEndpointId:
12 | redisPlan:
13 | azure:
14 | enabled: false
15 | redisSecretName: wonderwall-azure-redis-rw
16 | sessionMaxLifetime: 10h
17 | forwardAuth:
18 | enabled: false
19 | replicasMin: 2
20 | replicasMax: 4
21 | clientSecretName: azure-sso-server
22 | ingressClassName: nais-ingress-fa
23 | # 256 bits key, in standard base64 encoding
24 | sessionCookieEncryptionKey:
25 | sessionCookieName: forwardauth
26 | ssoDefaultRedirectURL:
27 | ssoDomain:
28 | ssoServerSecretName: wonderwall-azure-sso-server
29 | groupIds: [] # [""] - additional group IDs to grant access to
30 | idporten:
31 | enabled: false
32 | clientAccessTokenLifetime: 3600
33 | clientSessionLifetime: 21600
34 | clientSecretName: idporten-sso-server
35 | ingressClassName: nais-ingress-external
36 | openidAcrValues: idporten-loa-high
37 | openidLocale: nb
38 | openidPostLogoutRedirectURL:
39 | openidResourceIndicator:
40 | redisSecretNames:
41 | read: wonderwall-idporten-redis-ro
42 | readwrite: wonderwall-idporten-redis-rw
43 | replicasMax: 4
44 | replicasMin: 2
45 | sessionCookieName:
46 | # 256 bits key, in standard base64 encoding
47 | sessionCookieEncryptionKey:
48 | sessionInactivity: true
49 | sessionInactivityTimeout: 1h
50 | sessionMaxLifetime: 6h
51 | ssoServerHost:
52 | # secret for configuring Wonderwall server itself
53 | ssoServerSecretName: wonderwall-idporten-sso-server
54 | ssoDefaultRedirectURL:
55 | ssoDomain:
56 | openid:
57 | enabled: true
58 | redisSecretName: wonderwall-openid-redis-rw
59 | # https:///.well-known/openid-configuration
60 | wellKnownUrl:
61 | redis:
62 | connectionIdleTimeout: 299
63 | resources:
64 | limits:
65 | cpu: "2"
66 | memory: 512Mi
67 | requests:
68 | cpu: 100m
69 | memory: 64Mi
70 | resourceSuffix: ""
71 | podDisruptionBudget:
72 | maxUnavailable: 1
73 | otel:
74 | endpoint: http://opentelemetry-collector.nais-system:4317
75 |
--------------------------------------------------------------------------------
/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "flake-utils": {
4 | "inputs": {
5 | "systems": "systems"
6 | },
7 | "locked": {
8 | "lastModified": 1710146030,
9 | "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
10 | "owner": "numtide",
11 | "repo": "flake-utils",
12 | "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
13 | "type": "github"
14 | },
15 | "original": {
16 | "owner": "numtide",
17 | "repo": "flake-utils",
18 | "type": "github"
19 | }
20 | },
21 | "gitignore": {
22 | "inputs": {
23 | "nixpkgs": [
24 | "nixpkgs"
25 | ]
26 | },
27 | "locked": {
28 | "lastModified": 1709087332,
29 | "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
30 | "owner": "hercules-ci",
31 | "repo": "gitignore.nix",
32 | "rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
33 | "type": "github"
34 | },
35 | "original": {
36 | "owner": "hercules-ci",
37 | "repo": "gitignore.nix",
38 | "type": "github"
39 | }
40 | },
41 | "nixpkgs": {
42 | "locked": {
43 | "lastModified": 1722141560,
44 | "narHash": "sha256-Ul3rIdesWaiW56PS/Ak3UlJdkwBrD4UcagCmXZR9Z7Y=",
45 | "owner": "NixOS",
46 | "repo": "nixpkgs",
47 | "rev": "038fb464fcfa79b4f08131b07f2d8c9a6bcc4160",
48 | "type": "github"
49 | },
50 | "original": {
51 | "owner": "NixOS",
52 | "ref": "nixpkgs-unstable",
53 | "repo": "nixpkgs",
54 | "type": "github"
55 | }
56 | },
57 | "root": {
58 | "inputs": {
59 | "flake-utils": "flake-utils",
60 | "gitignore": "gitignore",
61 | "nixpkgs": "nixpkgs"
62 | }
63 | },
64 | "systems": {
65 | "locked": {
66 | "lastModified": 1681028828,
67 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
68 | "owner": "nix-systems",
69 | "repo": "default",
70 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
71 | "type": "github"
72 | },
73 | "original": {
74 | "owner": "nix-systems",
75 | "repo": "default",
76 | "type": "github"
77 | }
78 | }
79 | },
80 | "root": "root",
81 | "version": 7
82 | }
83 |
--------------------------------------------------------------------------------
/pkg/session/session_reader.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 |
9 | "github.com/nais/wonderwall/internal/crypto"
10 | "github.com/nais/wonderwall/internal/o11y/otel"
11 | "github.com/nais/wonderwall/internal/retry"
12 | "github.com/nais/wonderwall/pkg/config"
13 | "go.opentelemetry.io/otel/attribute"
14 | "go.opentelemetry.io/otel/trace"
15 | )
16 |
17 | var _ Reader = &reader{}
18 |
19 | type reader struct {
20 | cfg *config.Config
21 | cookieCrypter crypto.Crypter
22 | store Store
23 | }
24 |
25 | func NewReader(cfg *config.Config, cookieCrypter crypto.Crypter) (Reader, error) {
26 | store, err := NewStore(cfg)
27 | if err != nil {
28 | return nil, err
29 | }
30 |
31 | return &reader{
32 | cfg: cfg,
33 | cookieCrypter: cookieCrypter,
34 | store: store,
35 | }, nil
36 | }
37 |
38 | func (in *reader) Get(r *http.Request) (*Session, error) {
39 | r, span := otel.StartSpanFromRequest(r, "Session.Get")
40 | defer span.End()
41 |
42 | ticket, err := getTicket(r, in.cookieCrypter)
43 | if err != nil {
44 | return nil, err
45 | }
46 |
47 | return in.getForTicket(r.Context(), ticket)
48 | }
49 |
50 | func (in *reader) getForTicket(ctx context.Context, ticket *Ticket) (*Session, error) {
51 | span := trace.SpanFromContext(ctx)
52 | span.SetAttributes(attribute.Bool("session.valid_session", false))
53 |
54 | encrypted, err := retry.DoValue(ctx, func(ctx context.Context) (*EncryptedData, error) {
55 | encrypted, err := in.store.Read(ctx, ticket.Key())
56 | if errors.Is(err, ErrNotFound) {
57 | return nil, err
58 | }
59 | if err != nil {
60 | return nil, retry.RetryableError(err)
61 | }
62 | return encrypted, nil
63 | })
64 | if err != nil {
65 | return nil, fmt.Errorf("reading from store: %w", err)
66 | }
67 |
68 | data, err := encrypted.Decrypt(ticket.Crypter())
69 | if err != nil {
70 | return nil, fmt.Errorf("%w: decrypting session data: %w", ErrInvalid, err)
71 | }
72 |
73 | sess := NewSession(data, ticket)
74 | if sess != nil {
75 | span.SetAttributes(attribute.String("session.id", sess.ExternalSessionID()))
76 | }
77 | err = data.Validate()
78 | if err != nil {
79 | return sess, err
80 | }
81 |
82 | span.SetAttributes(attribute.Bool("session.valid_session", true))
83 | data.Metadata.SetSpanAttributes(span)
84 | return sess, nil
85 | }
86 |
--------------------------------------------------------------------------------
/pkg/cookie/options_test.go:
--------------------------------------------------------------------------------
1 | package cookie_test
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 |
9 | "github.com/nais/wonderwall/pkg/cookie"
10 | )
11 |
12 | func TestDefaultOptions(t *testing.T) {
13 | opts := cookie.DefaultOptions()
14 |
15 | assert.Equal(t, http.SameSiteLaxMode, opts.SameSite)
16 | assert.True(t, opts.Secure)
17 | assert.Empty(t, opts.Domain)
18 | assert.Empty(t, opts.Path)
19 | }
20 |
21 | func TestOptions_WithDomain(t *testing.T) {
22 | domain := ".some.domain"
23 | opts := cookie.Options{}.WithDomain(domain)
24 |
25 | assert.Equal(t, ".some.domain", opts.Domain)
26 |
27 | opts = cookie.Options{
28 | Domain: ".domain",
29 | }
30 | newOpts := opts.WithDomain(".some.other.domain")
31 |
32 | assert.Equal(t, ".domain", opts.Domain, "original options should be unchanged")
33 | assert.Equal(t, ".some.other.domain", newOpts.Domain, "copy of options should have new value")
34 | }
35 |
36 | func TestOptions_WithPath(t *testing.T) {
37 | path := "/some/path"
38 | opts := cookie.Options{}.WithPath(path)
39 |
40 | assert.Equal(t, "/some/path", opts.Path)
41 |
42 | opts = cookie.Options{
43 | Path: "/some/path",
44 | }
45 | newOpts := opts.WithPath("/some/other/path")
46 |
47 | assert.Equal(t, "/some/path", opts.Path, "original options should be unchanged")
48 | assert.Equal(t, "/some/other/path", newOpts.Path, "copy of options should have new value")
49 | }
50 |
51 | func TestOptions_WithSameSite(t *testing.T) {
52 | sameSite := http.SameSiteDefaultMode
53 | opts := cookie.Options{}.WithSameSite(sameSite)
54 |
55 | assert.Equal(t, http.SameSiteDefaultMode, opts.SameSite)
56 |
57 | opts = cookie.Options{
58 | SameSite: http.SameSiteLaxMode,
59 | }
60 | newOpts := opts.WithSameSite(sameSite)
61 |
62 | assert.Equal(t, http.SameSiteLaxMode, opts.SameSite, "original options should be unchanged")
63 | assert.Equal(t, http.SameSiteDefaultMode, newOpts.SameSite, "copy of options should have new value")
64 | }
65 |
66 | func TestOptions_WithSecure(t *testing.T) {
67 | opts := cookie.Options{}.WithSecure(true)
68 |
69 | assert.True(t, opts.Secure)
70 |
71 | opts = cookie.Options{
72 | Secure: false,
73 | }
74 | newOpts := opts.WithSecure(true)
75 |
76 | assert.False(t, opts.Secure, "original options should be unchanged")
77 | assert.True(t, newOpts.Secure, "copy of options should have new value")
78 | }
79 |
--------------------------------------------------------------------------------
/pkg/mock/client.go:
--------------------------------------------------------------------------------
1 | package mock
2 |
3 | import (
4 | "github.com/lestrrat-go/jwx/v3/jwk"
5 | "github.com/nais/wonderwall/internal/crypto"
6 | "github.com/nais/wonderwall/pkg/config"
7 | openidconfig "github.com/nais/wonderwall/pkg/openid/config"
8 | "github.com/nais/wonderwall/pkg/openid/scopes"
9 | )
10 |
11 | type TestClientConfiguration struct {
12 | *config.Config
13 | clientJwk jwk.Key
14 | trustedAudiences map[string]bool
15 | }
16 |
17 | var _ openidconfig.Client = (*TestClientConfiguration)(nil)
18 |
19 | func (c *TestClientConfiguration) ACRValues() string {
20 | return c.Config.OpenID.ACRValues
21 | }
22 |
23 | func (c *TestClientConfiguration) Audiences() map[string]bool {
24 | return c.trustedAudiences
25 | }
26 |
27 | func (c *TestClientConfiguration) AuthMethod() openidconfig.AuthMethod {
28 | return openidconfig.AuthMethodPrivateKeyJWT
29 | }
30 |
31 | func (c *TestClientConfiguration) ClientID() string {
32 | return c.Config.OpenID.ClientID
33 | }
34 |
35 | func (c *TestClientConfiguration) ClientJWK() jwk.Key {
36 | return c.clientJwk
37 | }
38 |
39 | func (c *TestClientConfiguration) ClientSecret() string {
40 | return c.Config.OpenID.ClientSecret
41 | }
42 |
43 | func (c *TestClientConfiguration) NewClientAuthJWTType() bool {
44 | return c.Config.OpenID.NewClientAuthJWTType
45 | }
46 |
47 | func (c *TestClientConfiguration) SetPostLogoutRedirectURI(uri string) {
48 | c.Config.OpenID.PostLogoutRedirectURI = uri
49 | }
50 |
51 | func (c *TestClientConfiguration) PostLogoutRedirectURI() string {
52 | return c.Config.OpenID.PostLogoutRedirectURI
53 | }
54 |
55 | func (c *TestClientConfiguration) ResourceIndicator() string {
56 | return c.Config.OpenID.ResourceIndicator
57 | }
58 |
59 | func (c *TestClientConfiguration) Scopes() scopes.Scopes {
60 | return scopes.DefaultScopes().WithAdditional(c.Config.OpenID.Scopes...)
61 | }
62 |
63 | func (c *TestClientConfiguration) UILocales() string {
64 | return c.Config.OpenID.UILocales
65 | }
66 |
67 | func (c *TestClientConfiguration) WellKnownURL() string {
68 | return c.Config.OpenID.WellKnownURL
69 | }
70 |
71 | func clientConfiguration(cfg *config.Config) *TestClientConfiguration {
72 | key, err := crypto.NewJwk()
73 | if err != nil {
74 | panic(err)
75 | }
76 |
77 | return &TestClientConfiguration{
78 | Config: cfg,
79 | clientJwk: key,
80 | trustedAudiences: cfg.OpenID.TrustedAudiences(),
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/pkg/config/redis.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "crypto/tls"
5 | "time"
6 |
7 | "github.com/redis/go-redis/v9"
8 | flag "github.com/spf13/pflag"
9 | )
10 |
11 | type Redis struct {
12 | Address string `json:"address"`
13 | Username string `json:"username"`
14 | Password string `json:"password"`
15 | TLS bool `json:"tls"`
16 | URI string `json:"uri"`
17 | ConnectionIdleTimeout int `json:"connection-idle-timeout"`
18 | }
19 |
20 | func (r *Redis) Client() (*redis.Client, error) {
21 | opts := &redis.Options{
22 | Network: "tcp",
23 | Addr: r.Address,
24 | }
25 |
26 | if r.TLS {
27 | opts.TLSConfig = &tls.Config{}
28 | }
29 |
30 | if r.URI != "" {
31 | var err error
32 |
33 | opts, err = redis.ParseURL(r.URI)
34 | if err != nil {
35 | return nil, err
36 | }
37 | }
38 |
39 | opts.MinIdleConns = 1
40 | opts.MaxRetries = 5
41 |
42 | if r.Username != "" {
43 | opts.Username = r.Username
44 | }
45 |
46 | if r.Password != "" {
47 | opts.Password = r.Password
48 | }
49 |
50 | if r.ConnectionIdleTimeout > 0 {
51 | opts.ConnMaxIdleTime = time.Duration(r.ConnectionIdleTimeout) * time.Second
52 | } else if r.ConnectionIdleTimeout == -1 {
53 | opts.ConnMaxIdleTime = -1
54 | }
55 |
56 | return redis.NewClient(opts), nil
57 | }
58 |
59 | const (
60 | RedisAddress = "redis.address"
61 | RedisPassword = "redis.password"
62 | RedisTLS = "redis.tls"
63 | RedisUsername = "redis.username"
64 | RedisURI = "redis.uri"
65 | RedisConnectionIdleTimeout = "redis.connection-idle-timeout"
66 | )
67 |
68 | func redisFlags() {
69 | flag.String(RedisURI, "", "Redis URI string. An empty value will fall back to 'redis-address'.")
70 | flag.String(RedisAddress, "", "Deprecated: prefer using 'redis.uri'. Address of the Redis instance (host:port). An empty value will use in-memory session storage. Does not override address set by 'redis.uri'.")
71 | flag.String(RedisPassword, "", "Password for Redis. Overrides password set by 'redis.uri'.")
72 | flag.Bool(RedisTLS, true, "Whether or not to use TLS for connecting to Redis. Does not override TLS config set by 'redis.uri'.")
73 | flag.String(RedisUsername, "", "Username for Redis. Overrides username set by 'redis.uri'.")
74 | flag.Int(RedisConnectionIdleTimeout, 0, "Idle timeout for Redis connections, in seconds. If non-zero, the value should be less than the client timeout configured at the Redis server. A value of -1 disables timeout. If zero, the default value from go-redis is used (30 minutes). Overrides options set by 'redis.uri'.")
75 | }
76 |
--------------------------------------------------------------------------------
/pkg/config/cookie.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/url"
7 | "slices"
8 | "strings"
9 |
10 | flag "github.com/spf13/pflag"
11 | )
12 |
13 | type Cookie struct {
14 | Prefix string `json:"prefix"`
15 | SameSite SameSite `json:"same-site"`
16 | Secure bool `json:"secure"`
17 | }
18 |
19 | func (c *Cookie) Validate(cfg *Config) error {
20 | if err := c.SameSite.Validate(); err != nil {
21 | return err
22 | }
23 |
24 | if c.Secure {
25 | return nil
26 | }
27 |
28 | for _, ingress := range cfg.Ingresses {
29 | u, err := url.ParseRequestURI(ingress)
30 | if err != nil {
31 | return fmt.Errorf("parsing ingress URL %q: %w", ingress, err)
32 | }
33 |
34 | if !strings.EqualFold(u.Hostname(), "localhost") {
35 | return fmt.Errorf("ingress %q is not localhost (was %q); cannot disable secure cookies", ingress, u.Hostname())
36 | }
37 |
38 | if u.Scheme != "http" {
39 | return fmt.Errorf("ingress %q is not HTTP (was %q); cannot disable secure cookies", ingress, u.Scheme)
40 | }
41 | }
42 |
43 | logger.Warn("secure cookies are disabled; not suitable for production use!")
44 | return nil
45 | }
46 |
47 | type SameSite string
48 |
49 | const (
50 | SameSiteLax SameSite = "Lax"
51 | SameSiteNone SameSite = "None"
52 | SameSiteStrict SameSite = "Strict"
53 | )
54 |
55 | // ToHttp returns the equivalent http.SameSite value for the SameSite attribute.
56 | func (s SameSite) ToHttp() http.SameSite {
57 | switch s {
58 | case SameSiteNone:
59 | return http.SameSiteNoneMode
60 | case SameSiteStrict:
61 | return http.SameSiteStrictMode
62 | default:
63 | return http.SameSiteLaxMode
64 | }
65 | }
66 |
67 | func (s SameSite) Validate() error {
68 | all := []SameSite{
69 | SameSiteLax,
70 | SameSiteNone,
71 | SameSiteStrict,
72 | }
73 |
74 | if slices.Contains(all, s) {
75 | return nil
76 | }
77 | return fmt.Errorf("%q must be one of %q (was %q)", CookieSameSite, all, s)
78 | }
79 |
80 | const (
81 | CookiePrefix = "cookie.prefix"
82 | CookieSameSite = "cookie.same-site"
83 | CookieSecure = "cookie.secure"
84 | EncryptionKey = "encryption-key"
85 | )
86 |
87 | func cookieFlags() {
88 | flag.String(CookiePrefix, "io.nais.wonderwall", "Prefix for cookie names.")
89 | flag.String(CookieSameSite, string(SameSiteLax), "SameSite attribute for session cookies.")
90 | flag.Bool(CookieSecure, true, "Set secure flag on session cookies. Can only be disabled when `ingress` only consist of localhost hosts. Generally, disabling this is only necessary when using Safari.")
91 | flag.String(EncryptionKey, "", "Base64 encoded 256-bit cookie encryption key; must be identical in instances that share session store.")
92 | }
93 |
--------------------------------------------------------------------------------
/pkg/openid/client/logout_test.go:
--------------------------------------------------------------------------------
1 | package client_test
2 |
3 | import (
4 | "net/url"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 |
9 | "github.com/nais/wonderwall/pkg/mock"
10 | "github.com/nais/wonderwall/pkg/openid/client"
11 | )
12 |
13 | const (
14 | LogoutCallbackURI = mock.Ingress + "/oauth2/logout/callback"
15 | PostLogoutRedirectURI = "http://some-other-url"
16 | EndSessionEndpoint = "http://provider/endsession"
17 | )
18 |
19 | func TestLogout_SingleLogoutURL(t *testing.T) {
20 | t.Run("with id_token", func(t *testing.T) {
21 | logout := newLogout(t)
22 | idToken := "some-id-token"
23 | state := logout.Cookie.State
24 |
25 | raw := logout.SingleLogoutURL(idToken)
26 | assert.NotEmpty(t, raw)
27 |
28 | logoutUrl, err := url.Parse(raw)
29 | assert.NoError(t, err)
30 |
31 | query := logoutUrl.Query()
32 | assert.Len(t, query, 3)
33 |
34 | assert.Contains(t, query, "id_token_hint")
35 | assert.Equal(t, idToken, query.Get("id_token_hint"))
36 |
37 | assert.Contains(t, query, "post_logout_redirect_uri")
38 | assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
39 |
40 | assert.Contains(t, query, "state")
41 | assert.Equal(t, state, query.Get("state"))
42 |
43 | logoutUrl.RawQuery = ""
44 | assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
45 | })
46 |
47 | t.Run("without id_token", func(t *testing.T) {
48 | logout := newLogout(t)
49 | idToken := ""
50 | state := logout.Cookie.State
51 |
52 | raw := logout.SingleLogoutURL(idToken)
53 | assert.NotEmpty(t, raw)
54 |
55 | logoutUrl, err := url.Parse(raw)
56 | assert.NoError(t, err)
57 |
58 | query := logoutUrl.Query()
59 | assert.Len(t, query, 2)
60 |
61 | assert.NotContains(t, query, "id_token_hint")
62 | assert.Equal(t, idToken, query.Get("id_token_hint"))
63 |
64 | assert.Contains(t, query, "post_logout_redirect_uri")
65 | assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
66 |
67 | assert.Contains(t, query, "state")
68 | assert.Equal(t, state, query.Get("state"))
69 |
70 | logoutUrl.RawQuery = ""
71 | assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
72 | })
73 | }
74 |
75 | func newLogout(t *testing.T) *client.Logout {
76 | cfg := mock.Config()
77 |
78 | openidCfg := mock.NewTestConfiguration(cfg)
79 | openidCfg.TestClient.SetPostLogoutRedirectURI(PostLogoutRedirectURI)
80 | openidCfg.TestProvider.SetEndSessionEndpoint(EndSessionEndpoint)
81 | ingresses := mock.Ingresses(cfg)
82 |
83 | req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout", ingresses)
84 |
85 | logout, err := newTestClientWithConfig(openidCfg).Logout(req)
86 | assert.NoError(t, err)
87 |
88 | return logout
89 | }
90 |
--------------------------------------------------------------------------------
/pkg/url/url.go:
--------------------------------------------------------------------------------
1 | package url
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "net/url"
7 |
8 | mw "github.com/nais/wonderwall/pkg/middleware"
9 | "github.com/nais/wonderwall/pkg/router/paths"
10 | )
11 |
12 | const (
13 | RedirectQueryParameter = "redirect"
14 | )
15 |
16 | var ErrNoMatchingIngress = errors.New("request host does not match any configured ingresses")
17 |
18 | // Login constructs a URL string that points to the login path for the given target URL.
19 | // The given redirect string should point to the location to be redirected to after login.
20 | func Login(target *url.URL, redirect string) string {
21 | u := target.JoinPath(paths.OAuth2, paths.Login)
22 |
23 | if len(redirect) > 0 {
24 | v := u.Query()
25 | v.Set(RedirectQueryParameter, redirect)
26 | u.RawQuery = v.Encode()
27 | }
28 |
29 | return u.String()
30 | }
31 |
32 | // LoginRelative constructs the relative URL with an absolute path that points to the application's login path, given an optional path prefix.
33 | // The given redirect string should point to the location to be redirected to after login.
34 | func LoginRelative(prefix, redirect string) string {
35 | u := new(url.URL)
36 | u.Path = prefix
37 |
38 | if prefix == "" {
39 | u.Path = "/"
40 | }
41 |
42 | return Login(u, redirect)
43 | }
44 |
45 | // Logout constructs a URL string that points to the logout path for the given target URL.
46 | // The given redirect string should point to the location to be redirected to after logout.
47 | func Logout(target *url.URL, redirect string) string {
48 | u := target.JoinPath(paths.OAuth2, paths.Logout)
49 |
50 | if len(redirect) > 0 {
51 | v := u.Query()
52 | v.Set(RedirectQueryParameter, redirect)
53 | u.RawQuery = v.Encode()
54 | }
55 |
56 | return u.String()
57 | }
58 |
59 | func LoginCallback(r *http.Request) (string, error) {
60 | return makeCallbackURL(r, paths.LoginCallback)
61 | }
62 |
63 | func LogoutCallback(r *http.Request) (string, error) {
64 | return makeCallbackURL(r, paths.LogoutCallback)
65 | }
66 |
67 | func makeCallbackURL(r *http.Request, callbackPath string) (string, error) {
68 | u, err := MatchingIngress(r)
69 | if err != nil {
70 | return "", err
71 | }
72 |
73 | return u.JoinPath(paths.OAuth2, callbackPath).String(), nil
74 | }
75 |
76 | func MatchingPath(r *http.Request) *url.URL {
77 | u := &url.URL{}
78 |
79 | p, found := mw.PathFrom(r.Context())
80 | if found && len(p) > 0 {
81 | u.Path = p
82 | } else {
83 | u.Path = "/"
84 | }
85 |
86 | return u
87 | }
88 |
89 | func MatchingIngress(r *http.Request) (*url.URL, error) {
90 | ing, found := mw.IngressFrom(r.Context())
91 | if !found {
92 | return nil, ErrNoMatchingIngress
93 | }
94 |
95 | return ing.NewURL(), nil
96 | }
97 |
--------------------------------------------------------------------------------
/pkg/openid/client/logout_callback_test.go:
--------------------------------------------------------------------------------
1 | package client_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 |
8 | "github.com/nais/wonderwall/pkg/config"
9 | "github.com/nais/wonderwall/pkg/mock"
10 | "github.com/nais/wonderwall/pkg/openid"
11 | "github.com/nais/wonderwall/pkg/openid/client"
12 | "github.com/nais/wonderwall/pkg/url"
13 | )
14 |
15 | func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) {
16 | const defaultState = "some-state"
17 | const defaultRedirectURI = "http://some-fancy-logout-page"
18 |
19 | for _, tt := range []struct {
20 | name string
21 | emptyDefaultURI bool
22 | cookie *openid.LogoutCookie
23 | expected string
24 | }{
25 | {
26 | name: "happy path",
27 | expected: defaultRedirectURI,
28 | },
29 | {
30 | name: "empty default uri",
31 | emptyDefaultURI: true,
32 | expected: mock.Ingress,
33 | },
34 | {
35 | name: "state mismatch",
36 | cookie: &openid.LogoutCookie{
37 | State: "some-other-state",
38 | },
39 | expected: defaultRedirectURI,
40 | },
41 | {
42 | name: "happy path, redirect in cookie",
43 | cookie: &openid.LogoutCookie{
44 | State: defaultState,
45 | RedirectTo: "http://wonderwall/some/path",
46 | },
47 | expected: "http://wonderwall/some/path",
48 | },
49 | {
50 | name: "empty redirect in cookie",
51 | cookie: &openid.LogoutCookie{
52 | State: defaultState,
53 | RedirectTo: "",
54 | },
55 | expected: defaultRedirectURI,
56 | },
57 | {
58 | name: "state mismatch, with redirect in cookie",
59 | cookie: &openid.LogoutCookie{
60 | State: "some-other-state",
61 | RedirectTo: "http://wonderwall/some/path",
62 | },
63 | expected: defaultRedirectURI,
64 | },
65 | {
66 | name: "invalid redirect in cookie",
67 | cookie: &openid.LogoutCookie{
68 | State: defaultState,
69 | RedirectTo: "http://not-wonderwall/some/path",
70 | },
71 | expected: defaultRedirectURI,
72 | },
73 | } {
74 | t.Run(tt.name, func(t *testing.T) {
75 | cfg := mock.Config()
76 | cfg.OpenID.PostLogoutRedirectURI = defaultRedirectURI
77 |
78 | if tt.emptyDefaultURI {
79 | cfg.OpenID.PostLogoutRedirectURI = ""
80 | }
81 |
82 | lc := newLogoutCallback(cfg, defaultState, tt.cookie)
83 |
84 | uri := lc.PostLogoutRedirectURI()
85 | assert.NotEmpty(t, uri)
86 | assert.Equal(t, tt.expected, uri)
87 | })
88 | }
89 | }
90 |
91 | func newLogoutCallback(cfg *config.Config, state string, cookie *openid.LogoutCookie) *client.LogoutCallback {
92 | openidCfg := mock.NewTestConfiguration(cfg)
93 | ingresses := mock.Ingresses(cfg)
94 | validator := url.NewAbsoluteValidator(ingresses.Hosts())
95 | req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback?state="+state, ingresses)
96 | return newTestClientWithConfig(openidCfg).LogoutCallback(req, cookie, validator)
97 | }
98 |
--------------------------------------------------------------------------------
/pkg/session/store_redis.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/nais/wonderwall/internal/o11y/otel"
10 | "github.com/nais/wonderwall/pkg/metrics"
11 | "github.com/redis/go-redis/v9"
12 | "go.opentelemetry.io/otel/attribute"
13 | )
14 |
15 | type redisSessionStore struct {
16 | client redis.Cmdable
17 | }
18 |
19 | var _ Store = &redisSessionStore{}
20 |
21 | func NewRedis(client redis.Cmdable) Store {
22 | return &redisSessionStore{
23 | client: client,
24 | }
25 | }
26 |
27 | func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) {
28 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Read")
29 | defer span.End()
30 | span.SetAttributes(attribute.Bool("redis.key_exists", false))
31 |
32 | encryptedData := &EncryptedData{}
33 | err := metrics.ObserveRedisLatency(metrics.RedisOperationRead, func() error {
34 | return s.client.Get(ctx, key).Scan(encryptedData)
35 | })
36 | if err == nil {
37 | span.SetAttributes(attribute.Bool("redis.key_exists", true))
38 | return encryptedData, nil
39 | }
40 |
41 | if errors.Is(err, redis.Nil) {
42 | return nil, fmt.Errorf("%w: %w", ErrNotFound, err)
43 | }
44 |
45 | return nil, err
46 | }
47 |
48 | func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
49 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Write")
50 | defer span.End()
51 | span.SetAttributes(attribute.String("redis.key_expiry", expiration.String()))
52 |
53 | err := metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error {
54 | return s.client.Set(ctx, key, value, expiration).Err()
55 | })
56 | if err != nil {
57 | return err
58 | }
59 |
60 | return nil
61 | }
62 |
63 | func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
64 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Delete")
65 | defer span.End()
66 |
67 | err := metrics.ObserveRedisLatency(metrics.RedisOperationDelete, func() error {
68 | return s.client.Del(ctx, keys...).Err()
69 | })
70 | if err == nil {
71 | return nil
72 | }
73 |
74 | if errors.Is(err, redis.Nil) {
75 | return fmt.Errorf("%w: %w", ErrNotFound, err)
76 | }
77 |
78 | return err
79 | }
80 |
81 | func (s *redisSessionStore) Update(ctx context.Context, key string, value *EncryptedData) error {
82 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Update")
83 | defer span.End()
84 |
85 | _, err := s.Read(ctx, key)
86 | if err != nil {
87 | return err
88 | }
89 |
90 | err = metrics.ObserveRedisLatency(metrics.RedisOperationUpdate, func() error {
91 | return s.client.Set(ctx, key, value, redis.KeepTTL).Err()
92 | })
93 | if err != nil {
94 | return err
95 | }
96 |
97 | return nil
98 | }
99 |
100 | func (s *redisSessionStore) MakeLock(key string) Lock {
101 | return NewRedisLock(s.client, key)
102 | }
103 |
--------------------------------------------------------------------------------
/pkg/session/ticket.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 |
9 | "github.com/nais/liberator/pkg/keygen"
10 | "go.opentelemetry.io/otel/attribute"
11 | "go.opentelemetry.io/otel/trace"
12 |
13 | "github.com/nais/wonderwall/internal/crypto"
14 | "github.com/nais/wonderwall/pkg/cookie"
15 | )
16 |
17 | // Ticket contains the user agent's data required to access their associated session.
18 | type Ticket struct {
19 | // SessionKey identifies the session.
20 | SessionKey string `json:"id"`
21 | // EncryptionKey is the data encryption key (DEK) used to encrypt the session's data.
22 | // Its size is equal to the expected key size for the used AEAD, defined in crypto.KeySize.
23 | EncryptionKey []byte `json:"dek"`
24 | crypter crypto.Crypter
25 | }
26 |
27 | func NewTicket(sessionKey string) (*Ticket, error) {
28 | encKey, err := keygen.Keygen(crypto.KeySize)
29 | if err != nil {
30 | return nil, fmt.Errorf("generate encryption key: %w", err)
31 | }
32 |
33 | return &Ticket{SessionKey: sessionKey, EncryptionKey: encKey}, nil
34 | }
35 |
36 | // Crypter returns a crypto.Crypter initialized with the session's data encryption key.
37 | func (c *Ticket) Crypter() crypto.Crypter {
38 | if c.crypter == nil {
39 | c.crypter = crypto.NewCrypter(c.EncryptionKey)
40 | }
41 | return c.crypter
42 | }
43 |
44 | // Key returns the key that identifies the session.
45 | func (c *Ticket) Key() string {
46 | return c.SessionKey
47 | }
48 |
49 | // SetCookie marshals the Ticket, encrypts the value with the given crypto.Crypter, and writes the resulting cookie to the
50 | // given http.ResponseWriter, applying any cookie.Options to the cookie itself.
51 | func (c *Ticket) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter) error {
52 | b, err := json.Marshal(c)
53 | if err != nil {
54 | return fmt.Errorf("marshalling ticket: %w", err)
55 | }
56 |
57 | return cookie.EncryptAndSet(w, cookie.Session, string(b), opts, crypter)
58 | }
59 |
60 | // getTicket returns a Ticket from the session cookie found in the http.Request, given a crypto.Crypter that
61 | // can decrypt the cookie is provided.
62 | func getTicket(r *http.Request, crypter crypto.Crypter) (*Ticket, error) {
63 | span := trace.SpanFromContext(r.Context())
64 | span.SetAttributes(attribute.Bool("session.valid_ticket", false))
65 |
66 | ticketJson, err := cookie.GetDecrypted(r, cookie.Session, crypter)
67 | if errors.Is(err, http.ErrNoCookie) {
68 | return nil, fmt.Errorf("ticket: cookie %w", ErrNotFound)
69 | }
70 | if errors.Is(err, cookie.ErrInvalidValue) || errors.Is(err, cookie.ErrDecrypt) {
71 | return nil, fmt.Errorf("ticket: cookie: %w: %w", ErrInvalid, err)
72 | }
73 | if err != nil {
74 | return nil, err
75 | }
76 |
77 | var ticket Ticket
78 | err = json.Unmarshal([]byte(ticketJson), &ticket)
79 | if err != nil {
80 | return nil, fmt.Errorf("ticket: unmarshalling: %w", err)
81 | }
82 |
83 | span.SetAttributes(attribute.Bool("session.valid_ticket", true))
84 | return &ticket, nil
85 | }
86 |
--------------------------------------------------------------------------------
/pkg/config/sso.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "fmt"
5 | "net/url"
6 |
7 | flag "github.com/spf13/pflag"
8 | )
9 |
10 | type SSOMode string
11 |
12 | const (
13 | SSOModeServer SSOMode = "server"
14 | SSOModeProxy SSOMode = "proxy"
15 | )
16 |
17 | type SSO struct {
18 | Enabled bool `json:"enabled"`
19 | Domain string `json:"domain"`
20 | Mode SSOMode `json:"mode"`
21 | SessionCookieName string `json:"session-cookie-name"`
22 | ServerURL string `json:"server-url"`
23 | ServerDefaultRedirectURL string `json:"server-default-redirect-url"`
24 | }
25 |
26 | func (s SSO) IsServer() bool {
27 | return s.Enabled && s.Mode == SSOModeServer
28 | }
29 |
30 | func (s SSO) Validate(c *Config) error {
31 | if !s.Enabled {
32 | return nil
33 | }
34 |
35 | if len(c.Redis.Address) == 0 && len(c.Redis.URI) == 0 {
36 | return fmt.Errorf("at least one of %q or %q must be set when %s is set", RedisAddress, RedisURI, SSOEnabled)
37 | }
38 |
39 | if len(s.SessionCookieName) == 0 {
40 | return fmt.Errorf("%q must not be empty when %s is set", SSOSessionCookieName, SSOEnabled)
41 | }
42 |
43 | switch s.Mode {
44 | case SSOModeProxy:
45 | _, err := url.ParseRequestURI(s.ServerURL)
46 | if err != nil {
47 | return fmt.Errorf("%q must be a valid url: %w", SSOServerURL, err)
48 | }
49 | case SSOModeServer:
50 | if len(s.Domain) == 0 {
51 | return fmt.Errorf("%q cannot be empty", SSODomain)
52 | }
53 |
54 | _, err := url.ParseRequestURI(s.ServerDefaultRedirectURL)
55 | if err != nil {
56 | return fmt.Errorf("%q must be a valid url: %w", SSOServerDefaultRedirectURL, err)
57 | }
58 | default:
59 | return fmt.Errorf("%q must be one of [%q, %q]", SSOModeFlag, SSOModeServer, SSOModeProxy)
60 | }
61 |
62 | return nil
63 | }
64 |
65 | const (
66 | SSOEnabled = "sso.enabled"
67 | SSODomain = "sso.domain"
68 | SSOModeFlag = "sso.mode"
69 | SSOServerDefaultRedirectURL = "sso.server-default-redirect-url"
70 | SSOSessionCookieName = "sso.session-cookie-name"
71 | SSOServerURL = "sso.server-url"
72 | )
73 |
74 | func ssoFlags() {
75 | flag.Bool(SSOEnabled, false, "Enable single sign-on mode; one server acting as the OIDC Relying Party, and N proxies. The proxies delegate most endpoint operations to the server, and only implements a reverse proxy that reads the user's session data from the shared store.")
76 | flag.String(SSODomain, "", "The domain that the session cookies should be set for, usually the second-level domain name (e.g. example.com).")
77 | flag.String(SSOModeFlag, string(SSOModeServer), "The SSO mode for this instance. Must be one of 'server' or 'proxy'.")
78 | flag.String(SSOSessionCookieName, "", "Session cookie name. Must be the same across all SSO Servers and Proxies.")
79 | flag.String(SSOServerDefaultRedirectURL, "", "The URL that the SSO server should redirect to by default if a given redirect query parameter is invalid.")
80 | flag.String(SSOServerURL, "", "The URL used by the proxy to point to the SSO server instance.")
81 | }
82 |
--------------------------------------------------------------------------------
/internal/crypto/crypter.go:
--------------------------------------------------------------------------------
1 | package crypto
2 |
3 | import (
4 | cryptorand "crypto/rand"
5 | "encoding/base64"
6 | "fmt"
7 |
8 | "github.com/nais/liberator/pkg/keygen"
9 | log "github.com/sirupsen/logrus"
10 | "golang.org/x/crypto/chacha20poly1305"
11 |
12 | "github.com/nais/wonderwall/pkg/config"
13 | )
14 |
15 | const (
16 | KeySize = chacha20poly1305.KeySize
17 |
18 | // MaxPlaintextSize is set to 64 MB, which is a fairly generous limit. The implementation in x/crypto/xchacha20poly1305 has a plaintext limit to 256 GB.
19 | // We generally only handle data that is stored within a cookie or a session store, i.e. it should be reasonably small.
20 | // In most cases the data is around 4 KB or less, mostly depending on the length of the tokens returned from the identity provider.
21 | MaxPlaintextSize = 64 * 1024 * 1024
22 | )
23 |
24 | type crypter struct {
25 | key []byte
26 | }
27 |
28 | type Crypter interface {
29 | Encrypt([]byte) ([]byte, error)
30 | Decrypt([]byte) ([]byte, error)
31 | }
32 |
33 | func NewCrypter(key []byte) Crypter {
34 | return &crypter{
35 | key: key,
36 | }
37 | }
38 |
39 | func EncryptionKeyOrGenerate(cfg *config.Config) ([]byte, error) {
40 | key, err := base64.StdEncoding.DecodeString(cfg.EncryptionKey)
41 | if err != nil {
42 | if len(cfg.EncryptionKey) > 0 {
43 | return nil, fmt.Errorf("decode encryption key: %w", err)
44 | }
45 | }
46 |
47 | if len(key) == 0 {
48 | log.Warn("no encryption key was provided, generating a random ephemeral key; sessions will not be able to be decrypted after restart")
49 | key, err = keygen.Keygen(KeySize)
50 | if err != nil {
51 | return nil, fmt.Errorf("generate random encryption key: %w", err)
52 | }
53 | }
54 |
55 | if len(key) != chacha20poly1305.KeySize {
56 | return nil, fmt.Errorf("bad key length (expected %d, got %d)", chacha20poly1305.KeySize, len(key))
57 | }
58 |
59 | return key, nil
60 | }
61 |
62 | // Encrypt encrypts a plaintext with XChaCha20-Poly1305.
63 | func (c *crypter) Encrypt(plaintext []byte) ([]byte, error) {
64 | aead, err := chacha20poly1305.NewX(c.key)
65 | if err != nil {
66 | return nil, err
67 | }
68 |
69 | plaintextSize := len(plaintext)
70 | if plaintextSize > MaxPlaintextSize {
71 | return nil, fmt.Errorf("crypter: plaintext too large (%d > %d)", plaintextSize, MaxPlaintextSize)
72 | }
73 |
74 | // Select a random nonce, and leave capacity for the ciphertext.
75 | nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+plaintextSize+aead.Overhead())
76 | _, err = cryptorand.Read(nonce)
77 | if err != nil {
78 | return nil, err
79 | }
80 |
81 | return aead.Seal(nonce, nonce, plaintext, nil), nil
82 | }
83 |
84 | // Decrypt decrypts a ciphertext encrypted with XChaCha20-Poly1305.
85 | func (c *crypter) Decrypt(ciphertext []byte) ([]byte, error) {
86 | aead, err := chacha20poly1305.NewX(c.key)
87 | if err != nil {
88 | return nil, err
89 | }
90 |
91 | if len(ciphertext) < aead.NonceSize() {
92 | return nil, fmt.Errorf("ciphertext is too short")
93 | }
94 |
95 | // Split nonce and ciphertext.
96 | nonce, encrypted := ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():]
97 | return aead.Open(nil, nonce, encrypted, nil)
98 | }
99 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/_resources.yaml:
--------------------------------------------------------------------------------
1 | {{/*
2 | Redis resource template.
3 | Expects a dict as input with the following keys:
4 | - root: The root values, e.g "."
5 | - provider: The identity provider, e.g. "azure" or "idporten"
6 | */}}
7 | {{- define "common.redis.tpl" -}}
8 | {{- $root := .root }}
9 | {{- $provider := .provider }}
10 | {{- $name := include "aiven.redisName" (dict "root" $root "provider" $provider) }}
11 | ---
12 | apiVersion: aiven.io/v1alpha1
13 | kind: Valkey
14 | metadata:
15 | name: {{ $name }}
16 | annotations:
17 | helm.sh/resource-policy: keep
18 | labels:
19 | {{- include "wonderwall.labels" $root | nindent 4 }}
20 | wonderwall.nais.io/provider: {{ $provider }}
21 | spec:
22 | project: {{ $root.Values.aiven.project | required ".Values.aiven.project is required." }}
23 | plan: {{ $root.Values.aiven.redisPlan | required ".Values.aiven.redisPlan is required." }}
24 | maintenanceWindowDow: "sunday"
25 | maintenanceWindowTime: "02:00:00"
26 | terminationProtection: true
27 | userConfig:
28 | valkey_maxmemory_policy: "allkeys-lru"
29 | {{- end }}
30 |
31 | {{/*
32 | Prometheus ServiceIntegration resource template.
33 | Expects a dict as input with the following keys:
34 | - root: The root values, e.g "."
35 | - provider: The identity provider, e.g. "azure" or "idporten"
36 | */}}
37 | {{- define "common.serviceintegration.tpl" -}}
38 | {{- $root := .root }}
39 | {{- $provider := .provider }}
40 | {{- $name := include "aiven.serviceintegrationName" (dict "root" $root "provider" $provider) }}
41 | {{- $redisName := include "aiven.redisName" (dict "root" $root "provider" $provider) }}
42 | ---
43 | apiVersion: aiven.io/v1alpha1
44 | kind: ServiceIntegration
45 | metadata:
46 | name: {{ $name }}
47 | labels:
48 | {{- include "wonderwall.labels" $root | nindent 4 }}
49 | wonderwall.nais.io/provider: {{ $provider }}
50 | spec:
51 | sourceServiceName: {{ $redisName }}
52 | project: {{ $root.Values.aiven.project }}
53 | integrationType: prometheus
54 | destinationEndpointId: {{ $root.Values.aiven.prometheusEndpointId | required ".Values.aiven.prometheusEndpointId is required." | splitList "/" | last }}
55 | {{- end }}
56 |
57 | {{/*
58 | AivenApplication resource template.
59 | Expects a dict as input with the following keys:
60 | - root: The root values, e.g "."
61 | - provider: The identity provider, e.g. "azure" or "idporten"
62 | - access: The access level for the Redis instance, e.g. "readwrite" or "read"
63 | - secretName: The name of the secret that should be generated
64 | */}}
65 | {{- define "common.aivenapplication.tpl" -}}
66 | {{- $root := .root }}
67 | {{- $provider := .provider }}
68 | {{- $access := .access }}
69 | {{- $secretName := .secretName -}}
70 | {{- $instance := include "aiven.instanceName" (dict "provider" $provider "root" $root) }}
71 | ---
72 | apiVersion: aiven.nais.io/v1
73 | kind: AivenApplication
74 | metadata:
75 | name: {{ $instance }}-{{ $access }}
76 | labels:
77 | {{- include "wonderwall.labels" $root | nindent 4 }}
78 | wonderwall.nais.io/provider: {{ $provider }}
79 | spec:
80 | valkey:
81 | - access: {{ $access }}
82 | instance: {{ $instance }}
83 | secretName: {{ $secretName }}
84 | protected: true
85 | {{- end }}
86 |
--------------------------------------------------------------------------------
/pkg/session/store_test.go:
--------------------------------------------------------------------------------
1 | package session_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | jwtlib "github.com/lestrrat-go/jwx/v3/jwt"
9 | "github.com/nais/liberator/pkg/keygen"
10 | "github.com/stretchr/testify/assert"
11 |
12 | "github.com/nais/wonderwall/internal/crypto"
13 | "github.com/nais/wonderwall/pkg/openid"
14 | "github.com/nais/wonderwall/pkg/session"
15 | )
16 |
17 | func decryptedEqual(t *testing.T, expected, actual *session.Data) {
18 | assert.Equal(t, expected.AccessToken, actual.AccessToken)
19 | assert.Equal(t, expected.RefreshToken, actual.RefreshToken)
20 | assert.Equal(t, expected.IDToken, actual.IDToken)
21 | assert.Equal(t, expected.ExternalSessionID, actual.ExternalSessionID)
22 | assert.Equal(t, expected.Acr, actual.Acr)
23 | assert.WithinDuration(t, expected.Metadata.Session.CreatedAt, actual.Metadata.Session.CreatedAt, 0)
24 | assert.WithinDuration(t, expected.Metadata.Session.EndsAt, actual.Metadata.Session.EndsAt, 0)
25 | assert.WithinDuration(t, expected.Metadata.Tokens.ExpireAt, actual.Metadata.Tokens.ExpireAt, 0)
26 | assert.WithinDuration(t, expected.Metadata.Tokens.RefreshedAt, actual.Metadata.Tokens.RefreshedAt, 0)
27 | }
28 |
29 | func makeCrypter(t *testing.T) crypto.Crypter {
30 | key, err := keygen.Keygen(32)
31 | assert.NoError(t, err)
32 | return crypto.NewCrypter(key)
33 | }
34 |
35 | func makeData() *session.Data {
36 | idToken := jwtlib.New()
37 | idToken.Set("jti", "id-token-jti")
38 |
39 | accessToken := "some-access-token"
40 | refreshToken := "some-refresh-token"
41 |
42 | tokens := &openid.Tokens{
43 | AccessToken: accessToken,
44 | IDToken: openid.NewIDToken("id_token", idToken),
45 | RefreshToken: refreshToken,
46 | }
47 |
48 | expiresIn := time.Hour
49 | endsIn := time.Hour
50 |
51 | metadata := session.NewMetadata(expiresIn, endsIn)
52 | return session.NewData("myid", tokens, metadata)
53 | }
54 |
55 | func write(t *testing.T, store session.Store, key string, value *session.EncryptedData) {
56 | err := store.Write(context.Background(), key, value, time.Minute)
57 | assert.NoError(t, err)
58 | }
59 |
60 | func read(t *testing.T, store session.Store, key string, encrypted *session.EncryptedData, crypter crypto.Crypter) *session.Data {
61 | result, err := store.Read(context.Background(), key)
62 | assert.NoError(t, err)
63 | assert.Equal(t, encrypted, result)
64 |
65 | decrypted, err := result.Decrypt(crypter)
66 | assert.NoError(t, err)
67 |
68 | return decrypted
69 | }
70 |
71 | func update(t *testing.T, store session.Store, key string, data *session.Data, crypter crypto.Crypter) (*session.Data, *session.EncryptedData) {
72 | data.AccessToken = "new-access-token"
73 | data.RefreshToken = "new-refresh-token"
74 | encryptedData, err := data.Encrypt(crypter)
75 | assert.NoError(t, err)
76 |
77 | err = store.Update(context.Background(), key, encryptedData)
78 | assert.NoError(t, err)
79 |
80 | return data, encryptedData
81 | }
82 |
83 | func del(t *testing.T, store session.Store, key string) {
84 | err := store.Delete(context.Background(), key)
85 | assert.NoError(t, err)
86 |
87 | result, err := store.Read(context.Background(), key)
88 | assert.Error(t, err)
89 | assert.ErrorIs(t, err, session.ErrNotFound)
90 | assert.Nil(t, result)
91 | }
92 |
--------------------------------------------------------------------------------
/internal/http/request.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "strings"
7 |
8 | "github.com/nais/wonderwall/pkg/cookie"
9 | )
10 |
11 | // IsNavigationRequest checks if the request is a navigation request by using Sec-Fetch headers.
12 | // This is used to separate between redirects for browser navigation and redirects for resource requests (e.g., Fetch or XHR).
13 | // We fall back to checking the Accept header if the browser doesn't support fetch metadata.
14 | func IsNavigationRequest(r *http.Request) bool {
15 | // we assume that navigation requests are always GET requests
16 | if r.Method != http.MethodGet {
17 | return false
18 | }
19 |
20 | mode := r.Header.Get("Sec-Fetch-Mode")
21 | dest := r.Header.Get("Sec-Fetch-Dest")
22 | if mode == "" && dest == "" {
23 | return Accepts(r, "text/html")
24 | }
25 |
26 | return mode == "navigate" && dest == "document"
27 | }
28 |
29 | func HasSecFetchMetadata(r *http.Request) bool {
30 | return r.Header.Get("Sec-Fetch-Mode") != "" && r.Header.Get("Sec-Fetch-Dest") != ""
31 | }
32 |
33 | func Accepts(r *http.Request, accepted ...string) bool {
34 | // iterate over all Accept headers
35 | for _, header := range r.Header.Values("Accept") {
36 | // iterate over all comma-separated values in a single Accept header
37 | for _, v := range strings.Split(header, ",") {
38 | v = strings.ToLower(v)
39 | v = strings.TrimSpace(v)
40 | v = strings.Split(v, ";")[0]
41 |
42 | for _, accept := range accepted {
43 | if v == accept {
44 | return true
45 | }
46 | }
47 | }
48 | }
49 |
50 | return false
51 | }
52 |
53 | // Attributes returns a map of interesting properties for the request.
54 | func Attributes(r *http.Request) map[string]any {
55 | return map[string]any{
56 | "request.cookies": nonEmptyRequestCookies(r),
57 | "request.host": r.Host,
58 | "request.is_navigational": IsNavigationRequest(r),
59 | "request.method": r.Method,
60 | "request.path": r.URL.Path,
61 | "request.protocol": r.Proto,
62 | "request.referer": refererStripped(r),
63 | "request.sec_fetch_dest": r.Header.Get("Sec-Fetch-Dest"),
64 | "request.sec_fetch_mode": r.Header.Get("Sec-Fetch-Mode"),
65 | "request.sec_fetch_site": r.Header.Get("Sec-Fetch-Site"),
66 | "request.user_agent": r.UserAgent(),
67 | }
68 | }
69 |
70 | func nonEmptyRequestCookies(r *http.Request) string {
71 | result := make([]string, 0)
72 |
73 | for _, c := range r.Cookies() {
74 | if !isRelevantCookie(c.Name) || len(c.Value) <= 0 {
75 | continue
76 | }
77 |
78 | result = append(result, c.Name)
79 | }
80 |
81 | return strings.Join(result, ", ")
82 | }
83 |
84 | func isRelevantCookie(name string) bool {
85 | switch name {
86 | case cookie.Session,
87 | cookie.Login,
88 | cookie.Logout:
89 | return true
90 | }
91 |
92 | return false
93 | }
94 |
95 | func refererStripped(r *http.Request) string {
96 | referer := r.Referer()
97 | refererUrl, err := url.Parse(referer)
98 | if err == nil {
99 | refererUrl.RawQuery = ""
100 | refererUrl.RawFragment = ""
101 | referer = refererUrl.String()
102 | }
103 |
104 | return referer
105 | }
106 |
--------------------------------------------------------------------------------
/pkg/middleware/logentry.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "strings"
7 | "time"
8 |
9 | "github.com/go-chi/chi/v5/middleware"
10 | log "github.com/sirupsen/logrus"
11 | "go.opentelemetry.io/otel/trace"
12 |
13 | httpinternal "github.com/nais/wonderwall/internal/http"
14 | "github.com/nais/wonderwall/pkg/router/paths"
15 | )
16 |
17 | type logger struct {
18 | Logger *log.Logger
19 | Provider string
20 | }
21 |
22 | // Logger provides a middleware that logs requests and responses.
23 | func Logger(provider string) logger {
24 | return logger{
25 | Logger: log.StandardLogger(),
26 | Provider: provider,
27 | }
28 | }
29 |
30 | // LogEntryFrom returns a log entry from the request context.
31 | func LogEntryFrom(r *http.Request) *log.Entry {
32 | ctx := r.Context()
33 | entry, ok := ctx.Value(middleware.LogEntryCtxKey).(*logEntryAdapter)
34 | if ok {
35 | return entry.Logger
36 | }
37 |
38 | return log.NewEntry(log.StandardLogger()).
39 | WithField("fallback_logger", true).
40 | WithFields(httpinternal.Attributes(r)).
41 | WithFields(traceFields(r))
42 | }
43 |
44 | func (l *logger) Handler(next http.Handler) http.Handler {
45 | fn := func(w http.ResponseWriter, r *http.Request) {
46 | entry := l.newLogEntry(r)
47 | ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
48 |
49 | if !strings.HasSuffix(r.URL.Path, paths.Ping) {
50 | t1 := time.Now()
51 | defer func() {
52 | entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil)
53 | }()
54 | }
55 |
56 | next.ServeHTTP(ww, middleware.WithLogEntry(r, entry))
57 | }
58 | return http.HandlerFunc(fn)
59 | }
60 |
61 | func (l *logger) newLogEntry(r *http.Request) *logEntryAdapter {
62 | return &logEntryAdapter{
63 | requestFields: httpinternal.Attributes(r),
64 | Logger: l.Logger.WithContext(r.Context()).
65 | WithField("provider", l.Provider).
66 | WithFields(traceFields(r)),
67 | }
68 | }
69 |
70 | // logEntryAdapter implements [middleware.LogEntry]
71 | type logEntryAdapter struct {
72 | Logger *log.Entry
73 | requestFields log.Fields
74 | }
75 |
76 | func (l *logEntryAdapter) Write(status, bytes int, _ http.Header, elapsed time.Duration, _ any) {
77 | responseFields := log.Fields{
78 | "response_status": status,
79 | "response_bytes": bytes,
80 | "response_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0, // in milliseconds, with fractional
81 | }
82 |
83 | l.Logger.WithFields(l.requestFields).
84 | WithFields(responseFields).
85 | Debugf("response: %d %s", status, http.StatusText(status))
86 | }
87 |
88 | func (l *logEntryAdapter) Panic(v interface{}, _ []byte) {
89 | stacktrace := "#"
90 |
91 | fields := log.Fields{
92 | "stacktrace": stacktrace,
93 | "error": fmt.Sprintf("%+v", v),
94 | }
95 |
96 | l.Logger = l.Logger.WithFields(fields)
97 | }
98 |
99 | func traceFields(r *http.Request) log.Fields {
100 | fields := log.Fields{}
101 | span := trace.SpanFromContext(r.Context())
102 | if span.SpanContext().HasTraceID() {
103 | fields["trace_id"] = span.SpanContext().TraceID().String()
104 | } else {
105 | fields["correlation_id"] = middleware.GetReqID(r.Context())
106 | }
107 |
108 | return fields
109 | }
110 |
--------------------------------------------------------------------------------
/pkg/session/session.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 | "time"
9 |
10 | "github.com/nais/wonderwall/internal/crypto"
11 | "github.com/nais/wonderwall/pkg/cookie"
12 | "github.com/nais/wonderwall/pkg/openid"
13 | )
14 |
15 | var (
16 | ErrInactive = errors.New("is inactive")
17 | ErrInvalid = errors.New("session is invalid")
18 | ErrInvalidExternal = errors.New("session has invalid state at identity provider")
19 | ErrNotFound = errors.New("not found")
20 | )
21 |
22 | // Reader knows how to read a session.
23 | type Reader interface {
24 | // Get returns the session for a given http.Request, or an error if the session is invalid or not found.
25 | Get(r *http.Request) (*Session, error)
26 | }
27 |
28 | // Writer knows how to create, update and delete a session.
29 | type Writer interface {
30 | // Create creates and stores a session in the Store.
31 | Create(r *http.Request, tokens *openid.Tokens, sessionLifetime time.Duration) (*Session, error)
32 | // Delete deletes a session for a given Session.
33 | Delete(ctx context.Context, session *Session) error
34 | // DeleteForExternalID deletes a session for a given external session ID (e.g. front-channel logout).
35 | DeleteForExternalID(ctx context.Context, id string) error
36 | // Refresh refreshes the user's tokens and returns the updated session. If the session should not be
37 | // refreshed, it will return the existing session without modifications.
38 | Refresh(r *http.Request, sess *Session) (*Session, error)
39 | }
40 |
41 | // Manager is both a Reader and a Writer.
42 | type Manager interface {
43 | Reader
44 | Writer
45 |
46 | // GetOrRefresh returns the session for a given http.Request. If the tokens within the session are expired and the
47 | // session is still valid, it will automatically attempt to refresh and update the session.
48 | GetOrRefresh(r *http.Request) (*Session, error)
49 | }
50 |
51 | type Session struct {
52 | data *Data
53 | ticket *Ticket
54 | }
55 |
56 | func (in *Session) AccessToken() (string, error) {
57 | if in.data != nil && in.data.HasActiveAccessToken() {
58 | return in.data.AccessToken, nil
59 | }
60 |
61 | return "", fmt.Errorf("%w: access token is expired", ErrInvalid)
62 | }
63 |
64 | func (in *Session) Acr() string {
65 | if in.data != nil {
66 | return in.data.Acr
67 | }
68 | return ""
69 | }
70 |
71 | func (in *Session) ExternalSessionID() string {
72 | if in.data != nil {
73 | return in.data.ExternalSessionID
74 | }
75 |
76 | return ""
77 | }
78 |
79 | func (in *Session) IDToken() string {
80 | return in.data.IDToken
81 | }
82 |
83 | func (in *Session) MetadataVerbose() MetadataVerbose {
84 | return in.data.Metadata.Verbose()
85 | }
86 |
87 | func (in *Session) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter) error {
88 | return in.ticket.SetCookie(w, opts, crypter)
89 | }
90 |
91 | func (in *Session) canRefresh() bool {
92 | return in.data != nil && in.data.HasRefreshToken() && !in.data.Metadata.IsRefreshOnCooldown()
93 | }
94 |
95 | func (in *Session) encrypt() (*EncryptedData, error) {
96 | return in.data.Encrypt(in.ticket.Crypter())
97 | }
98 |
99 | func (in *Session) key() string {
100 | return in.ticket.Key()
101 | }
102 |
103 | func (in *Session) shouldRefresh() bool {
104 | return in.data != nil && in.data.Metadata.ShouldRefresh()
105 | }
106 |
107 | func NewSession(data *Data, ticket *Ticket) *Session {
108 | return &Session{data: data, ticket: ticket}
109 | }
110 |
--------------------------------------------------------------------------------
/docs/sessions.md:
--------------------------------------------------------------------------------
1 | # Session Management
2 |
3 | When a user authenticates themselves, they receive a session. Sessions are stored server-side; we only store a session identifier at the end-user's user agent.
4 |
5 | A session has three states:
6 |
7 | - _active_ - the session is valid
8 | - _inactive_ - the session has reached the _inactivity timeout_ and is considered invalid
9 | - _expired_ - the session has reached its _maximum lifetime_ and is considered invalid
10 |
11 | Requests with an _invalid_ session are considered _unauthenticated_.
12 |
13 | ## Session Metadata
14 |
15 | User agents can access their own session metadata by using [the `/oauth2/session` endpoint](endpoints.md#oauth2session).
16 |
17 | ## Session Expiry
18 |
19 | Every session has a maximum lifetime.
20 | The lifetime is indicated by the `session.ends_at` and `session.ends_in_seconds` fields in the session metadata.
21 |
22 | When the session reaches the maximum lifetime, it is considered to be _expired_, after which the user is essentially unauthenticated.
23 | A new session must be acquired by redirecting the user to [the `/oauth2/login` endpoint](endpoints.md#oauth2login) again.
24 |
25 | The maximum lifetime can be configured with the `session.max-lifetime` flag.
26 |
27 | ## Session Refreshing
28 |
29 | The tokens within the session will usually expire before the session itself.
30 | This is indicated by the `tokens.expire_at` and `tokens.expire_in_seconds` fields in the session metadata.
31 |
32 | If you've configured a session lifetime that is longer than the token expiry, you'll probably want to _refresh_ the tokens to avoid redirecting end-users to the `/oauth2/login` endpoint whenever the access tokens have expired.
33 |
34 | ### Automatic vs Manual Refresh
35 |
36 | The behaviour for refreshing depends on the [runtime mode](configuration.md#modes) for Wonderwall.
37 |
38 | In standalone mode, tokens are automatically refreshed.
39 | Tokens will at the _earliest_ automatically be renewed 5 minutes before they expire.
40 | If the token already _has_ expired, a refresh attempt is still automatically triggered as long as the session itself not has ended or is marked as inactive.
41 |
42 | Automatic refreshes happens whenever the end-user visits or requests any path that is proxied to the upstream application.
43 |
44 | In SSO mode, tokens can not be automatically refreshed. They must be refreshed by performing a request to [the `/oauth2/session/refresh` endpoint](endpoints.md#oauth2sessionrefresh).
45 |
46 | ## Session Inactivity
47 |
48 | A session can be marked as _inactive_ before it _expires_ (reaches the maximum lifetime).
49 | This happens if the time since the last _refresh_ exceeds the given _inactivity timeout_.
50 |
51 | An _inactive_ session _cannot_ be refreshed; a new session must be acquired by redirecting the user to the `/oauth2/login` endpoint.
52 | This is useful if you want to ensure that an end-user can re-authenticate with the identity provider if they've been gone from an authenticated session for some time.
53 |
54 | Inactivity support is enabled with the `session.inactivity` option.
55 |
56 | The activity state of the session is indicated by the `session.active` field in the session metadata.
57 |
58 | The time until the session will be marked as inactive are indicated by the `session.timeout_at` and `session.timeout_in_seconds` fields in the session metadata.
59 |
60 | The timeout is configured with `session.inactivity-timeout`.
61 | If this timeout is shorter than the token expiry, the session metadata fields `tokens.expire_at` and `tokens.expire_in_seconds` will be reduced accordingly to reflect the inactivity timeout.
62 |
--------------------------------------------------------------------------------
/charts/wonderwall/Feature.yaml:
--------------------------------------------------------------------------------
1 | dependencies:
2 | - allOf:
3 | - aiven-operator
4 | - aivenator
5 | - mutilator
6 | - replicator
7 | environmentKinds:
8 | - tenant
9 | - legacy
10 | timeout: "1800s"
11 | values:
12 | aiven.project:
13 | description: Aiven project for Redis.
14 | computed:
15 | template: '"{{ .Env.aiven_project }}"'
16 | aiven.redisPlan:
17 | description: Aiven plan for Redis.
18 | required: true
19 | config:
20 | type: string
21 | aiven.prometheusEndpointId:
22 | description: Aiven Prometheus integration endpoint ID.
23 | computed:
24 | template: '"{{ .Env.aiven_prometheus_endpoint_id }}"'
25 | azure.enabled:
26 | description: Enable Azure AD. Requires Azurerator to be enabled.
27 | config:
28 | type: bool
29 | azure.forwardAuth.enabled:
30 | description: Enables forward auth server. Requires Azurerator and loadbalancer-fa to be enabled.
31 | config:
32 | type: bool
33 | azure.forwardAuth.groupIds:
34 | description: Additional group IDs to grant access to
35 | config:
36 | type: string_array
37 | azure.forwardAuth.sessionCookieEncryptionKey:
38 | description: Cookie encryption key, 256 bits (e.g. 32 ASCII characters) encoded with standard base64.
39 | config:
40 | type: string
41 | secret: true
42 | azure.forwardAuth.ssoDomain:
43 | description: Cookie domain for forward auth
44 | config:
45 | type: string
46 | azure.forwardAuth.ssoDefaultRedirectURL:
47 | description: Default redirect URL for forward auth
48 | config:
49 | type: string
50 | idporten.enabled:
51 | description: Enable ID-porten. Requires Digdirator to be enabled.
52 | config:
53 | type: bool
54 | idporten.openidResourceIndicator:
55 | description: Resource indicator for audience-restricted tokens.
56 | config:
57 | type: string
58 | idporten.openidPostLogoutRedirectURL:
59 | description: Where to redirect the user after global logout.
60 | config:
61 | type: string
62 | idporten.replicasMax:
63 | description: Maximum replicas for SSO server.
64 | config:
65 | type: int
66 | idporten.replicasMin:
67 | description: Minimum replicas for SSO server.
68 | config:
69 | type: int
70 | idporten.sessionCookieEncryptionKey:
71 | description: Cookie encryption key, 256 bits (e.g. 32 ASCII characters) encoded with standard base64.
72 | config:
73 | type: string
74 | secret: true
75 | idporten.sessionCookieName:
76 | description: Cookie name for SSO sessions.
77 | config:
78 | type: string
79 | idporten.ssoDefaultRedirectURL:
80 | description: Fallback URL for invalid SSO redirects.
81 | config:
82 | type: string
83 | idporten.ssoDomain:
84 | description: Allowed domain for SSO (for cookies, CORS and redirect URL validation).
85 | config:
86 | type: string
87 | idporten.ssoServerHost:
88 | description: Host for SSO server.
89 | config:
90 | type: string
91 | idporten.ingressClassName:
92 | description: Ingress class for SSO server.
93 | config:
94 | type: string
95 | image.tag:
96 | config:
97 | type: string
98 | openid.wellKnownUrl:
99 | description: Well-known URL to generic identity provider. Optional. Only needed if a default global provider is desired.
100 | config:
101 | type: string
102 | ignoreKind:
103 | - legacy
104 | resourceSuffix:
105 | description: Suffix for resources that may conflict in parallel environments.
106 | config:
107 | type: string
108 |
--------------------------------------------------------------------------------
/pkg/openid/provider/provider.go:
--------------------------------------------------------------------------------
1 | package provider
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | "time"
8 |
9 | "github.com/lestrrat-go/httprc/v3"
10 | "github.com/lestrrat-go/jwx/v3/jwa"
11 | "github.com/lestrrat-go/jwx/v3/jwk"
12 | "github.com/nais/wonderwall/internal/o11y/otel"
13 | openidconfig "github.com/nais/wonderwall/pkg/openid/config"
14 | "go.opentelemetry.io/otel/attribute"
15 | )
16 |
17 | const (
18 | JwkMinimumRefreshInterval = 5 * time.Second
19 | )
20 |
21 | type JwksProvider struct {
22 | config openidconfig.Provider
23 | jwksCache *jwk.Cache
24 | jwksLock *jwksLock
25 | }
26 |
27 | type jwksLock struct {
28 | lastRefresh time.Time
29 | sync.Mutex
30 | }
31 |
32 | func (p *JwksProvider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
33 | url := p.config.JwksURI()
34 | set, err := p.jwksCache.Lookup(ctx, url)
35 | if err != nil {
36 | return nil, fmt.Errorf("provider: fetching jwks: %w", err)
37 | }
38 |
39 | set, err = ensureJwkSetWithAlg(set, p.config.IDTokenSigningAlg())
40 | if err != nil {
41 | return nil, fmt.Errorf("provider: mutating jwks: %w", err)
42 | }
43 |
44 | return &set, nil
45 | }
46 |
47 | func (p *JwksProvider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
48 | ctx, span := otel.StartSpan(ctx, "JwksProvider.RefreshPublicJwkSet")
49 | defer span.End()
50 | p.jwksLock.Lock()
51 | defer p.jwksLock.Unlock()
52 |
53 | // redirect to cache if recently refreshed to avoid overwhelming provider
54 | diff := time.Since(p.jwksLock.lastRefresh)
55 | if diff < JwkMinimumRefreshInterval {
56 | span.SetAttributes(attribute.Bool("jwks.cooldown", true))
57 | return p.GetPublicJwkSet(ctx)
58 | }
59 |
60 | p.jwksLock.lastRefresh = time.Now()
61 |
62 | url := p.config.JwksURI()
63 | set, err := p.jwksCache.Refresh(ctx, url)
64 | if err != nil {
65 | return nil, fmt.Errorf("provider: refreshing jwks: %w", err)
66 | }
67 |
68 | set, err = ensureJwkSetWithAlg(set, p.config.IDTokenSigningAlg())
69 | if err != nil {
70 | return nil, fmt.Errorf("provider: mutating jwks: %w", err)
71 | }
72 |
73 | span.SetAttributes(attribute.Bool("jwks.refreshed", true))
74 | return &set, nil
75 | }
76 |
77 | func NewJwksProvider(ctx context.Context, openidCfg openidconfig.Config) (*JwksProvider, error) {
78 | providerCfg := openidCfg.Provider()
79 |
80 | uri := providerCfg.JwksURI()
81 | cache, err := jwk.NewCache(ctx, httprc.NewClient())
82 | if err != nil {
83 | return nil, fmt.Errorf("creating jwks cache: %w", err)
84 | }
85 |
86 | if err := cache.Register(ctx, uri); err != nil {
87 | return nil, fmt.Errorf("registering jwks provider uri to cache: %w", err)
88 | }
89 |
90 | return &JwksProvider{
91 | config: providerCfg,
92 | jwksCache: cache,
93 | jwksLock: &jwksLock{},
94 | }, nil
95 | }
96 |
97 | func ensureJwkSetWithAlg(set jwk.Set, expectedAlg jwa.KeyAlgorithm) (jwk.Set, error) {
98 | for i := 0; i < set.Len(); i++ {
99 | key, ok := set.Key(i)
100 | if !ok {
101 | continue
102 | }
103 |
104 | alg, ok := key.Algorithm()
105 | if ok {
106 | // drop keys with "alg=none"
107 | if alg == jwa.NoSignature() {
108 | if err := set.RemoveKey(key); err != nil {
109 | return nil, fmt.Errorf("removing key: %w", err)
110 | }
111 | }
112 |
113 | // don't mutate keys with a valid algorithm
114 | continue
115 | }
116 |
117 | // set "alg" to expected algorithm for keys that don't have one set
118 | if err := key.Set(jwk.AlgorithmKey, expectedAlg); err != nil {
119 | return nil, fmt.Errorf("setting key algorithm: %w", err)
120 | }
121 | }
122 |
123 | return set, nil
124 | }
125 |
--------------------------------------------------------------------------------
/pkg/middleware/prometheus.go:
--------------------------------------------------------------------------------
1 | // This code was originally written by Rene Zbinden and modified by Vladimir Konovalov.
2 | // Copied from https://github.com/766b/chi-prometheus and further adapted.
3 |
4 | package middleware
5 |
6 | import (
7 | "net/http"
8 | "strconv"
9 | "strings"
10 | "time"
11 |
12 | chi_middleware "github.com/go-chi/chi/v5/middleware"
13 | "github.com/prometheus/client_golang/prometheus"
14 |
15 | "github.com/nais/wonderwall/pkg/metrics"
16 | "github.com/nais/wonderwall/pkg/router/paths"
17 | )
18 |
19 | var defaultBuckets = []float64{.001, .01, .05, .1, .5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5}
20 |
21 | const (
22 | serviceName = "wonderwall"
23 | reqsName = "requests_total"
24 | latencyName = "request_duration_seconds"
25 | )
26 |
27 | // PrometheusMiddleware is a handler that exposes prometheus metrics for the number of requests,
28 | // the latency and the response size, partitioned by status code, method and HTTP path.
29 | type PrometheusMiddleware struct {
30 | reqs *prometheus.CounterVec
31 | latency *prometheus.HistogramVec
32 | }
33 |
34 | // Prometheus returns a new PrometheusMiddleware handler.
35 | func Prometheus(provider string, buckets ...float64) *PrometheusMiddleware {
36 | var m PrometheusMiddleware
37 | m.reqs = prometheus.NewCounterVec(
38 | prometheus.CounterOpts{
39 | Name: reqsName,
40 | Help: "How many HTTP requests processed, partitioned by status code, method and HTTP path.",
41 | ConstLabels: prometheus.Labels{"service": serviceName, metrics.LabelProvider: provider},
42 | },
43 | []string{"code", "method", "path", "host"},
44 | )
45 |
46 | if len(buckets) == 0 {
47 | buckets = defaultBuckets
48 | }
49 | m.latency = prometheus.NewHistogramVec(prometheus.HistogramOpts{
50 | Name: latencyName,
51 | Help: "How long it took to process the request, partitioned by status code, method and HTTP path.",
52 | ConstLabels: prometheus.Labels{"service": serviceName, metrics.LabelProvider: provider},
53 | Buckets: buckets,
54 | },
55 | []string{"code", "method", "path", "host"},
56 | )
57 |
58 | prometheus.Register(m.reqs)
59 | prometheus.Register(m.latency)
60 |
61 | return &m
62 | }
63 |
64 | func (m *PrometheusMiddleware) Initialize(path, method string, code int) {
65 | m.reqs.WithLabelValues(
66 | strconv.Itoa(code),
67 | method,
68 | path,
69 | )
70 | }
71 |
72 | func (m *PrometheusMiddleware) Handler(next http.Handler) http.Handler {
73 | relevantPaths := map[string]bool{
74 | paths.OAuth2 + paths.Login: true,
75 | paths.OAuth2 + paths.LoginCallback: true,
76 | paths.OAuth2 + paths.Logout: true,
77 | paths.OAuth2 + paths.LogoutCallback: true,
78 | paths.OAuth2 + paths.LogoutFrontChannel: true,
79 | paths.OAuth2 + paths.LogoutLocal: true,
80 | paths.OAuth2 + paths.Ping: false,
81 | paths.OAuth2 + paths.Session: true,
82 | paths.OAuth2 + paths.Session + paths.Refresh: true,
83 | paths.OAuth2 + paths.Session + paths.ForwardAuth: true,
84 | }
85 |
86 | fn := func(w http.ResponseWriter, r *http.Request) {
87 | found := false
88 | for path, relevant := range relevantPaths {
89 | if strings.HasSuffix(r.URL.Path, path) && relevant {
90 | found = true
91 | break
92 | }
93 | }
94 |
95 | if !found {
96 | next.ServeHTTP(w, r)
97 | return
98 | }
99 |
100 | start := time.Now()
101 | ww := chi_middleware.NewWrapResponseWriter(w, r.ProtoMajor)
102 | next.ServeHTTP(ww, r)
103 | statusCode := strconv.Itoa(ww.Status())
104 | duration := time.Since(start)
105 | m.reqs.WithLabelValues(statusCode, r.Method, r.URL.Path, r.Host).Inc()
106 | m.latency.WithLabelValues(statusCode, r.Method, r.URL.Path, r.Host).Observe(duration.Seconds())
107 | }
108 | return http.HandlerFunc(fn)
109 | }
110 |
--------------------------------------------------------------------------------
/pkg/openid/client/login_callback.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 |
9 | "github.com/nais/wonderwall/internal/o11y/otel"
10 | "github.com/nais/wonderwall/pkg/openid"
11 | "go.opentelemetry.io/otel/attribute"
12 | "go.opentelemetry.io/otel/trace"
13 | )
14 |
15 | var (
16 | ErrCallbackIdentityProvider = errors.New("callback: identity provider error")
17 | ErrCallbackInvalidCookie = errors.New("callback: invalid cookie")
18 | ErrCallbackInvalidState = errors.New("callback: invalid state")
19 | ErrCallbackInvalidIssuer = errors.New("callback: invalid issuer")
20 | ErrCallbackRedeemTokens = errors.New("callback: redeeming tokens")
21 | )
22 |
23 | func (c *Client) LoginCallback(r *http.Request, cookie *openid.LoginCookie) (*openid.Tokens, error) {
24 | span := trace.SpanFromContext(r.Context())
25 | if cookie == nil {
26 | return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidCookie, "cookie is nil")
27 | }
28 |
29 | query := r.URL.Query()
30 |
31 | if oauthError := query.Get("error"); len(oauthError) > 0 {
32 | oauthErrorDescription := query.Get("error_description")
33 | span.SetAttributes(attribute.String("login.oauth_error", oauthError), attribute.String("login.oauth_error_description", oauthErrorDescription))
34 | return nil, fmt.Errorf("%w: %s: %s", ErrCallbackIdentityProvider, oauthError, oauthErrorDescription)
35 | }
36 |
37 | if err := openid.StateMismatchError(query, cookie.State); err != nil {
38 | span.SetAttributes(attribute.Bool("login.state_mismatch", true))
39 | return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidState, err)
40 | }
41 |
42 | if err := c.authorizationServerIssuerIdentification(query.Get("iss")); err != nil {
43 | span.SetAttributes(attribute.Bool("login.issuer_mismatch", true))
44 | return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidIssuer, err)
45 | }
46 |
47 | tokens, err := c.redeemTokens(r.Context(), query.Get("code"), cookie)
48 | if err != nil {
49 | return nil, fmt.Errorf("%w: %s", ErrCallbackRedeemTokens, err)
50 | }
51 |
52 | return tokens, nil
53 | }
54 |
55 | // Verify iss parameter if provider supports RFC 9207 - OAuth 2.0 Authorization Server Issuer Identification
56 | func (c *Client) authorizationServerIssuerIdentification(iss string) error {
57 | if !c.cfg.Provider().AuthorizationResponseIssParameterSupported() {
58 | return nil
59 | }
60 |
61 | if len(iss) == 0 {
62 | return fmt.Errorf("missing issuer parameter")
63 | }
64 |
65 | expectedIss := c.cfg.Provider().Issuer()
66 | if iss != expectedIss {
67 | return fmt.Errorf("issuer mismatch: expected %q, got %q", expectedIss, iss)
68 | }
69 |
70 | return nil
71 | }
72 |
73 | func (c *Client) redeemTokens(ctx context.Context, code string, cookie *openid.LoginCookie) (*openid.Tokens, error) {
74 | ctx, span := otel.StartSpan(ctx, "Client.RedeemTokens")
75 | defer span.End()
76 | clientAuth, err := c.ClientAuthenticationParams()
77 | if err != nil {
78 | return nil, err
79 | }
80 |
81 | payload := openid.ExchangeAuthorizationCodeParams(
82 | c.cfg.Client().ClientID(),
83 | code,
84 | cookie.CodeVerifier,
85 | cookie.RedirectURI,
86 | ).With(clientAuth).AuthCodeOptions()
87 |
88 | rawTokens, err := c.AuthCodeGrant(ctx, code, payload)
89 | if err != nil {
90 | return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
91 | }
92 | span.SetAttributes(attribute.Int64("oauth.token_expires_in_seconds", rawTokens.ExpiresIn))
93 |
94 | jwkSet, err := c.jwksProvider.GetPublicJwkSet(ctx)
95 | if err != nil {
96 | return nil, fmt.Errorf("getting jwks: %w", err)
97 | }
98 |
99 | tokens, err := openid.NewTokens(rawTokens, jwkSet, c.cfg, cookie)
100 | if err != nil {
101 | // JWKS might not be up to date, so we'll want to force a refresh for the next attempt
102 | _, _ = c.jwksProvider.RefreshPublicJwkSet(ctx)
103 | return nil, fmt.Errorf("parsing tokens: %w", err)
104 | }
105 |
106 | return tokens, nil
107 | }
108 |
--------------------------------------------------------------------------------
/pkg/ingress/ingress.go:
--------------------------------------------------------------------------------
1 | package ingress
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/url"
7 | "strings"
8 |
9 | "github.com/nais/wonderwall/pkg/config"
10 | )
11 |
12 | const (
13 | XForwardedHost = "X-Forwarded-Host"
14 | )
15 |
16 | type Ingresses struct {
17 | ingressMap map[string]Ingress
18 | hosts []string
19 | paths []string
20 | urls []string
21 | }
22 |
23 | func ParseIngresses(cfg *config.Config) (*Ingresses, error) {
24 | ingresses := cfg.Ingresses
25 | if len(ingresses) == 0 {
26 | return nil, fmt.Errorf("must have at least 1 ingress")
27 | }
28 |
29 | seen := make(map[string]Ingress)
30 |
31 | for _, raw := range ingresses {
32 | ingress, err := ParseIngress(raw)
33 | if err != nil {
34 | return nil, fmt.Errorf("parsing ingress '%s': %w", raw, err)
35 | }
36 |
37 | if _, found := seen[ingress.String()]; !found {
38 | seen[ingress.String()] = *ingress
39 | }
40 | }
41 |
42 | return &Ingresses{
43 | ingressMap: seen,
44 | hosts: mapIngresses(seen, Ingress.Host),
45 | paths: mapIngresses(seen, Ingress.Path),
46 | urls: mapIngresses(seen, Ingress.String),
47 | }, nil
48 | }
49 |
50 | func (i *Ingresses) Hosts() []string {
51 | return i.hosts
52 | }
53 |
54 | func (i *Ingresses) Paths() []string {
55 | return i.paths
56 | }
57 |
58 | func (i *Ingresses) Strings() []string {
59 | return i.urls
60 | }
61 |
62 | func (i *Ingresses) MatchingIngress(r *http.Request) (Ingress, bool) {
63 | for _, ingress := range i.ingressMap {
64 | hostMatch := ingress.Host() == r.Host || ingress.Host() == r.Header.Get(XForwardedHost)
65 | pathMatch := ingress.Path() == i.MatchingPath(r)
66 |
67 | if hostMatch && pathMatch {
68 | return ingress, true
69 | }
70 | }
71 |
72 | return Ingress{}, false
73 | }
74 |
75 | func (i *Ingresses) MatchingPath(r *http.Request) string {
76 | reqPath := r.URL.Path
77 | result := ""
78 |
79 | for _, p := range i.Paths() {
80 | if len(p) == 0 {
81 | continue
82 | }
83 |
84 | if strings.HasPrefix(reqPath, p) && len(p) > len(result) {
85 | result = p
86 | }
87 | }
88 |
89 | return result
90 | }
91 |
92 | func (i *Ingresses) Single() Ingress {
93 | var res Ingress
94 |
95 | for _, v := range i.ingressMap {
96 | res = v
97 | break
98 | }
99 |
100 | return res
101 | }
102 |
103 | func mapIngresses(ingresses map[string]Ingress, fn func(i Ingress) string) []string {
104 | seen := make(map[string]bool, 0)
105 | result := make([]string, 0)
106 |
107 | for _, ingress := range ingresses {
108 | value := fn(ingress)
109 |
110 | if _, found := seen[value]; !found {
111 | seen[value] = true
112 | result = append(result, value)
113 | }
114 | }
115 |
116 | return result
117 | }
118 |
119 | func ParseIngress(ingress string) (*Ingress, error) {
120 | if len(ingress) == 0 {
121 | return nil, fmt.Errorf("ingress cannot be empty")
122 | }
123 |
124 | u, err := url.ParseRequestURI(ingress)
125 | if err != nil {
126 | return nil, err
127 | }
128 |
129 | if len(u.Host) == 0 {
130 | return nil, fmt.Errorf("must have non-empty host")
131 | }
132 |
133 | err = mustScheme(u)
134 | if err != nil {
135 | return nil, err
136 | }
137 |
138 | u.Path = strings.TrimRight(u.Path, "/")
139 |
140 | return &Ingress{
141 | URL: u,
142 | }, nil
143 | }
144 |
145 | func mustScheme(u *url.URL) error {
146 | validSchemes := []string{"http", "https"}
147 |
148 | valid := false
149 | for _, scheme := range validSchemes {
150 | if u.Scheme == scheme {
151 | valid = true
152 | }
153 | }
154 |
155 | if !valid {
156 | return fmt.Errorf("invalid URL scheme, must be one of %s", validSchemes)
157 | }
158 |
159 | return nil
160 | }
161 |
162 | type Ingress struct {
163 | *url.URL
164 | }
165 |
166 | func (i Ingress) Path() string {
167 | return i.URL.Path
168 | }
169 |
170 | func (i Ingress) Host() string {
171 | return i.URL.Host
172 | }
173 |
174 | func (i Ingress) String() string {
175 | return i.URL.String()
176 | }
177 |
178 | func (i Ingress) NewURL() *url.URL {
179 | u := *i.URL
180 | return &u
181 | }
182 |
--------------------------------------------------------------------------------
/cmd/wonderwall/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 |
8 | "github.com/KimMachineGun/automemlimit/memlimit"
9 | "github.com/nais/wonderwall/internal/crypto"
10 | "github.com/nais/wonderwall/internal/o11y/otel"
11 | "github.com/nais/wonderwall/pkg/config"
12 | "github.com/nais/wonderwall/pkg/cookie"
13 | "github.com/nais/wonderwall/pkg/handler"
14 | "github.com/nais/wonderwall/pkg/metrics"
15 | openidconfig "github.com/nais/wonderwall/pkg/openid/config"
16 | "github.com/nais/wonderwall/pkg/openid/provider"
17 | "github.com/nais/wonderwall/pkg/router"
18 | "github.com/nais/wonderwall/pkg/server"
19 | log "github.com/sirupsen/logrus"
20 | )
21 |
22 | func main() {
23 | err := run()
24 | if err != nil {
25 | log.Fatalf("Fatal error: %s", err)
26 | }
27 | }
28 |
29 | func run() error {
30 | cfg, err := config.Initialize()
31 | if err != nil {
32 | return err
33 | }
34 |
35 | if _, err := memlimit.SetGoMemLimitWithOpts(); err != nil {
36 | log.Debugf("setting GOMEMLIMIT: %+v", err)
37 | }
38 |
39 | key, err := crypto.EncryptionKeyOrGenerate(cfg)
40 | if err != nil {
41 | return err
42 | }
43 |
44 | crypt := crypto.NewCrypter(key)
45 |
46 | ctx, cancel := context.WithCancel(context.Background())
47 | defer cancel()
48 |
49 | if cfg.Cookie.Prefix != cookie.DefaultPrefix {
50 | cookie.ConfigureCookieNamesWithPrefix(cfg.Cookie.Prefix)
51 | }
52 |
53 | if cfg.SSO.Enabled {
54 | cookie.ConfigureCookieNamesWithPrefix(cfg.SSO.SessionCookieName)
55 | cookie.Session = cfg.SSO.SessionCookieName
56 | }
57 |
58 | if cfg.OpenTelemetry.Enabled {
59 | otelShutdown, err := otel.Setup(ctx, cfg)
60 | if err != nil {
61 | return fmt.Errorf("initializing OpenTelemetry: %w", err)
62 | }
63 | defer otelShutdown(ctx)
64 | }
65 |
66 | var src router.Source
67 |
68 | if cfg.SSO.Enabled {
69 | switch cfg.SSO.Mode {
70 | case config.SSOModeServer:
71 | src, err = ssoServer(ctx, cfg, crypt)
72 | case config.SSOModeProxy:
73 | src, err = ssoProxy(cfg, crypt)
74 | default:
75 | return fmt.Errorf("invalid SSO mode: %q", cfg.SSO.Mode)
76 | }
77 | } else {
78 | src, err = standalone(ctx, cfg, crypt)
79 | }
80 | if err != nil {
81 | return fmt.Errorf("initializing routing handler: %w", err)
82 | }
83 |
84 | r := router.New(src, cfg)
85 |
86 | if cfg.MetricsBindAddress != "" {
87 | go func() {
88 | log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress)
89 | err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider)
90 | if err != nil {
91 | log.Fatalf("fatal: metrics server error: %s", err)
92 | }
93 | }()
94 | }
95 |
96 | if cfg.ProbeBindAddress != "" {
97 | go func() {
98 | mux := http.NewServeMux()
99 | healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
100 | w.WriteHeader(http.StatusOK)
101 | w.Write([]byte("ok"))
102 | })
103 | mux.HandleFunc("/", healthz)
104 | mux.HandleFunc("/healthz", healthz)
105 | log.Debugf("probe: listening on %s", cfg.ProbeBindAddress)
106 | err := http.ListenAndServe(cfg.ProbeBindAddress, mux)
107 | if err != nil {
108 | log.Fatalf("fatal: probe server error: %s", err)
109 | }
110 | }()
111 | }
112 |
113 | return server.Start(cfg, r)
114 | }
115 |
116 | func standalone(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.Standalone, error) {
117 | openidConfig, err := openidconfig.NewConfig(cfg)
118 | if err != nil {
119 | return nil, err
120 | }
121 |
122 | jwksProvider, err := provider.NewJwksProvider(ctx, openidConfig)
123 | if err != nil {
124 | return nil, err
125 | }
126 |
127 | return handler.NewStandalone(cfg, jwksProvider, openidConfig, crypt)
128 | }
129 |
130 | func ssoServer(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.SSOServer, error) {
131 | h, err := standalone(ctx, cfg, crypt)
132 | if err != nil {
133 | return nil, err
134 | }
135 |
136 | return handler.NewSSOServer(cfg, h)
137 | }
138 |
139 | func ssoProxy(cfg *config.Config, crypt crypto.Crypter) (*handler.SSOProxy, error) {
140 | return handler.NewSSOProxy(cfg, crypt)
141 | }
142 |
--------------------------------------------------------------------------------
/pkg/url/validator.go:
--------------------------------------------------------------------------------
1 | package url
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "regexp"
7 | "strings"
8 |
9 | mw "github.com/nais/wonderwall/pkg/middleware"
10 | )
11 |
12 | // Used to check final redirects are not susceptible to open redirects.
13 | // Matches //, /\ and both of these with whitespace in between (eg / / or / \).
14 | var invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`)
15 |
16 | var _ Validator = &AbsoluteValidator{}
17 |
18 | type Validator interface {
19 | IsValidRedirect(r *http.Request, redirect string) bool
20 | }
21 |
22 | type AbsoluteValidator struct {
23 | allowedDomains []string
24 | }
25 |
26 | func NewAbsoluteValidator(allowedDomains []string) *AbsoluteValidator {
27 | return &AbsoluteValidator{allowedDomains: allowedDomains}
28 | }
29 |
30 | // IsValidRedirect validates that the given redirect string is a valid absolute URL.
31 | // It must use the 'http' or 'https' scheme.
32 | // It must point to a host that matches the configured list of allowed domains.
33 | func (v *AbsoluteValidator) IsValidRedirect(r *http.Request, redirect string) bool {
34 | u, ok := parsableRequestURI(r, redirect)
35 | if !ok {
36 | return false
37 | }
38 |
39 | if !isRelativeURL(u) && isValidScheme(u) && isAllowedHost(u, v.allowedDomains) {
40 | return true
41 | }
42 |
43 | if isRelativeURL(u) {
44 | mw.LogEntryFrom(r).Infof("validator: not an absolute URL")
45 | return false
46 | }
47 |
48 | if !isValidScheme(u) {
49 | mw.LogEntryFrom(r).Infof("validator: invalid scheme; must be one of ['http', 'https']")
50 | return false
51 | }
52 |
53 | if !isAllowedHost(u, v.allowedDomains) {
54 | mw.LogEntryFrom(r).Infof("validator: host does not match any allowlisted domains: %q", v.allowedDomains)
55 | return false
56 | }
57 |
58 | return false
59 | }
60 |
61 | var _ Validator = &RelativeValidator{}
62 |
63 | type RelativeValidator struct{}
64 |
65 | func NewRelativeValidator() *RelativeValidator {
66 | return &RelativeValidator{}
67 | }
68 |
69 | // IsValidRedirect validates that the given redirect string is a valid relative URL.
70 | // It must be an absolute path (i.e. has a leading '/').
71 | func (v *RelativeValidator) IsValidRedirect(r *http.Request, redirect string) bool {
72 | u, ok := parsableRequestURI(r, redirect)
73 | if !ok {
74 | return false
75 | }
76 |
77 | if isRelativeURL(u) && isValidAbsolutePath(u.String()) {
78 | return true
79 | }
80 |
81 | mw.LogEntryFrom(r).Infof("validator: not a valid relative URL")
82 | return false
83 | }
84 |
85 | func parsableRequestURI(r *http.Request, redirect string) (*url.URL, bool) {
86 | if redirect == "" {
87 | mw.LogEntryFrom(r).Debugf("validator: redirect is empty")
88 | return nil, false
89 | }
90 |
91 | u, err := url.ParseRequestURI(redirect)
92 | if err != nil {
93 | mw.LogEntryFrom(r).Infof("validator: %+v", err)
94 | return nil, false
95 | }
96 |
97 | return u, true
98 | }
99 |
100 | func isAllowedHost(u *url.URL, allowedDomains []string) bool {
101 | host := u.Host
102 | hostname := u.Hostname()
103 |
104 | if host == "" || hostname == "" || len(allowedDomains) == 0 {
105 | return false
106 | }
107 |
108 | for _, allowed := range allowedDomains {
109 | if isAllowedDomain(u, allowed) {
110 | return true
111 | }
112 | }
113 |
114 | return false
115 | }
116 |
117 | func isValidScheme(u *url.URL) bool {
118 | return u.Scheme == "http" || u.Scheme == "https"
119 | }
120 |
121 | func isRelativeURL(u *url.URL) bool {
122 | return u.Scheme == "" && u.Host == ""
123 | }
124 |
125 | func isValidAbsolutePath(redirect string) bool {
126 | return strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect)
127 | }
128 |
129 | func isAllowedDomain(u *url.URL, allowed string) bool {
130 | if len(allowed) == 0 {
131 | return false
132 | }
133 |
134 | host := u.Host
135 | hostname := u.Hostname()
136 |
137 | // exact match on host:port or host
138 | if host == allowed || hostname == allowed {
139 | return true
140 | }
141 |
142 | // subdomain of allowed domain
143 | if !strings.HasPrefix(allowed, ".") {
144 | allowed = "." + allowed
145 | }
146 | return strings.HasSuffix(host, allowed)
147 | }
148 |
--------------------------------------------------------------------------------
/docs/usage.md:
--------------------------------------------------------------------------------
1 | # Usage
2 |
3 | The contract for using Wonderwall is fairly straightforward.
4 |
5 | For any endpoint that requires authentication:
6 |
7 | 1. Validate the `Authorization` header, and any tokens within.
8 | 2. If the `Authorization` header is missing, redirect the user to the [login endpoint](#1-login).
9 | 3. If the JWT `access_token` in the `Authorization` header is invalid or expired, redirect the user to
10 | the [login endpoint](#1-login).
11 | 4. If you need to log out a user, redirect the user to the [logout endpoint](#2-logout).
12 |
13 | Note that Wonderwall does not validate the `access_token` that is attached; this is the responsibility of the upstream application.
14 | Wonderwall only validates the `id_token` in accordance with the OpenID Connect Core specifications.
15 |
16 | ## Scenarios
17 |
18 | ### 1. Login
19 |
20 | When you must authenticate a user, redirect to the user to [the `/oauth2/login` endpoint](endpoints.md#oauth2login).
21 |
22 | #### 1.1. Autologin
23 |
24 | The `auto-login` option will configure Wonderwall to enforce authentication for **all** requests, except for the paths that are explicitly [excluded](configuration.md#auto-login-ignore-paths).
25 |
26 | If the user is _unauthenticated_ or has an [_inactive_ or _expired_ session](sessions.md), all requests will be short-circuited (i.e. return early and **not** proxied to your application).
27 | The short-circuited response depends on whether the request is a _top-level navigation_ request or not.
28 |
29 | A _top-level navigation request_ is a `GET` request that has the [Fetch metadata request headers](https://developer.mozilla.org/en-US/docs/Glossary/Fetch_metadata_request_header) `Sec-Fetch-Dest=document` and `Sec-Fetch-Mode=navigate`.
30 | If the user agent does not support the Fetch metadata headers, we look for an `Accept` header that includes `text/html`, which all major browsers send for navigation requests.
31 | Internet Explorer 8 won't work with this of course, so hopefully you're not in a position that requires supporting this browser.
32 |
33 | A top-level navigation request results in a HTTP 302 Found response with the `Location` header pointing to [the `/oauth2/login` endpoint](endpoints.md#oauth2login).
34 | The `redirect` parameter in the login URL is set to the value found in the `Referer` header, so that the user is redirected back to their intended location after login.
35 | If the `Referer` header is empty, the `redirect` parameter is set to the matching ingress path for the original request.
36 |
37 | Other requests are considered non-navigational requests and result in a HTTP 401 Unauthorized response with the `Location` header set as described above.
38 |
39 | For defence in depth, you should still check the `Authorization` header for a token and validate the token even when using auto-login.
40 |
41 | ### 2. Logout
42 |
43 | When you must log out a user, redirect to the user to [the `/oauth2/logout` endpoint](endpoints.md#oauth2logout).
44 |
45 | The user's session with the sidecar will be cleared, and the user will be redirected to the identity provider for
46 | global/single-logout, if logged in with SSO (single sign-on) at the identity provider.
47 |
48 | #### 2.1 Local Logout
49 |
50 | If you only want to perform a _local logout_ for the user, perform a `GET` request from the user's browser / user agent to [the `/oauth2/logout/local` endpoint](endpoints.md#oauth2logoutlocal).
51 |
52 | This will only clear the user's local session (i.e. remove the cookies) with the sidecar, without performing global logout at the identity provider.
53 | The endpoint responds with a HTTP 204 after successful logout. It will **not** respond with a redirect.
54 |
55 | A local logout is useful for scenarios where users frequently switch between multiple accounts.
56 | This means that they do not have to re-enter their credentials (e.g. username, password, 2FA) between each local logout, as they still have an SSO-session logged in with the identity provider.
57 | If the user is using a shared device with other users, only performing a local logout is thus a security risk.
58 |
59 | **Ensure you understand the difference in intentions between the two logout endpoints. If you're unsure, use `/oauth2/logout`.**
60 |
61 | ### 3. Advanced: Session Management
62 |
63 | See the [session management](sessions.md) page for details.
64 |
--------------------------------------------------------------------------------
/pkg/openid/client/client_test.go:
--------------------------------------------------------------------------------
1 | package client_test
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "strings"
7 | "testing"
8 | "time"
9 |
10 | "github.com/lestrrat-go/jwx/v3/jwa"
11 | "github.com/lestrrat-go/jwx/v3/jws"
12 | "github.com/lestrrat-go/jwx/v3/jwt"
13 | "github.com/stretchr/testify/assert"
14 |
15 | "github.com/nais/wonderwall/pkg/mock"
16 | "github.com/nais/wonderwall/pkg/openid/client"
17 | )
18 |
19 | func TestClientAuthenticationAssertion(t *testing.T) {
20 | cfg := mock.Config()
21 | cfg.OpenID.ClientID = "some-client-id"
22 |
23 | openidConfig := mock.NewTestConfiguration(cfg)
24 | openidConfig.TestProvider.SetIssuer("some-issuer")
25 | c := newTestClientWithConfig(openidConfig)
26 |
27 | expiry := 30 * time.Second
28 | jwtAssertion, err := c.ClientAuthenticationAssertion(expiry)
29 | assert.NoError(t, err)
30 |
31 | assertFlattenedAudience(t, jwtAssertion)
32 |
33 | key := openidConfig.Client().ClientJWK()
34 | publicKey, err := key.PublicKey()
35 | assert.NoError(t, err)
36 |
37 | alg, ok := publicKey.Algorithm()
38 | assert.True(t, ok)
39 |
40 | opts := []jwt.ParseOption{
41 | jwt.WithKey(alg, publicKey),
42 | jwt.WithRequiredClaim(jwt.IssuedAtKey),
43 | jwt.WithRequiredClaim(jwt.ExpirationKey),
44 | jwt.WithRequiredClaim(jwt.JwtIDKey),
45 | }
46 | assertion, err := jwt.ParseString(jwtAssertion, opts...)
47 | assert.NoError(t, err)
48 |
49 | aud, ok := assertion.Audience()
50 | assert.True(t, ok)
51 | assert.ElementsMatch(t, []string{"some-issuer"}, aud)
52 |
53 | iss, ok := assertion.Issuer()
54 | assert.True(t, ok)
55 | assert.Equal(t, "some-client-id", iss)
56 |
57 | sub, ok := assertion.Subject()
58 | assert.True(t, ok)
59 | assert.Equal(t, "some-client-id", sub)
60 |
61 | iat, ok := assertion.IssuedAt()
62 | assert.True(t, ok)
63 | assert.True(t, iat.Before(time.Now()))
64 |
65 | exp, ok := assertion.Expiration()
66 | assert.True(t, ok)
67 | assert.True(t, exp.After(time.Now()))
68 | assert.True(t, exp.Before(time.Now().Add(expiry)))
69 |
70 | msg, err := jws.ParseString(jwtAssertion)
71 | assert.NoError(t, err)
72 | assert.Len(t, msg.Signatures(), 1)
73 | headers := msg.Signatures()[0].ProtectedHeaders()
74 |
75 | typ, ok := headers.Type()
76 | assert.True(t, ok)
77 | assert.Equal(t, "JWT", typ)
78 |
79 | alg, ok = headers.Algorithm()
80 | assert.True(t, ok)
81 | assert.Equal(t, jwa.RS256(), alg)
82 |
83 | expectedKid, ok := key.KeyID()
84 | assert.True(t, ok)
85 | kid, ok := headers.KeyID()
86 | assert.True(t, ok)
87 | assert.Equal(t, expectedKid, kid)
88 | }
89 |
90 | func TestClientAuthenticationAssertionHeader(t *testing.T) {
91 | cfg := mock.Config()
92 | cfg.OpenID.ClientID = "some-client-id"
93 | cfg.OpenID.NewClientAuthJWTType = true
94 |
95 | openidConfig := mock.NewTestConfiguration(cfg)
96 | openidConfig.TestProvider.SetIssuer("some-issuer")
97 | c := newTestClientWithConfig(openidConfig)
98 |
99 | expiry := 30 * time.Second
100 | jwtAssertion, err := c.ClientAuthenticationAssertion(expiry)
101 | assert.NoError(t, err)
102 |
103 | msg, err := jws.ParseString(jwtAssertion)
104 | assert.NoError(t, err)
105 | assert.Len(t, msg.Signatures(), 1)
106 | headers := msg.Signatures()[0].ProtectedHeaders()
107 |
108 | typ, ok := headers.Type()
109 | assert.True(t, ok)
110 | assert.Equal(t, "client-authentication+jwt", typ)
111 | }
112 |
113 | // assertFlattenedAudience asserts that the raw JWT assertion has a flattened audience claim, i.e. aud is a string value.
114 | // We do this as the jwx library only exposes the audience as a slice of strings for parsed JWTs.
115 | func assertFlattenedAudience(t *testing.T, jwtAssertion string) {
116 | parts := strings.Split(jwtAssertion, ".")
117 | assert.Len(t, parts, 3)
118 |
119 | rawClaims, err := base64.RawURLEncoding.DecodeString(parts[1])
120 | assert.NoError(t, err)
121 |
122 | claims := make(map[string]any)
123 | err = json.Unmarshal(rawClaims, &claims)
124 | assert.NoError(t, err)
125 |
126 | assert.Equal(t, "some-issuer", claims["aud"])
127 | }
128 |
129 | func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client {
130 | jwksProvider := mock.NewTestJwksProvider()
131 | return client.NewClient(config, jwksProvider)
132 | }
133 |
--------------------------------------------------------------------------------
/internal/o11y/otel/otel.go:
--------------------------------------------------------------------------------
1 | package otel
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "time"
7 |
8 | "github.com/nais/wonderwall/pkg/config"
9 | log "github.com/sirupsen/logrus"
10 | "github.com/uptrace/opentelemetry-go-extra/otellogrus"
11 | "go.opentelemetry.io/otel"
12 | "go.opentelemetry.io/otel/attribute"
13 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
14 | "go.opentelemetry.io/otel/propagation"
15 | "go.opentelemetry.io/otel/sdk/resource"
16 | tracesdk "go.opentelemetry.io/otel/sdk/trace"
17 | "go.opentelemetry.io/otel/semconv/v1.37.0"
18 | "go.opentelemetry.io/otel/trace"
19 | "go.opentelemetry.io/otel/trace/noop"
20 | )
21 |
22 | const (
23 | // How long between each time OT sends something to the collector.
24 | batchTimeout = 5 * time.Second
25 | )
26 |
27 | var tracer = noop.NewTracerProvider().Tracer("noop")
28 |
29 | func Setup(ctx context.Context, cfg *config.Config) (func(context.Context), error) {
30 | prop := newPropagator()
31 | otel.SetTextMapPropagator(prop)
32 |
33 | res, err := newResource(attributesFrom(cfg))
34 | if err != nil {
35 | return nil, err
36 | }
37 |
38 | tracerProvider, err := newTraceProvider(ctx, res)
39 | if err != nil {
40 | return nil, err
41 | }
42 | otel.SetTracerProvider(tracerProvider)
43 | tracer = tracerProvider.Tracer(cfg.OpenTelemetry.ServiceName)
44 |
45 | log.Debug("opentelemetry: initialized configuration")
46 | shutdown := func(ctx context.Context) {
47 | if err := tracerProvider.Shutdown(ctx); err != nil {
48 | log.Fatalf("fatal: otel shutdown error: %+v", err)
49 | }
50 | }
51 |
52 | // Add OpenTelemetry logging hook to logrus.
53 | // This attaches logs to the associated span in the given log context as events.
54 | log.AddHook(otellogrus.NewHook(otellogrus.WithLevels(
55 | log.PanicLevel,
56 | log.FatalLevel,
57 | log.ErrorLevel,
58 | log.WarnLevel,
59 | )))
60 |
61 | return shutdown, nil
62 | }
63 |
64 | func StartSpan(ctx context.Context, spanName string) (context.Context, trace.Span) {
65 | return tracer.Start(ctx, spanName)
66 | }
67 |
68 | // StartSpanFromRequest starts a span from an incoming HTTP request and returns th request with the updated context.
69 | func StartSpanFromRequest(r *http.Request, spanName string) (*http.Request, trace.Span) {
70 | ctx := r.Context()
71 | ctx, span := StartSpan(ctx, spanName)
72 | return r.WithContext(ctx), span
73 | }
74 |
75 | func AddErrorEvent(span trace.Span, eventName, errType string, err error) {
76 | span.AddEvent(eventName, trace.WithAttributes(
77 | semconv.ExceptionTypeKey.String(errType),
78 | semconv.ExceptionMessageKey.String(err.Error()),
79 | ))
80 | }
81 |
82 | func attributesFrom(cfg *config.Config) []attribute.KeyValue {
83 | attrs := []attribute.KeyValue{
84 | semconv.ServiceName(cfg.OpenTelemetry.ServiceName),
85 | semconv.ServiceVersion(cfg.Version),
86 | attribute.String("wonderwall.identity_provider.name", string(cfg.OpenID.Provider)),
87 | attribute.String("wonderwall.identity_provider.url", cfg.OpenID.WellKnownURL),
88 | attribute.Bool("wonderwall.autologin", cfg.AutoLogin),
89 | attribute.Bool("wonderwall.sso", cfg.SSO.Enabled),
90 | }
91 | if cfg.SSO.Enabled {
92 | attrs = append(attrs, attribute.String("wonderwall.sso.mode", string(cfg.SSO.Mode)))
93 | }
94 | return attrs
95 | }
96 |
97 | func newResource(attributes []attribute.KeyValue) (*resource.Resource, error) {
98 | return resource.Merge(
99 | resource.Default(),
100 | resource.NewWithAttributes(
101 | semconv.SchemaURL,
102 | attributes...,
103 | ),
104 | )
105 | }
106 |
107 | func newPropagator() propagation.TextMapPropagator {
108 | return propagation.NewCompositeTextMapPropagator(
109 | propagation.TraceContext{},
110 | propagation.Baggage{},
111 | )
112 | }
113 |
114 | func newTraceProvider(ctx context.Context, res *resource.Resource) (*tracesdk.TracerProvider, error) {
115 | traceExporter, err := otlptracegrpc.New(ctx)
116 | if err != nil {
117 | return nil, err
118 | }
119 |
120 | traceProvider := tracesdk.NewTracerProvider(
121 | tracesdk.WithBatcher(
122 | traceExporter,
123 | tracesdk.WithBatchTimeout(batchTimeout),
124 | ),
125 | tracesdk.WithResource(res),
126 | )
127 | return traceProvider, nil
128 | }
129 |
--------------------------------------------------------------------------------
/pkg/cookie/cookie.go:
--------------------------------------------------------------------------------
1 | package cookie
2 |
3 | import (
4 | "encoding/base64"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 | "time"
9 |
10 | "github.com/nais/wonderwall/internal/crypto"
11 | )
12 |
13 | const (
14 | DefaultPrefix = "io.nais.wonderwall"
15 | )
16 |
17 | var (
18 | Login = login(DefaultPrefix)
19 | LoginCount = loginCount(DefaultPrefix)
20 | Logout = logout(DefaultPrefix)
21 | Retry = retry(DefaultPrefix)
22 | Session = session(DefaultPrefix)
23 | ErrInvalidValue = errors.New("invalid value")
24 | ErrDecrypt = errors.New("unable to decrypt, key or scheme mismatch")
25 | )
26 |
27 | type Cookie struct {
28 | *http.Cookie
29 | }
30 |
31 | func (in *Cookie) Encrypt(crypter crypto.Crypter) (*Cookie, error) {
32 | plaintext := []byte(in.Cookie.Value)
33 | ciphertext, err := crypter.Encrypt(plaintext)
34 | if err != nil {
35 | return nil, fmt.Errorf("unable to encrypt cookie '%s': %w", in.Cookie.Name, err)
36 | }
37 |
38 | value := base64.RawURLEncoding.EncodeToString(ciphertext)
39 | in.Cookie.Value = value
40 | return in, nil
41 | }
42 |
43 | func (in *Cookie) Decrypt(crypter crypto.Crypter) (string, error) {
44 | ciphertext, err := base64.RawURLEncoding.DecodeString(in.Value)
45 | if err != nil {
46 | return "", fmt.Errorf("%w: named '%s': %w", ErrInvalidValue, in.Name, err)
47 | }
48 |
49 | plaintext, err := crypter.Decrypt(ciphertext)
50 | if err != nil {
51 | return "", fmt.Errorf("%w: named '%s': %w", ErrDecrypt, in.Name, err)
52 | }
53 |
54 | return string(plaintext), err
55 | }
56 |
57 | func Clear(w http.ResponseWriter, name string, opts Options) {
58 | expires := time.Unix(0, 0)
59 | maxAge := -1
60 |
61 | cookie := &http.Cookie{
62 | Expires: expires,
63 | HttpOnly: true,
64 | MaxAge: maxAge,
65 | Name: name,
66 | Path: "/",
67 | SameSite: opts.SameSite,
68 | Secure: opts.Secure,
69 | }
70 |
71 | if len(opts.Domain) > 0 {
72 | cookie.Domain = opts.Domain
73 | }
74 |
75 | if len(opts.Path) > 0 {
76 | cookie.Path = opts.Path
77 | }
78 |
79 | http.SetCookie(w, cookie)
80 | }
81 |
82 | func Get(r *http.Request, key string) (*Cookie, error) {
83 | cookie, err := r.Cookie(key)
84 | if err != nil {
85 | return nil, fmt.Errorf("no cookie named '%s': %w", key, err)
86 | }
87 |
88 | return &Cookie{cookie}, nil
89 | }
90 |
91 | func GetDecrypted(r *http.Request, key string, crypter crypto.Crypter) (string, error) {
92 | encryptedCookie, err := Get(r, key)
93 | if err != nil {
94 | return "", err
95 | }
96 |
97 | return encryptedCookie.Decrypt(crypter)
98 | }
99 |
100 | func Make(name, value string, opts Options) *Cookie {
101 | cookie := &http.Cookie{
102 | HttpOnly: true,
103 | Name: name,
104 | Path: "/",
105 | SameSite: opts.SameSite,
106 | Secure: opts.Secure,
107 | Value: value,
108 | }
109 |
110 | if len(opts.Domain) > 0 {
111 | cookie.Domain = opts.Domain
112 | }
113 |
114 | if len(opts.Path) > 0 {
115 | cookie.Path = opts.Path
116 | }
117 |
118 | return &Cookie{cookie}
119 | }
120 |
121 | func Set(w http.ResponseWriter, cookie *Cookie) {
122 | http.SetCookie(w, cookie.Cookie)
123 | }
124 |
125 | func EncryptAndSet(w http.ResponseWriter, key, value string, opts Options, crypter crypto.Crypter) error {
126 | encryptedCookie, err := Make(key, value, opts).Encrypt(crypter)
127 | if err != nil {
128 | return err
129 | }
130 |
131 | Set(w, encryptedCookie)
132 | return nil
133 | }
134 |
135 | func ConfigureCookieNamesWithPrefix(prefix string) {
136 | Login = login(prefix)
137 | Logout = logout(prefix)
138 | Retry = retry(prefix)
139 | Session = session(prefix)
140 | }
141 |
142 | func withPrefix(prefix, s string) string {
143 | return fmt.Sprintf("%s.%s", prefix, s)
144 | }
145 |
146 | func login(prefix string) string {
147 | return withPrefix(prefix, "callback")
148 | }
149 |
150 | func loginCount(prefix string) string {
151 | return withPrefix(prefix, "logincount")
152 | }
153 |
154 | func logout(prefix string) string {
155 | return withPrefix(prefix, "logout")
156 | }
157 |
158 | func retry(prefix string) string {
159 | return withPrefix(prefix, "retry")
160 | }
161 |
162 | func session(prefix string) string {
163 | return withPrefix(prefix, "session")
164 | }
165 |
--------------------------------------------------------------------------------
/hack/dashboard.yaml:
--------------------------------------------------------------------------------
1 | title: Wonderwall
2 | editable: true
3 | tags: [generated, yaml]
4 | auto_refresh: 1m
5 | time: ["now-24h", "now"]
6 | timezone: default # valid values are: utc, browser, default
7 |
8 | # Render to JSON using https://github.com/K-Phoen/grabana v0.17.0 or newer
9 | # Import into Grafana using UI (remember to select folder)
10 |
11 | variables:
12 | - custom:
13 | name: env
14 | default: dev
15 | values_map:
16 | dev: dev
17 | prod: prod
18 | - query:
19 | name: namespace
20 | label: Namespace
21 | datasource: $env-gcp
22 | request: "label_values(kube_pod_container_info{container=\"wonderwall\"}, namespace)"
23 | include_all: true
24 | default_all: true
25 | all_value: ".*"
26 | - datasource:
27 | name: ds
28 | type: prometheus
29 | regex: $env-gcp
30 | include_all: true
31 | hide: variable
32 | - query:
33 | name: redis_op
34 | label: Redis Operation
35 | datasource: $env-gcp
36 | request: "label_values(wonderwall_redis_latency_bucket, operation)"
37 | include_all: true
38 | default_all: true
39 | hide: variable
40 |
41 | rows:
42 | - name: Versions
43 | collapse: false
44 | panels:
45 | - single_stat:
46 | title: Sidecar versions in use
47 | datasource: $ds
48 | transparent: true
49 | span: 12
50 | targets:
51 | - prometheus:
52 | query: count(label_replace(kube_pod_container_info{container="wonderwall",namespace=~"$namespace"}, "version", "$1", "image", ".*:(.*)")) by (version)
53 | legend: "{{ version }}"
54 | instant: true
55 | - name: Resource usage
56 | collapse: false
57 | panels:
58 | - graph:
59 | title: Memory usage - $ds
60 | datasource: $ds
61 | transparent: true
62 | targets:
63 | - prometheus:
64 | query: sum(container_memory_working_set_bytes{container="wonderwall",namespace=~"$namespace"}) by (pod, namespace)
65 | legend: "working set {{ pod }} in {{ namespace }}"
66 | - prometheus:
67 | query: sum(container_memory_usage_bytes{container="wonderwall",namespace=~"$namespace"}) by (pod, namespace)
68 | legend: "Resident set size {{ pod }} in {{ namespace }}"
69 | - graph:
70 | title: CPU usage - $ds
71 | datasource: $ds
72 | transparent: true
73 | targets:
74 | - prometheus:
75 | query: sum(irate(container_cpu_usage_seconds_total{container="wonderwall",namespace=~"$namespace"}[2m])) by (pod, namespace)
76 | legend: "{{ pod }} in {{ namespace }}"
77 | - name: HTTP
78 | collapse: false
79 | panels:
80 | - graph:
81 | title: HTTP requests
82 | datasource: $ds
83 | transparent: true
84 | targets:
85 | - prometheus:
86 | query: sum(rate(requests_total{job="wonderwall",namespace=~"$namespace"}[5m])) by (code)
87 | legend: "{{ code }}"
88 | - graph:
89 | title: HTTP latency
90 | datasource: $ds
91 | transparent: true
92 | targets:
93 | - prometheus:
94 | query: sum(irate(request_duration_seconds_sum{job="wonderwall",namespace=~"$namespace"}[2m])) by (path)
95 | legend: "{{ path }}"
96 | - name: Redis Latency
97 | collapse: false
98 | panels:
99 | - heatmap:
100 | # Must be done manually in Grafana after import: Set max datapoints to 25
101 | title: $redis_op
102 | datasource: $ds
103 | repeat: redis_op
104 | data_format: time_series_buckets
105 | hide_zero_buckets: true
106 | transparent: true
107 | span: 4
108 | tooltip:
109 | show: true
110 | showhistogram: false
111 | decimals: 0
112 | yaxis:
113 | unit: "dtdurations"
114 | decimals: 0
115 | targets:
116 | - prometheus:
117 | query: sum(increase(wonderwall_redis_latency_bucket{operation="$redis_op",namespace=~"$namespace"}[$__interval])) by (le)
118 | legend: "{{ le }}"
119 | format: heatmap
120 |
--------------------------------------------------------------------------------
/pkg/openid/config/client.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/lestrrat-go/jwx/v3/jwk"
7 | log "github.com/sirupsen/logrus"
8 |
9 | "github.com/nais/wonderwall/pkg/config"
10 | "github.com/nais/wonderwall/pkg/openid/scopes"
11 | )
12 |
13 | type AuthMethod string
14 |
15 | const (
16 | AuthMethodPrivateKeyJWT AuthMethod = "private_key_jwt"
17 | AuthMethodClientSecret AuthMethod = "client_secret"
18 | )
19 |
20 | type Client interface {
21 | ACRValues() string
22 | Audiences() map[string]bool
23 | AuthMethod() AuthMethod
24 | ClientID() string
25 | ClientJWK() jwk.Key
26 | ClientSecret() string
27 | NewClientAuthJWTType() bool
28 | PostLogoutRedirectURI() string
29 | ResourceIndicator() string
30 | Scopes() scopes.Scopes
31 | UILocales() string
32 | WellKnownURL() string
33 | }
34 |
35 | type client struct {
36 | config.OpenID
37 | authMethod AuthMethod
38 | clientJwk jwk.Key
39 | trustedAudiences map[string]bool
40 | }
41 |
42 | var _ Client = (*client)(nil)
43 |
44 | func (in *client) ACRValues() string {
45 | return in.OpenID.ACRValues
46 | }
47 |
48 | func (in *client) Audiences() map[string]bool {
49 | return in.trustedAudiences
50 | }
51 |
52 | func (in *client) AuthMethod() AuthMethod {
53 | return in.authMethod
54 | }
55 |
56 | func (in *client) ClientID() string {
57 | return in.OpenID.ClientID
58 | }
59 |
60 | func (in *client) ClientJWK() jwk.Key {
61 | return in.clientJwk
62 | }
63 |
64 | func (in *client) ClientSecret() string {
65 | return in.OpenID.ClientSecret
66 | }
67 |
68 | func (in *client) NewClientAuthJWTType() bool {
69 | return in.OpenID.NewClientAuthJWTType
70 | }
71 |
72 | func (in *client) PostLogoutRedirectURI() string {
73 | return in.OpenID.PostLogoutRedirectURI
74 | }
75 |
76 | func (in *client) ResourceIndicator() string {
77 | return in.OpenID.ResourceIndicator
78 | }
79 |
80 | func (in *client) Scopes() scopes.Scopes {
81 | return scopes.DefaultScopes().WithAdditional(in.OpenID.Scopes...)
82 | }
83 |
84 | func (in *client) UILocales() string {
85 | return in.OpenID.UILocales
86 | }
87 |
88 | func (in *client) WellKnownURL() string {
89 | return in.OpenID.WellKnownURL
90 | }
91 |
92 | func NewClientConfig(cfg *config.Config) (Client, error) {
93 | c := &client{
94 | OpenID: cfg.OpenID,
95 | trustedAudiences: cfg.OpenID.TrustedAudiences(),
96 | }
97 |
98 | if len(cfg.OpenID.ClientJWK) == 0 && len(cfg.OpenID.ClientSecret) == 0 {
99 | return nil, fmt.Errorf("missing required config: at least one of %q or %q must be set", config.OpenIDClientJWK, config.OpenIDClientSecret)
100 | }
101 |
102 | if len(cfg.OpenID.ClientSecret) > 0 {
103 | c.authMethod = AuthMethodClientSecret
104 | }
105 |
106 | if len(cfg.OpenID.ClientJWK) > 0 {
107 | if c.authMethod == AuthMethodClientSecret {
108 | log.WithField("logger", "wonderwall.config").Debug("both client JWK and client secret were set; using client JWK...")
109 | }
110 |
111 | clientJwk, err := jwk.ParseKey([]byte(cfg.OpenID.ClientJWK))
112 | if err != nil {
113 | return nil, fmt.Errorf("parsing client JWK: %w", err)
114 | }
115 |
116 | c.clientJwk = clientJwk
117 | c.authMethod = AuthMethodPrivateKeyJWT
118 | }
119 |
120 | var clientConfig Client
121 | switch cfg.OpenID.Provider {
122 | case config.ProviderIDPorten:
123 | clientConfig = c.IDPorten()
124 | case config.ProviderAzure:
125 | clientConfig = c.Azure()
126 | case "":
127 | return nil, fmt.Errorf("missing required config %q", config.OpenIDProvider)
128 | default:
129 | clientConfig = c
130 | }
131 |
132 | if len(clientConfig.ClientID()) == 0 {
133 | return nil, fmt.Errorf("missing required config %q", config.OpenIDClientID)
134 | }
135 |
136 | if len(clientConfig.WellKnownURL()) == 0 {
137 | return nil, fmt.Errorf("missing required config %q", config.OpenIDWellKnownURL)
138 | }
139 |
140 | return clientConfig, nil
141 | }
142 |
143 | type azure struct {
144 | *client
145 | }
146 |
147 | func (in *client) Azure() *azure {
148 | return &azure{
149 | client: in,
150 | }
151 | }
152 |
153 | func (in *azure) Scopes() scopes.Scopes {
154 | return scopes.DefaultScopes().
155 | WithAzureScope(in.OpenID.ClientID).
156 | WithOfflineAccess().
157 | WithAdditional(in.OpenID.Scopes...)
158 | }
159 |
160 | type idporten struct {
161 | *client
162 | }
163 |
164 | func (in *client) IDPorten() *idporten {
165 | return &idporten{
166 | client: in,
167 | }
168 | }
169 |
--------------------------------------------------------------------------------
/pkg/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/nais/wonderwall/pkg/config"
7 | "github.com/stretchr/testify/assert"
8 | "github.com/stretchr/testify/require"
9 | )
10 |
11 | func TestConfig_Validate(t *testing.T) {
12 | type test struct {
13 | name string
14 | mutate func(cfg *config.Config)
15 | }
16 |
17 | run := func(name string, base *config.Config, errorCases []test) {
18 | t.Run(name, func(t *testing.T) {
19 | t.Run("happy path", func(t *testing.T) {
20 | assert.NoError(t, base.Validate())
21 | })
22 |
23 | for _, tt := range errorCases {
24 | t.Run(tt.name, func(t *testing.T) {
25 | cfg := *base
26 | tt.mutate(&cfg)
27 | assert.Error(t, cfg.Validate())
28 | })
29 | }
30 | })
31 | }
32 |
33 | base, err := config.Initialize()
34 | require.NoError(t, err)
35 |
36 | run("default", base, []test{
37 | {
38 | "invalid value for cookie.same-site",
39 | func(cfg *config.Config) {
40 | cfg.Cookie.SameSite = "invalid"
41 | },
42 | },
43 | {
44 | "upstream ip must be set if port is set",
45 | func(cfg *config.Config) {
46 | cfg.UpstreamIP = ""
47 | cfg.UpstreamPort = 8080
48 | },
49 | },
50 | {
51 | "upstream port must be set if ip is set",
52 | func(cfg *config.Config) {
53 | cfg.UpstreamIP = "127.0.0.1"
54 | cfg.UpstreamPort = 0
55 | },
56 | },
57 | {
58 | "upstream port must not exceed 65535",
59 | func(cfg *config.Config) {
60 | cfg.UpstreamIP = "127.0.0.1"
61 | cfg.UpstreamPort = 65536
62 | },
63 | },
64 | {
65 | "upstream port must not be negative",
66 | func(cfg *config.Config) {
67 | cfg.UpstreamIP = "127.0.0.1"
68 | cfg.UpstreamPort = -1
69 | },
70 | },
71 | {
72 | "shutdown graceful period must be greater than wait before period",
73 | func(cfg *config.Config) {
74 | cfg.ShutdownGracefulPeriod = 1
75 | cfg.ShutdownWaitBeforePeriod = 1
76 | },
77 | },
78 | {
79 | "secure cookies cannot be disabled for non-localhost ingress",
80 | func(cfg *config.Config) {
81 | cfg.Cookie.Secure = false
82 | cfg.Ingresses = []string{"http://not-localhost.example"}
83 | },
84 | },
85 | {
86 | "secure cookies cannot be disabled for secure ingress",
87 | func(cfg *config.Config) {
88 | cfg.Cookie.Secure = false
89 | cfg.Ingresses = []string{"https://localhost:3000"}
90 | },
91 | },
92 | })
93 |
94 | server := *base
95 | server.SSO.Enabled = true
96 | server.SSO.Mode = config.SSOModeServer
97 | server.SSO.Domain = "example.com"
98 | server.SSO.SessionCookieName = "some-cookie"
99 | server.SSO.ServerDefaultRedirectURL = "https://default.local"
100 | server.Redis.Address = "localhost:6379"
101 |
102 | run("sso server", &server, []test{
103 | {
104 | "missing redis",
105 | func(cfg *config.Config) {
106 | cfg.Redis = config.Redis{}
107 | },
108 | },
109 | {
110 | "missing cookie name",
111 | func(cfg *config.Config) {
112 | cfg.SSO.SessionCookieName = ""
113 | },
114 | },
115 | {
116 | "missing session cookie name",
117 | func(cfg *config.Config) {
118 | cfg.SSO.SessionCookieName = ""
119 | },
120 | },
121 | {
122 | "missing domain",
123 | func(cfg *config.Config) {
124 | cfg.SSO.Domain = ""
125 | },
126 | },
127 | {
128 | "invalid default redirect url",
129 | func(cfg *config.Config) {
130 | cfg.SSO.ServerDefaultRedirectURL = "invalid"
131 | },
132 | },
133 | {
134 | "invalid mode",
135 | func(cfg *config.Config) {
136 | cfg.SSO.Mode = "invalid"
137 | },
138 | },
139 | })
140 |
141 | proxy := *base
142 | proxy.SSO.Enabled = true
143 | proxy.SSO.Mode = config.SSOModeProxy
144 | proxy.SSO.ServerURL = "https://sso-server.local"
145 | proxy.SSO.SessionCookieName = "some-cookie"
146 | proxy.Redis.Address = "localhost:6379"
147 |
148 | run("sso proxy", &proxy, []test{
149 | {
150 | "missing redis",
151 | func(cfg *config.Config) {
152 | cfg.Redis = config.Redis{}
153 | },
154 | },
155 | {
156 | "missing cookie name",
157 | func(cfg *config.Config) {
158 | cfg.SSO.SessionCookieName = ""
159 | },
160 | },
161 | {
162 | "missing session cookie name",
163 | func(cfg *config.Config) {
164 | cfg.SSO.SessionCookieName = ""
165 | },
166 | },
167 | {
168 | "invalid server url",
169 | func(cfg *config.Config) {
170 | cfg.SSO.ServerURL = "invalid"
171 | },
172 | },
173 | {
174 | "invalid mode",
175 | func(cfg *config.Config) {
176 | cfg.SSO.Mode = "invalid"
177 | },
178 | },
179 | })
180 | }
181 |
--------------------------------------------------------------------------------
/charts/wonderwall/templates/prometheusrule.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.idporten.enabled }}
2 | ---
3 | apiVersion: monitoring.coreos.com/v1
4 | kind: PrometheusRule
5 | metadata:
6 | name: {{ include "wonderwall.fullname" . }}-idporten-alerts
7 | labels:
8 | {{- include "wonderwall.labels" . | nindent 4 }}
9 | spec:
10 | groups:
11 | - name: "wonderwall-idporten"
12 | rules:
13 | - alert: Wonderwall sidecars for ID-porten reports a high amount of internal errors
14 | expr: sum(increase(requests_total{job="nais-system/monitoring-wonderwall", code="500", provider="idporten"}[5m])) > 10
15 | for: 5m
16 | annotations:
17 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes.
18 | consequence: This probably means that end-users are having trouble with authentication.
19 | action: |
20 | * Check the logs for errors:
21 | * Check DigDir status: /
22 | * Check Aiven Redis (session store) and verify Aiven network connectivity
23 | * Check with DigDir in [#nav-digdir](https://nav-it.slack.com/archives/C013RTT99G9)
24 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=idporten"
25 | labels:
26 | severity: critical
27 | namespace: {{ .Release.Namespace }}
28 | - alert: Wonderwall SSO server for ID-porten reports a high amount of internal errors
29 | expr: sum(increase(requests_total{app="wonderwall-idporten", namespace="{{ .Release.Namespace }}", code="500"}[5m])) > 10
30 | for: 5m
31 | annotations:
32 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes.
33 | consequence: This probably means that end-users are having trouble with authentication.
34 | action: |
35 | * Check the logs for errors:
36 | * Check DigDir status: /
37 | * Check Aiven Redis (session store) and verify Aiven network connectivity
38 | * Check with DigDir in [#nav-digdir](https://nav-it.slack.com/archives/C013RTT99G9)
39 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=idporten&var-namespace=nais-system"
40 | labels:
41 | severity: critical
42 | namespace: {{ .Release.Namespace }}
43 | {{ end }}
44 | {{ if .Values.azure.enabled }}
45 | ---
46 | apiVersion: monitoring.coreos.com/v1
47 | kind: PrometheusRule
48 | metadata:
49 | name: {{ include "wonderwall.fullname" . }}-azure-alerts
50 | labels:
51 | {{- include "wonderwall.labels" . | nindent 4 }}
52 | spec:
53 | groups:
54 | - name: "wonderwall-azure"
55 | rules:
56 | - alert: Wonderwall for Azure AD reports a high amount of internal errors
57 | expr: sum(increase(requests_total{job="nais-system/monitoring-wonderwall", code="500", provider="azure"}[5m])) > 30
58 | for: 5m
59 | annotations:
60 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes.
61 | consequence: This probably means that end-users are having trouble with authentication.
62 | action: |
63 | * Check the logs for errors:
64 | * Check Azure status: https://status.azure.com/nb-no/status
65 | * Check Aiven Redis (session store) and verify Aiven network connectivity
66 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=azure"
67 | labels:
68 | severity: critical
69 | namespace: {{ .Release.Namespace }}
70 | {{- if .Values.azure.forwardAuth.enabled }}
71 | - alert: Wonderwall-fa (forward-auth / ansatt) for Azure AD reports a high amount of internal errors
72 | expr: sum(increase(requests_total{app="wonderwall-fa", namespace="{{ .Release.Namespace }}", code="500"}[5m])) > 30
73 | for: 5m
74 | annotations:
75 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes.
76 | consequence: This probably means that end-users are having trouble with authentication.
77 | action: |
78 | * Check the logs for errors:
79 | * Check Azure status: https://status.azure.com/nb-no/status
80 | * Check Aiven Redis (session store) and verify Aiven network connectivity
81 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=azure&var-namespace=nais-system"
82 | labels:
83 | severity: critical
84 | namespace: {{ .Release.Namespace }}
85 | {{- end }}
86 | {{ end }}
87 |
--------------------------------------------------------------------------------