├── CODEOWNERS ├── templates ├── package.json ├── css.go ├── Makefile ├── package-lock.json ├── README.md ├── error.go ├── input.css └── error.gohtml ├── .gitignore ├── docs ├── README.md ├── sessions.md └── usage.md ├── charts ├── wonderwall │ ├── templates │ │ ├── serviceaccount.yaml │ │ ├── configmap.yaml │ │ ├── idporten-secret.yaml │ │ ├── idporten-poddisruptionbudget.yaml │ │ ├── idporten-service.yaml │ │ ├── fa-poddisruptionbudget.yaml │ │ ├── fa-service.yaml │ │ ├── fa-secret.yaml │ │ ├── idporten-horizontalpodautoscaler.yaml │ │ ├── fa-horizontalpodautoscaler.yaml │ │ ├── idporten-ingress.yaml │ │ ├── fa-ingress.yaml │ │ ├── fa-azureapp.yaml │ │ ├── idporten-idportenclient.yaml │ │ ├── redis.yaml │ │ ├── networkpolicy.yaml │ │ ├── _resources.yaml │ │ └── prometheusrule.yaml │ ├── Chart.yaml │ ├── .helmignore │ ├── values.yaml │ └── Feature.yaml └── wonderwall-forward-auth │ ├── Chart.yaml │ ├── templates │ ├── poddisruptionbudget.yaml │ ├── service.yaml │ ├── secret.yaml │ ├── servicemonitor.yaml │ ├── horizontalpodautoscaler.yaml │ ├── ingress.yaml │ ├── prometheusrule.yaml │ ├── networkpolicy.yaml │ └── _helpers.tpl │ ├── .helmignore │ ├── values.yaml │ └── Feature.yaml ├── .envrc ├── internal ├── crypto │ ├── text.go │ ├── jwk.go │ ├── crypter_test.go │ └── crypter.go ├── o11y │ ├── otel │ │ ├── otel_test.go │ │ ├── middleware.go │ │ └── otel.go │ └── logging │ │ └── logging.go ├── http │ ├── transport.go │ ├── middleware.go │ └── request.go └── retry │ └── retry.go ├── pkg ├── router │ └── paths │ │ └── paths.go ├── handler │ ├── path.go │ ├── acr │ │ └── acr.go │ ├── handler_sso_server.go │ └── autologin │ │ └── autologin.go ├── mock │ ├── request.go │ ├── config.go │ └── client.go ├── middleware │ ├── correlationid.go │ ├── cors.go │ ├── ingress.go │ ├── context.go │ ├── logentry.go │ └── prometheus.go ├── openid │ ├── oauth2_test.go │ ├── scopes │ │ └── scopes.go │ ├── config │ │ ├── config.go │ │ ├── provider_test.go │ │ └── client.go │ ├── client │ │ ├── logout_callback.go │ │ ├── logout.go │ │ ├── logout_test.go │ │ ├── logout_callback_test.go │ │ ├── login_callback.go │ │ └── client_test.go │ ├── cookies.go │ ├── acr │ │ ├── acr_test.go │ │ └── acr.go │ └── provider │ │ └── provider.go ├── cookie │ ├── options.go │ ├── options_test.go │ └── cookie.go ├── session │ ├── store_memory_test.go │ ├── lock_test.go │ ├── store_redis_test.go │ ├── store_memory.go │ ├── id.go │ ├── store.go │ ├── lock.go │ ├── session_reader.go │ ├── store_redis.go │ ├── ticket.go │ ├── store_test.go │ └── session.go ├── config │ ├── ratelimit.go │ ├── otel.go │ ├── session.go │ ├── redis.go │ ├── cookie.go │ ├── sso.go │ └── config_test.go ├── server │ └── server.go ├── url │ ├── url.go │ └── validator.go └── ingress │ └── ingress.go ├── Dockerfile ├── docker-compose.yml ├── .github ├── workflows │ └── dependabot-auto-merge.yaml └── dependabot.yml ├── LICENSE ├── mise.toml ├── docker-compose.example.yml ├── flake.nix ├── flake.lock ├── cmd └── wonderwall │ └── main.go └── hack └── dashboard.yaml /CODEOWNERS: -------------------------------------------------------------------------------- 1 | @navikt/aura 2 | 3 | -------------------------------------------------------------------------------- /templates/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "devDependencies": { 3 | "tailwindcss": "^4.0.0" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | run.sh 3 | .vscode/ 4 | .idea 5 | *.iml 6 | .env* 7 | *.out 8 | *.tgz 9 | 10 | # nix stuffs 11 | /.direnv 12 | result 13 | -------------------------------------------------------------------------------- /templates/css.go: -------------------------------------------------------------------------------- 1 | package templates 2 | 3 | import ( 4 | _ "embed" 5 | "html/template" 6 | ) 7 | 8 | //go:embed output.css 9 | var CSS template.CSS 10 | -------------------------------------------------------------------------------- /templates/Makefile: -------------------------------------------------------------------------------- 1 | local: 2 | npx @tailwindcss/cli -i ./input.css -o ./output.css --watch 3 | 4 | build: 5 | npx @tailwindcss/cli -i ./input.css -o ./output.css --minify 6 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Table of Contents 2 | 3 | - [Architecture](architecture.md) 4 | - [Configuration](configuration.md) 5 | - [Endpoints](endpoints.md) 6 | - [Usage](usage.md) 7 | - [Session Management](sessions.md) 8 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/serviceaccount.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | labels: 5 | {{- include "wonderwall.labels" . | nindent 4 }} 6 | name: {{ include "wonderwall.fullname" . }} 7 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: wonderwall-forward-auth 3 | description: Forward-auth service for loadbalancer-fa 4 | type: application 5 | version: 1.0.0 6 | sources: 7 | - https://github.com/nais/wonderwall/tree/master/charts/wonderwall-forward-auth 8 | -------------------------------------------------------------------------------- /charts/wonderwall/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | description: Reverse proxy with OIDC auth. Configuration for sidecar proxies and SSO server. Feature toggle in naiserator must be enabled for sidecars to be injected to Applications. 3 | name: wonderwall 4 | type: application 5 | version: 1.0.0 6 | sources: 7 | - https://github.com/nais/wonderwall/tree/master/charts/wonderwall 8 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/configmap.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ConfigMap 3 | metadata: 4 | name: {{ include "wonderwall.fullname" . }} 5 | labels: 6 | {{- include "wonderwall.labels" . | nindent 4 }} 7 | annotations: 8 | reloader.stakater.com/match: "true" 9 | data: 10 | wonderwall_image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" 11 | -------------------------------------------------------------------------------- /.envrc: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Export all: 4 | # - (should be) .gitignored 5 | # - (potentially) secret environment variables 6 | # - from dotenv-formatted files w/names starting w/`.env` 7 | DOTENV_FILES="$(find . -maxdepth 1 -type f -name '.env*' -and -not -name '.envrc')" 8 | for file in ${DOTENV_FILES}; do 9 | dotenv "${file}" 10 | done 11 | export DOTENV_FILES 12 | 13 | # Load nix env for all the cool people 14 | use flake 15 | -------------------------------------------------------------------------------- /internal/crypto/text.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | ) 7 | 8 | // Text generates a cryptographically secure random string of a given length, and base64 URL-encodes it. 9 | func Text(length int) (string, error) { 10 | data := make([]byte, length) 11 | if _, err := rand.Read(data); err != nil { 12 | return "", err 13 | } 14 | 15 | return base64.RawURLEncoding.EncodeToString(data), nil 16 | } 17 | -------------------------------------------------------------------------------- /internal/o11y/otel/otel_test.go: -------------------------------------------------------------------------------- 1 | package otel_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/nais/wonderwall/internal/o11y/otel" 7 | "github.com/nais/wonderwall/pkg/config" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSetup(t *testing.T) { 12 | // Assert that version for semconv schemas don't conflict with the current otel version. 13 | _, err := otel.Setup(t.Context(), &config.Config{}) 14 | assert.NoError(t, err) 15 | } 16 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/poddisruptionbudget.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: policy/v1 3 | kind: PodDisruptionBudget 4 | metadata: 5 | labels: 6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 7 | name: {{ include "wonderwall-forward-auth.fullname" . }} 8 | spec: 9 | {{- toYaml .Values.podDisruptionBudget | nindent 2 }} 10 | selector: 11 | matchLabels: 12 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }} 13 | -------------------------------------------------------------------------------- /charts/wonderwall/.helmignore: -------------------------------------------------------------------------------- 1 | # Patterns to ignore when building packages. 2 | # This supports shell glob matching, relative path matching, and 3 | # negation (prefixed with !). Only one pattern per line. 4 | .DS_Store 5 | # Common VCS dirs 6 | .git/ 7 | .gitignore 8 | .bzr/ 9 | .bzrignore 10 | .hg/ 11 | .hgignore 12 | .svn/ 13 | # Common backup files 14 | *.swp 15 | *.bak 16 | *.tmp 17 | *.orig 18 | *~ 19 | # Various IDEs 20 | .project 21 | .idea/ 22 | *.tmproj 23 | .vscode/ 24 | 25 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/.helmignore: -------------------------------------------------------------------------------- 1 | # Patterns to ignore when building packages. 2 | # This supports shell glob matching, relative path matching, and 3 | # negation (prefixed with !). Only one pattern per line. 4 | .DS_Store 5 | # Common VCS dirs 6 | .git/ 7 | .gitignore 8 | .bzr/ 9 | .bzrignore 10 | .hg/ 11 | .hgignore 12 | .svn/ 13 | # Common backup files 14 | *.swp 15 | *.bak 16 | *.tmp 17 | *.orig 18 | *~ 19 | # Various IDEs 20 | .project 21 | .idea/ 22 | *.tmproj 23 | .vscode/ 24 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/idporten-secret.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.idporten.enabled }} 2 | --- 3 | apiVersion: v1 4 | kind: Secret 5 | type: kubernetes.io/Opaque 6 | metadata: 7 | name: "{{ .Values.idporten.ssoServerSecretName }}" 8 | labels: 9 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }} 10 | data: 11 | WONDERWALL_ENCRYPTION_KEY: "{{ .Values.idporten.sessionCookieEncryptionKey | required ".Values.idporten.sessionCookieEncryptionKey is required." | b64enc }}" 12 | {{ end }} 13 | -------------------------------------------------------------------------------- /pkg/router/paths/paths.go: -------------------------------------------------------------------------------- 1 | package paths 2 | 3 | const ( 4 | OAuth2 = "/oauth2" 5 | Login = "/login" 6 | LoginCallback = "/callback" 7 | Logout = "/logout" 8 | LogoutCallback = "/logout/callback" 9 | LogoutFrontChannel = "/logout/frontchannel" 10 | LogoutLocal = "/logout/local" 11 | Ping = "/ping" 12 | Refresh = "/refresh" 13 | Session = "/session" 14 | ForwardAuth = "/forwardauth" 15 | ) 16 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/idporten-poddisruptionbudget.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.idporten.enabled }} 2 | apiVersion: policy/v1 3 | kind: PodDisruptionBudget 4 | metadata: 5 | labels: 6 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }} 7 | name: {{ include "wonderwall.fullname" . }}-idporten 8 | spec: 9 | {{- toYaml .Values.podDisruptionBudget | nindent 2 }} 10 | selector: 11 | matchLabels: 12 | {{- include "wonderwall.selectorLabelsIdporten" . | nindent 6 }} 13 | {{- end }} 14 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/idporten-service.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.idporten.enabled }} 2 | apiVersion: v1 3 | kind: Service 4 | metadata: 5 | labels: 6 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }} 7 | name: {{ include "wonderwall.fullname" . }}-idporten 8 | spec: 9 | type: ClusterIP 10 | ports: 11 | - name: http 12 | port: 80 13 | protocol: TCP 14 | targetPort: http 15 | selector: 16 | {{- include "wonderwall.selectorLabelsIdporten" . | nindent 4 }} 17 | {{ end }} 18 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=$BUILDPLATFORM golang:1.25 AS builder 2 | ENV CGO_ENABLED=0 3 | ENV GOTOOLCHAIN=auto 4 | WORKDIR /src 5 | 6 | COPY go.mod go.sum ./ 7 | RUN go mod download 8 | 9 | COPY . . 10 | ARG TARGETOS 11 | ARG TARGETARCH 12 | RUN GOOS=$TARGETOS GOARCH=$TARGETARCH go build -trimpath -ldflags "-s -w" -a -o bin/wonderwall ./cmd/wonderwall 13 | 14 | FROM gcr.io/distroless/static-debian12:nonroot 15 | WORKDIR /app 16 | COPY --from=builder /src/bin/wonderwall /app/wonderwall 17 | ENTRYPOINT ["/app/wonderwall"] 18 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/fa-poddisruptionbudget.yaml: -------------------------------------------------------------------------------- 1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }} 2 | apiVersion: policy/v1 3 | kind: PodDisruptionBudget 4 | metadata: 5 | labels: 6 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }} 7 | name: {{ include "wonderwall.fullname" . }}-fa 8 | spec: 9 | {{- toYaml .Values.podDisruptionBudget | nindent 2 }} 10 | selector: 11 | matchLabels: 12 | {{- include "wonderwall.selectorLabelsForwardAuth" . | nindent 6 }} 13 | {{- end }} 14 | -------------------------------------------------------------------------------- /pkg/handler/path.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/nais/wonderwall/pkg/ingress" 7 | mw "github.com/nais/wonderwall/pkg/middleware" 8 | ) 9 | 10 | // GetPath returns the matching context path from the list of registered ingresses. 11 | // If none match, an empty string is returned. 12 | func GetPath(r *http.Request, ingresses *ingress.Ingresses) string { 13 | path, ok := mw.PathFrom(r.Context()) 14 | if !ok { 15 | path = ingresses.MatchingPath(r) 16 | } 17 | 18 | return path 19 | } 20 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/fa-service.yaml: -------------------------------------------------------------------------------- 1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }} 2 | apiVersion: v1 3 | kind: Service 4 | metadata: 5 | labels: 6 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }} 7 | name: {{ include "wonderwall.fullname" . }}-fa 8 | spec: 9 | type: ClusterIP 10 | ports: 11 | - name: http 12 | port: 80 13 | protocol: TCP 14 | targetPort: http 15 | selector: 16 | {{- include "wonderwall.selectorLabelsForwardAuth" . | nindent 4 }} 17 | {{- end }} 18 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/fa-secret.yaml: -------------------------------------------------------------------------------- 1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }} 2 | --- 3 | apiVersion: v1 4 | kind: Secret 5 | type: kubernetes.io/Opaque 6 | metadata: 7 | name: "{{ .Values.azure.forwardAuth.ssoServerSecretName }}" 8 | labels: 9 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }} 10 | data: 11 | WONDERWALL_ENCRYPTION_KEY: "{{ .Values.azure.forwardAuth.sessionCookieEncryptionKey | required ".Values.azure.forwardAuth.sessionCookieEncryptionKey is required." | b64enc }}" 12 | {{- end }} 13 | -------------------------------------------------------------------------------- /templates/package-lock.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "templates", 3 | "lockfileVersion": 3, 4 | "requires": true, 5 | "packages": { 6 | "": { 7 | "devDependencies": { 8 | "tailwindcss": "^4.0.0" 9 | } 10 | }, 11 | "node_modules/tailwindcss": { 12 | "version": "4.0.0", 13 | "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.0.0.tgz", 14 | "integrity": "sha512-ULRPI3A+e39T7pSaf1xoi58AqqJxVCLg8F/uM5A3FadUbnyDTgltVnXJvdkTjwCOGA6NazqHVcwPJC5h2vRYVQ==", 15 | "dev": true, 16 | "license": "MIT" 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/service.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: Service 4 | metadata: 5 | labels: 6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 7 | name: {{ include "wonderwall-forward-auth.fullname" . }} 8 | spec: 9 | type: ClusterIP 10 | ports: 11 | - name: http 12 | port: 80 13 | protocol: TCP 14 | targetPort: http 15 | - name: http-metrics 16 | port: 8081 17 | protocol: TCP 18 | targetPort: http-metrics 19 | selector: 20 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 4 }} 21 | -------------------------------------------------------------------------------- /internal/http/transport.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | "time" 7 | 8 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 9 | ) 10 | 11 | var ( 12 | defaultTransport *http.Transport 13 | once sync.Once 14 | ) 15 | 16 | func Transport() http.RoundTripper { 17 | once.Do(func() { 18 | t := http.DefaultTransport.(*http.Transport).Clone() 19 | t.MaxIdleConns = 200 20 | t.MaxIdleConnsPerHost = 100 21 | t.IdleConnTimeout = 5 * time.Second 22 | 23 | defaultTransport = t 24 | }) 25 | 26 | return otelhttp.NewTransport(defaultTransport) 27 | } 28 | -------------------------------------------------------------------------------- /pkg/mock/request.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | "github.com/nais/wonderwall/pkg/ingress" 8 | mw "github.com/nais/wonderwall/pkg/middleware" 9 | ) 10 | 11 | func NewGetRequest(target string, ingresses *ingress.Ingresses) *http.Request { 12 | req := httptest.NewRequest(http.MethodGet, target, nil) 13 | 14 | path := ingresses.MatchingPath(req) 15 | req = mw.RequestWithPath(req, path) 16 | 17 | ing, ok := ingresses.MatchingIngress(req) 18 | if ok { 19 | req = mw.RequestWithIngress(req, ing) 20 | } 21 | 22 | mw.Logger("test") 23 | 24 | return req 25 | } 26 | -------------------------------------------------------------------------------- /pkg/middleware/correlationid.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | chi_middleware "github.com/go-chi/chi/v5/middleware" 8 | "github.com/google/uuid" 9 | ) 10 | 11 | func CorrelationIDHandler(next http.Handler) http.Handler { 12 | fn := func(w http.ResponseWriter, r *http.Request) { 13 | id := r.Header.Get(chi_middleware.RequestIDHeader) 14 | if len(id) == 0 { 15 | id = uuid.New().String() 16 | } 17 | 18 | ctx := r.Context() 19 | ctx = context.WithValue(ctx, chi_middleware.RequestIDKey, id) 20 | next.ServeHTTP(w, r.WithContext(ctx)) 21 | } 22 | return http.HandlerFunc(fn) 23 | } 24 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | redis: 3 | image: redis:7 4 | ports: 5 | - "6379:6379" 6 | otel: 7 | image: grafana/otel-lgtm:latest 8 | ports: 9 | - "3002:3000" 10 | - "4317:4317" 11 | - "4318:4318" 12 | mock-oauth2-server: 13 | image: ghcr.io/navikt/mock-oauth2-server:3.0.1 14 | ports: 15 | - "8888:8080" 16 | environment: 17 | JSON_CONFIG: "{\"interactiveLogin\":false}" 18 | upstream: 19 | image: mendhak/http-https-echo:30 20 | ports: 21 | - "4000:4000" 22 | environment: 23 | HTTP_PORT: 4000 24 | JWT_HEADER: Authorization 25 | LOG_IGNORE_PATH: / 26 | -------------------------------------------------------------------------------- /pkg/handler/acr/acr.go: -------------------------------------------------------------------------------- 1 | package acr 2 | 3 | import ( 4 | "github.com/nais/wonderwall/pkg/config" 5 | "github.com/nais/wonderwall/pkg/openid/acr" 6 | "github.com/nais/wonderwall/pkg/session" 7 | ) 8 | 9 | type Handler struct { 10 | Enabled bool 11 | ExpectedValue string 12 | } 13 | 14 | func (h *Handler) Validate(sess *session.Session) error { 15 | if !h.Enabled || sess == nil { 16 | return nil 17 | } 18 | 19 | return acr.Validate(h.ExpectedValue, sess.Acr()) 20 | } 21 | 22 | func NewHandler(cfg *config.Config) *Handler { 23 | return &Handler{ 24 | Enabled: len(cfg.OpenID.ACRValues) > 0, 25 | ExpectedValue: cfg.OpenID.ACRValues, 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/secret.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: Secret 4 | type: kubernetes.io/Opaque 5 | metadata: 6 | name: {{ include "wonderwall-forward-auth.fullname" . }} 7 | labels: 8 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 9 | data: 10 | WONDERWALL_OPENID_CLIENT_SECRET: "{{ .Values.openid.clientSecret | required ".Values.openid.clientSecret is required." | b64enc }}" 11 | WONDERWALL_ENCRYPTION_KEY: "{{ .Values.session.cookieEncryptionKey | required ".Values.session.cookieEncryptionKey is required." | b64enc }}" 12 | WONDERWALL_REDIS_PASSWORD: "{{ .Values.valkey.password | required ".Values.valkey.password is required." | b64enc }}" 13 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/servicemonitor.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" }} 2 | --- 3 | apiVersion: monitoring.coreos.com/v1 4 | kind: ServiceMonitor 5 | metadata: 6 | name: {{ include "wonderwall-forward-auth.fullname" . }} 7 | labels: {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 8 | spec: 9 | endpoints: 10 | - interval: 1m 11 | port: http-metrics 12 | scrapeTimeout: 10s 13 | path: "/" 14 | namespaceSelector: 15 | matchNames: 16 | - {{ .Release.Namespace }} 17 | selector: 18 | matchLabels: 19 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }} 20 | {{- end }} 21 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/horizontalpodautoscaler.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: autoscaling/v2 3 | kind: HorizontalPodAutoscaler 4 | metadata: 5 | labels: 6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 7 | name: {{ include "wonderwall-forward-auth.fullname" . }} 8 | spec: 9 | minReplicas: {{ .Values.replicas.min }} 10 | maxReplicas: {{ .Values.replicas.max }} 11 | metrics: 12 | - resource: 13 | name: cpu 14 | target: 15 | averageUtilization: 75 16 | type: Utilization 17 | type: Resource 18 | scaleTargetRef: 19 | apiVersion: apps/v1 20 | kind: Deployment 21 | name: {{ include "wonderwall-forward-auth.fullname" . }} 22 | -------------------------------------------------------------------------------- /pkg/openid/oauth2_test.go: -------------------------------------------------------------------------------- 1 | package openid 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestStateMismatchError(t *testing.T) { 11 | for _, tt := range []struct { 12 | name, expected, actual string 13 | assertion assert.ErrorAssertionFunc 14 | }{ 15 | {"missing actual state", "expected", "", assert.Error}, 16 | {"state mismatch", "match", "not-match", assert.Error}, 17 | {"state match", "match", "match", assert.NoError}, 18 | } { 19 | t.Run(tt.name, func(t *testing.T) { 20 | actual := url.Values{ 21 | "state": []string{tt.actual}, 22 | } 23 | 24 | err := StateMismatchError(actual, tt.expected) 25 | tt.assertion(t, err) 26 | }) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | # Error templates 2 | 3 | This directory contains `.gohtml` templates for static error pages served by Wonderwall. 4 | 5 | These pages are typically only shown on exceptional errors, i.e. invalid configuration or infrastructure errors. 6 | End-users should generally not see these pages unless something is really wrong. 7 | 8 | We embed the CSS directly into the `.gohtml` templates. 9 | This avoids implementing an endpoint to serve the CSS file separately. 10 | 11 | ## Prerequisites 12 | 13 | If you haven't already, [install the Tailwind CSS CLI](https://tailwindcss.com/docs/installation). 14 | 15 | ## Development 16 | 17 | ```shell 18 | make local 19 | ``` 20 | 21 | ## Production 22 | 23 | ```shell 24 | make build 25 | ``` 26 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/idporten-horizontalpodautoscaler.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.idporten.enabled }} 2 | apiVersion: autoscaling/v2 3 | kind: HorizontalPodAutoscaler 4 | metadata: 5 | labels: 6 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }} 7 | name: {{ include "wonderwall.fullname" . }}-idporten 8 | spec: 9 | minReplicas: {{ .Values.idporten.replicasMin }} 10 | maxReplicas: {{ .Values.idporten.replicasMax }} 11 | metrics: 12 | - resource: 13 | name: cpu 14 | target: 15 | averageUtilization: 75 16 | type: Utilization 17 | type: Resource 18 | scaleTargetRef: 19 | apiVersion: apps/v1 20 | kind: Deployment 21 | name: {{ include "wonderwall.fullname" . }}-idporten 22 | {{- end }} 23 | -------------------------------------------------------------------------------- /pkg/middleware/cors.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/rs/cors" 9 | 10 | "github.com/nais/wonderwall/pkg/config" 11 | ) 12 | 13 | func Cors(cfg *config.Config, methods []string) func(http.Handler) http.Handler { 14 | ssoDomain := strings.TrimPrefix(cfg.SSO.Domain, ".") 15 | 16 | allowedOrigins := []string{ 17 | fmt.Sprintf("https://*.%s", ssoDomain), 18 | fmt.Sprintf("https://%s", ssoDomain), 19 | } 20 | 21 | return cors.New(cors.Options{ 22 | AllowedOrigins: allowedOrigins, 23 | AllowedMethods: methods, 24 | AllowCredentials: true, 25 | // This reflects the request headers, essentially allowing all headers. 26 | AllowedHeaders: []string{"*"}, 27 | }).Handler 28 | } 29 | -------------------------------------------------------------------------------- /pkg/openid/scopes/scopes.go: -------------------------------------------------------------------------------- 1 | package scopes 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | const ( 9 | OpenID = "openid" 10 | OfflineAccess = "offline_access" 11 | AzureAPITemplate = "api://%s/.default" 12 | ) 13 | 14 | type Scopes []string 15 | 16 | func (s Scopes) String() string { 17 | return strings.Join(s, " ") 18 | } 19 | 20 | func (s Scopes) WithAdditional(scopes ...string) Scopes { 21 | return append(s, scopes...) 22 | } 23 | 24 | func (s Scopes) WithAzureScope(clientID string) Scopes { 25 | return append(s, fmt.Sprintf(AzureAPITemplate, clientID)) 26 | } 27 | 28 | func (s Scopes) WithOfflineAccess() Scopes { 29 | return append(s, OfflineAccess) 30 | } 31 | 32 | func DefaultScopes() Scopes { 33 | return []string{OpenID} 34 | } 35 | -------------------------------------------------------------------------------- /pkg/cookie/options.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Options struct { 8 | Domain string 9 | Path string 10 | SameSite http.SameSite 11 | Secure bool 12 | } 13 | 14 | func DefaultOptions() Options { 15 | return Options{ 16 | SameSite: http.SameSiteLaxMode, 17 | Secure: true, 18 | } 19 | } 20 | 21 | func (o Options) WithDomain(domain string) Options { 22 | o.Domain = domain 23 | return o 24 | } 25 | 26 | func (o Options) WithPath(path string) Options { 27 | o.Path = path 28 | return o 29 | } 30 | 31 | func (o Options) WithSameSite(sameSite http.SameSite) Options { 32 | o.SameSite = sameSite 33 | return o 34 | } 35 | 36 | func (o Options) WithSecure(secure bool) Options { 37 | o.Secure = secure 38 | return o 39 | } 40 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/fa-horizontalpodautoscaler.yaml: -------------------------------------------------------------------------------- 1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }} 2 | apiVersion: autoscaling/v2 3 | kind: HorizontalPodAutoscaler 4 | metadata: 5 | labels: 6 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }} 7 | name: {{ include "wonderwall.fullname" . }}-fa 8 | spec: 9 | minReplicas: {{ .Values.azure.forwardAuth.replicasMin }} 10 | maxReplicas: {{ .Values.azure.forwardAuth.replicasMax }} 11 | metrics: 12 | - resource: 13 | name: cpu 14 | target: 15 | averageUtilization: 75 16 | type: Utilization 17 | type: Resource 18 | scaleTargetRef: 19 | apiVersion: apps/v1 20 | kind: Deployment 21 | name: {{ include "wonderwall.fullname" . }}-fa 22 | {{- end }} 23 | -------------------------------------------------------------------------------- /pkg/session/store_memory_test.go: -------------------------------------------------------------------------------- 1 | package session_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/nais/wonderwall/pkg/session" 9 | ) 10 | 11 | func TestMemory(t *testing.T) { 12 | crypter := makeCrypter(t) 13 | data := makeData() 14 | encryptedData, err := data.Encrypt(crypter) 15 | assert.NoError(t, err) 16 | 17 | store := session.NewMemory() 18 | key := "key" 19 | 20 | write(t, store, key, encryptedData) 21 | 22 | decrypted := read(t, store, key, encryptedData, crypter) 23 | decryptedEqual(t, data, decrypted) 24 | 25 | data, encryptedData = update(t, store, key, data, crypter) 26 | 27 | decrypted = read(t, store, key, encryptedData, crypter) 28 | decryptedEqual(t, data, decrypted) 29 | 30 | del(t, store, key) 31 | } 32 | -------------------------------------------------------------------------------- /templates/error.go: -------------------------------------------------------------------------------- 1 | package templates 2 | 3 | import ( 4 | _ "embed" 5 | "html/template" 6 | "io" 7 | 8 | log "github.com/sirupsen/logrus" 9 | ) 10 | 11 | //go:embed error.gohtml 12 | var errorGoHtml string 13 | var errorTemplate *template.Template 14 | 15 | type ErrorVariables struct { 16 | CorrelationID string 17 | CSS template.CSS 18 | DefaultRedirectURI string 19 | HttpStatusCode int 20 | RetryURI string 21 | } 22 | 23 | func init() { 24 | var err error 25 | 26 | errorTemplate = template.New("error") 27 | errorTemplate, err = errorTemplate.Parse(errorGoHtml) 28 | if err != nil { 29 | log.Fatalf("parsing error template: %+v", err) 30 | } 31 | } 32 | 33 | func ExecError(w io.Writer, vars ErrorVariables) error { 34 | return errorTemplate.Execute(w, vars) 35 | } 36 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/ingress.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: networking.k8s.io/v1 3 | kind: Ingress 4 | metadata: 5 | annotations: 6 | nginx.ingress.kubernetes.io/proxy-buffer-size: 16k 7 | nginx.ingress.kubernetes.io/enable-global-auth: "false" 8 | labels: 9 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 10 | name: {{ include "wonderwall-forward-auth.fullname" . }} 11 | spec: 12 | ingressClassName: {{ .Values.ingressClassName }} 13 | rules: 14 | - host: {{ .Values.sso.domain }} 15 | http: 16 | paths: 17 | - backend: 18 | service: 19 | name: {{ include "wonderwall-forward-auth.fullname" . }} 20 | port: 21 | number: 80 22 | path: / 23 | pathType: ImplementationSpecific 24 | -------------------------------------------------------------------------------- /pkg/openid/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | wonderwallconfig "github.com/nais/wonderwall/pkg/config" 5 | ) 6 | 7 | type Config interface { 8 | Client() Client 9 | Provider() Provider 10 | } 11 | 12 | type openidconfig struct { 13 | clientConfig Client 14 | providerConfig Provider 15 | } 16 | 17 | func (c *openidconfig) Client() Client { 18 | return c.clientConfig 19 | } 20 | 21 | func (c *openidconfig) Provider() Provider { 22 | return c.providerConfig 23 | } 24 | 25 | func NewConfig(cfg *wonderwallconfig.Config) (Config, error) { 26 | clientCfg, err := NewClientConfig(cfg) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | providerCfg, err := NewProviderConfig(cfg) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | return &openidconfig{ 37 | clientConfig: clientCfg, 38 | providerConfig: providerCfg, 39 | }, nil 40 | } 41 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/idporten-ingress.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.idporten.enabled }} 2 | apiVersion: networking.k8s.io/v1 3 | kind: Ingress 4 | metadata: 5 | annotations: 6 | nginx.ingress.kubernetes.io/proxy-buffer-size: 16k 7 | prometheus.io/path: /oauth2/ping 8 | prometheus.io/scrape: "true" 9 | labels: 10 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }} 11 | name: {{ include "wonderwall.fullname" . }}-idporten 12 | spec: 13 | ingressClassName: {{ .Values.idporten.ingressClassName }} 14 | rules: 15 | - host: {{ .Values.idporten.ssoServerHost }} 16 | http: 17 | paths: 18 | - backend: 19 | service: 20 | name: {{ include "wonderwall.fullname" . }}-idporten 21 | port: 22 | number: 80 23 | path: / 24 | pathType: ImplementationSpecific 25 | {{ end }} 26 | -------------------------------------------------------------------------------- /.github/workflows/dependabot-auto-merge.yaml: -------------------------------------------------------------------------------- 1 | name: Dependabot auto-merge 2 | on: pull_request 3 | 4 | permissions: 5 | contents: write 6 | pull-requests: write 7 | 8 | jobs: 9 | dependabot: 10 | runs-on: ubuntu-latest 11 | if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' }} 12 | steps: 13 | - name: Dependabot metadata 14 | id: metadata 15 | uses: dependabot/fetch-metadata@08eff52bf64351f401fb50d4972fa95b9f2c2d1b # ratchet:dependabot/fetch-metadata@v2 16 | with: 17 | github-token: "${{ secrets.GITHUB_TOKEN }}" 18 | - name: Enable auto-merge for Dependabot PRs 19 | if: steps.metadata.outputs.update-type != 'version-update:semver-major' 20 | run: gh pr merge --auto --squash "$PR_URL" 21 | env: 22 | PR_URL: ${{github.event.pull_request.html_url}} 23 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} 24 | -------------------------------------------------------------------------------- /pkg/middleware/ingress.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/nais/wonderwall/pkg/ingress" 7 | ) 8 | 9 | type IngressSource interface { 10 | GetIngresses() *ingress.Ingresses 11 | } 12 | 13 | type IngressMiddleware struct { 14 | IngressSource 15 | } 16 | 17 | func Ingress(source IngressSource) IngressMiddleware { 18 | return IngressMiddleware{IngressSource: source} 19 | } 20 | 21 | func (i *IngressMiddleware) Handler(next http.Handler) http.Handler { 22 | fn := func(w http.ResponseWriter, r *http.Request) { 23 | ingresses := i.GetIngresses() 24 | ctx := r.Context() 25 | 26 | path := ingresses.MatchingPath(r) 27 | ctx = WithPath(ctx, path) 28 | 29 | matchingIngress, ok := ingresses.MatchingIngress(r) 30 | if ok { 31 | ctx = WithIngress(ctx, matchingIngress) 32 | } 33 | 34 | next.ServeHTTP(w, r.WithContext(ctx)) 35 | } 36 | return http.HandlerFunc(fn) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/session/lock_test.go: -------------------------------------------------------------------------------- 1 | package session_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/alicebob/miniredis/v2" 9 | "github.com/redis/go-redis/v9" 10 | "github.com/stretchr/testify/assert" 11 | 12 | "github.com/nais/wonderwall/pkg/session" 13 | ) 14 | 15 | func TestRedisLock(t *testing.T) { 16 | s, err := miniredis.Run() 17 | if err != nil { 18 | panic(err) 19 | } 20 | defer s.Close() 21 | 22 | client := redis.NewClient(&redis.Options{ 23 | Network: "tcp", 24 | Addr: s.Addr(), 25 | }) 26 | 27 | key := "some-key" 28 | ctx := context.Background() 29 | lock := session.NewRedisLock(client, key) 30 | 31 | err = lock.Acquire(ctx, time.Minute) 32 | assert.NoError(t, err) 33 | 34 | err = lock.Acquire(ctx, time.Minute) 35 | assert.Error(t, err) 36 | assert.ErrorIs(t, err, session.ErrAcquireLock) 37 | 38 | err = lock.Release(ctx) 39 | assert.NoError(t, err) 40 | } 41 | -------------------------------------------------------------------------------- /pkg/config/ratelimit.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | flag "github.com/spf13/pflag" 7 | ) 8 | 9 | type RateLimit struct { 10 | Enabled bool `json:"enabled"` 11 | Logins int `json:"logins"` 12 | Window time.Duration `json:"window"` 13 | } 14 | 15 | const ( 16 | RateLimitEnabled = "ratelimit.enabled" 17 | RateLimitLogins = "ratelimit.logins" 18 | RateLimitWindow = "ratelimit.window" 19 | ) 20 | 21 | func rateLimitFlags() { 22 | flag.Bool(RateLimitEnabled, true, "Enable rate limiting per user-agent.") 23 | flag.Int(RateLimitLogins, 5, "Maximum permitted login attempts within 'ratelimit.window' before rate limiting.") 24 | flag.Duration(RateLimitWindow, 5*time.Second, "Time window for counting consecutive attempts towards rate limit."+ 25 | "Each attempt within the window will increment the attempt counter and reset the window."+ 26 | "If the window expires with no additional attempts, the counter is discarded.") 27 | } 28 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/fa-ingress.yaml: -------------------------------------------------------------------------------- 1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }} 2 | apiVersion: networking.k8s.io/v1 3 | kind: Ingress 4 | metadata: 5 | annotations: 6 | nginx.ingress.kubernetes.io/proxy-buffer-size: 16k 7 | nginx.ingress.kubernetes.io/enable-global-auth: "false" 8 | prometheus.io/path: /oauth2/ping 9 | prometheus.io/scrape: "true" 10 | labels: 11 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }} 12 | name: {{ include "wonderwall.fullname" . }}-fa 13 | spec: 14 | ingressClassName: {{ .Values.azure.forwardAuth.ingressClassName }} 15 | rules: 16 | - host: {{ .Values.azure.forwardAuth.ssoDomain }} 17 | http: 18 | paths: 19 | - backend: 20 | service: 21 | name: {{ include "wonderwall.fullname" . }}-fa 22 | port: 23 | number: 80 24 | path: / 25 | pathType: ImplementationSpecific 26 | {{- end }} 27 | -------------------------------------------------------------------------------- /internal/o11y/logging/logging.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | func textFormatter() log.Formatter { 11 | return &log.TextFormatter{ 12 | DisableTimestamp: false, 13 | FullTimestamp: true, 14 | TimestampFormat: time.RFC3339Nano, 15 | } 16 | } 17 | 18 | func jsonFormatter() log.Formatter { 19 | return &log.JSONFormatter{ 20 | TimestampFormat: time.RFC3339Nano, 21 | } 22 | } 23 | 24 | func Setup(level, format string) error { 25 | switch format { 26 | case "json": 27 | log.SetFormatter(jsonFormatter()) 28 | case "text": 29 | log.SetFormatter(textFormatter()) 30 | default: 31 | return fmt.Errorf("log format '%s' is not recognized", format) 32 | } 33 | 34 | logLevel, err := log.ParseLevel(level) 35 | if err != nil { 36 | return fmt.Errorf("while setting log level: %s", err) 37 | } 38 | 39 | log.SetLevel(logLevel) 40 | log.Tracef("Trace logging enabled") 41 | 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/fa-azureapp.yaml: -------------------------------------------------------------------------------- 1 | {{- if and .Values.azure.enabled .Values.azure.forwardAuth.enabled }} 2 | --- 3 | apiVersion: nais.io/v1 4 | kind: AzureAdApplication 5 | metadata: 6 | {{- if .Values.resourceSuffix }} 7 | name: {{ include "wonderwall.fullname" . }}-fa-{{ .Values.resourceSuffix }} 8 | {{- else }} 9 | name: {{ include "wonderwall.fullname" . }}-fa 10 | {{- end }} 11 | labels: 12 | {{- include "wonderwall.labelsForwardAuth" . | nindent 4 }} 13 | spec: 14 | secretName: {{ .Values.azure.forwardAuth.clientSecretName }} 15 | allowAllUsers: true 16 | {{- if .Values.azure.forwardAuth.groupIds }} 17 | claims: 18 | groups: 19 | {{- range .Values.azure.forwardAuth.groupIds }} 20 | - id: {{ . }} 21 | {{- end }} 22 | {{- end }} 23 | logoutUrl: "{{ include "wonderwall.azure.forwardAuthURL" . }}/oauth2/logout/frontchannel" 24 | replyUrls: 25 | - url: "{{- include "wonderwall.azure.forwardAuthURL" . }}/oauth2/callback" 26 | tenant: nav.no 27 | {{- end }} 28 | -------------------------------------------------------------------------------- /pkg/session/store_redis_test.go: -------------------------------------------------------------------------------- 1 | package session_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/alicebob/miniredis/v2" 7 | "github.com/redis/go-redis/v9" 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/nais/wonderwall/pkg/session" 11 | ) 12 | 13 | func TestRedis(t *testing.T) { 14 | crypter := makeCrypter(t) 15 | data := makeData() 16 | encryptedData, err := data.Encrypt(crypter) 17 | assert.NoError(t, err) 18 | 19 | s, err := miniredis.Run() 20 | if err != nil { 21 | panic(err) 22 | } 23 | defer s.Close() 24 | 25 | client := redis.NewClient(&redis.Options{ 26 | Network: "tcp", 27 | Addr: s.Addr(), 28 | }) 29 | 30 | store := session.NewRedis(client) 31 | key := "key" 32 | 33 | write(t, store, key, encryptedData) 34 | 35 | decrypted := read(t, store, key, encryptedData, crypter) 36 | decryptedEqual(t, data, decrypted) 37 | 38 | data, encryptedData = update(t, store, key, data, crypter) 39 | 40 | decrypted = read(t, store, key, encryptedData, crypter) 41 | decryptedEqual(t, data, decrypted) 42 | 43 | del(t, store, key) 44 | } 45 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | day: "monday" 8 | time: "09:00" 9 | timezone: "Europe/Oslo" 10 | groups: 11 | otel: 12 | patterns: 13 | - 'go.opentelemetry.io/*' 14 | redis: 15 | patterns: 16 | - 'github.com/redis/go-redis/*' 17 | cooldown: 18 | default-days: 7 19 | - package-ecosystem: "github-actions" 20 | directory: "/" 21 | schedule: 22 | interval: "weekly" 23 | day: "monday" 24 | time: "09:00" 25 | timezone: "Europe/Oslo" 26 | groups: 27 | gh-actions: 28 | patterns: 29 | - '*' 30 | cooldown: 31 | default-days: 7 32 | - package-ecosystem: "docker" 33 | directory: "/" 34 | schedule: 35 | interval: "weekly" 36 | day: "monday" 37 | time: "09:00" 38 | timezone: "Europe/Oslo" 39 | groups: 40 | docker: 41 | patterns: 42 | - '*' 43 | cooldown: 44 | default-days: 7 45 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/values.yaml: -------------------------------------------------------------------------------- 1 | nameOverride: "" 2 | fullnameOverride: "" 3 | 4 | image: 5 | repository: europe-north1-docker.pkg.dev/nais-io/nais/images/wonderwall 6 | tag: latest 7 | 8 | # mapped by fasit 9 | fasit: 10 | tenant: 11 | name: 12 | 13 | resources: 14 | limits: 15 | cpu: "2" 16 | memory: 512Mi 17 | requests: 18 | cpu: 100m 19 | memory: 64Mi 20 | replicas: 21 | min: 2 22 | max: 4 23 | podDisruptionBudget: 24 | maxUnavailable: 1 25 | ingressClassName: nais-ingress-fa 26 | otel: 27 | endpoint: http://opentelemetry-management-collector.nais-system:4317 28 | 29 | openid: 30 | clientID: 31 | clientSecret: 32 | extraAudience: 33 | extraScopes: 34 | wellKnownURL: https://auth.nais.io/.well-known/openid-configuration 35 | session: 36 | maxLifetime: 10h 37 | # 256 bits key, in standard base64 encoding 38 | cookieEncryptionKey: 39 | cookieName: nais-io-forward-auth 40 | sso: 41 | defaultRedirectURL: 42 | domain: 43 | 44 | valkey: 45 | host: 46 | port: 47 | username: 48 | password: 49 | connectionIdleTimeoutSeconds: 299 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 NAV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /internal/o11y/otel/middleware.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | chi_middleware "github.com/go-chi/chi/v5/middleware" 9 | httpinternal "github.com/nais/wonderwall/internal/http" 10 | "go.opentelemetry.io/otel/attribute" 11 | "go.opentelemetry.io/otel/trace" 12 | ) 13 | 14 | func Middleware(next http.Handler) http.Handler { 15 | fn := func(w http.ResponseWriter, r *http.Request) { 16 | ctx := r.Context() 17 | span := trace.SpanFromContext(ctx) 18 | 19 | attrs := httpinternal.Attributes(r) 20 | for k, v := range attrs { 21 | switch v := v.(type) { 22 | case bool: 23 | span.SetAttributes(attribute.Bool(k, v)) 24 | default: 25 | span.SetAttributes(attribute.String(k, fmt.Sprint(v))) 26 | } 27 | } 28 | 29 | // Override request ID with trace ID if available. 30 | if span.SpanContext().HasTraceID() { 31 | id := span.SpanContext().TraceID().String() 32 | ctx = context.WithValue(ctx, chi_middleware.RequestIDKey, id) 33 | next.ServeHTTP(w, r.WithContext(ctx)) 34 | } else { 35 | next.ServeHTTP(w, r) 36 | } 37 | } 38 | 39 | return http.HandlerFunc(fn) 40 | } 41 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/idporten-idportenclient.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.idporten.enabled }} 2 | --- 3 | apiVersion: nais.io/v1 4 | kind: IDPortenClient 5 | metadata: 6 | {{- if .Values.resourceSuffix }} 7 | name: {{ include "wonderwall.fullname" . }}-idporten-{{ .Values.resourceSuffix }} 8 | {{- else }} 9 | name: {{ include "wonderwall.fullname" . }}-idporten 10 | {{- end }} 11 | labels: 12 | {{- include "wonderwall.labelsIdporten" . | nindent 4 }} 13 | annotations: 14 | "digdir.nais.io/preserve": "true" 15 | "helm.sh/resource-policy": "keep" 16 | spec: 17 | clientURI: "{{ .Values.idporten.ssoDefaultRedirectURL }}" 18 | redirectURIs: 19 | - "{{- include "wonderwall.idporten.ssoServerURL" . }}/oauth2/callback" 20 | secretName: "{{ .Values.idporten.clientSecretName }}" 21 | frontchannelLogoutURI: "{{ include "wonderwall.idporten.ssoServerURL" . }}/oauth2/logout/frontchannel" 22 | postLogoutRedirectURIs: 23 | - "{{- include "wonderwall.idporten.ssoServerURL" . }}/oauth2/logout/callback" 24 | accessTokenLifetime: {{ .Values.idporten.clientAccessTokenLifetime }} 25 | sessionLifetime: {{ .Values.idporten.clientSessionLifetime }} 26 | {{- end }} 27 | -------------------------------------------------------------------------------- /internal/retry/retry.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/sethvargo/go-retry" 8 | ) 9 | 10 | var RetryableError = retry.RetryableError 11 | 12 | type fibonacci struct { 13 | base time.Duration 14 | max time.Duration 15 | } 16 | 17 | type Option func(*fibonacci) 18 | 19 | func WithBase(d time.Duration) Option { 20 | return func(f *fibonacci) { 21 | f.base = d 22 | } 23 | } 24 | 25 | func WithMax(d time.Duration) Option { 26 | return func(f *fibonacci) { 27 | f.max = d 28 | } 29 | } 30 | 31 | func Do(ctx context.Context, f retry.RetryFunc, opts ...Option) error { 32 | return retry.Do(ctx, fibonacciBackoff(opts...), f) 33 | } 34 | 35 | func DoValue[T any](ctx context.Context, f retry.RetryFuncValue[T], opts ...Option) (T, error) { 36 | return retry.DoValue(ctx, fibonacciBackoff(opts...), f) 37 | } 38 | 39 | func fibonacciBackoff(opts ...Option) retry.Backoff { 40 | f := &fibonacci{ 41 | base: 50 * time.Millisecond, 42 | max: 5 * time.Second, 43 | } 44 | 45 | for _, opt := range opts { 46 | opt(f) 47 | } 48 | 49 | b := retry.NewFibonacci(f.base) 50 | // beware: this starts a timer when invoked, on which the max duration is evaluated against 51 | b = retry.WithMaxDuration(f.max, b) 52 | return b 53 | } 54 | -------------------------------------------------------------------------------- /pkg/config/otel.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | 6 | flag "github.com/spf13/pflag" 7 | "github.com/spf13/viper" 8 | ) 9 | 10 | type OpenTelemetry struct { 11 | Enabled bool `json:"enabled"` 12 | ServiceName string `json:"service-name"` 13 | } 14 | 15 | const ( 16 | OpenTelemetryEnabled = "otel.enabled" 17 | OpenTelemetryServiceName = "otel.service-name" 18 | ) 19 | 20 | func otelFlags() { 21 | flag.Bool(OpenTelemetryEnabled, false, "Enable OpenTelemetry tracing. Automatically enabled if OTEL_EXPORTER_OTLP_ENDPOINT is set.") 22 | flag.String(OpenTelemetryServiceName, "wonderwall", "Service name to use for OpenTelemetry. The OTEL_SERVICE_NAME environment variable overrides this value.") 23 | } 24 | 25 | func resolveOtel() { 26 | otelEndpoint := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") 27 | if otelEndpoint != "" { 28 | logger.Debugf("config: OTLP endpoint set to %q, enabling OpenTelemetry", otelEndpoint) 29 | viper.Set(OpenTelemetryEnabled, "true") 30 | } 31 | 32 | otelServiceName := os.Getenv("OTEL_SERVICE_NAME") 33 | if otelServiceName != "" { 34 | logger.Debugf("config: OTEL_SERVICE_NAME set to %q; overriding %q flag", otelServiceName, OpenTelemetryServiceName) 35 | viper.Set(OpenTelemetryServiceName, otelServiceName) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/redis.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.azure.enabled }} 2 | {{ $provider := "azure" }} 3 | {{ include "common.redis.tpl" (dict "root" . "provider" $provider) }} 4 | {{ include "common.serviceintegration.tpl" (dict "root" . "provider" $provider) }} 5 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "readwrite" "secretName" .Values.azure.redisSecretName) }} 6 | {{- end }} 7 | 8 | {{- if .Values.idporten.enabled }} 9 | {{ $provider := "idporten" }} 10 | {{ include "common.redis.tpl" (dict "root" . "provider" $provider) }} 11 | {{ include "common.serviceintegration.tpl" (dict "root" . "provider" $provider) }} 12 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "readwrite" "secretName" .Values.idporten.redisSecretNames.readwrite) }} 13 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "read" "secretName" .Values.idporten.redisSecretNames.read) }} 14 | {{- end }} 15 | 16 | {{- if .Values.openid.enabled }} 17 | {{ $provider := "openid" }} 18 | {{ include "common.redis.tpl" (dict "root" . "provider" $provider) }} 19 | {{ include "common.serviceintegration.tpl" (dict "root" . "provider" $provider) }} 20 | {{ include "common.aivenapplication.tpl" (dict "root" . "provider" $provider "access" "readwrite" "secretName" .Values.openid.redisSecretName) }} 21 | {{- end }} 22 | -------------------------------------------------------------------------------- /pkg/openid/client/logout_callback.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/nais/wonderwall/pkg/openid" 8 | urlpkg "github.com/nais/wonderwall/pkg/url" 9 | ) 10 | 11 | type LogoutCallback struct { 12 | *Client 13 | cookie *openid.LogoutCookie 14 | validator urlpkg.Validator 15 | request *http.Request 16 | } 17 | 18 | func NewLogoutCallback(c *Client, r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback { 19 | return &LogoutCallback{ 20 | Client: c, 21 | cookie: cookie, 22 | validator: validator, 23 | request: r, 24 | } 25 | } 26 | 27 | func (in *LogoutCallback) PostLogoutRedirectURI() string { 28 | if in.cookie != nil && in.stateMismatchError() == nil && in.validator.IsValidRedirect(in.request, in.cookie.RedirectTo) { 29 | return in.cookie.RedirectTo 30 | } 31 | 32 | defaultRedirect := in.cfg.Client().PostLogoutRedirectURI() 33 | if defaultRedirect != "" { 34 | return defaultRedirect 35 | } 36 | 37 | ingress, err := urlpkg.MatchingIngress(in.request) 38 | if err != nil { 39 | return "/" 40 | } 41 | 42 | return ingress.String() 43 | } 44 | 45 | func (in *LogoutCallback) stateMismatchError() error { 46 | if in.cookie == nil { 47 | return fmt.Errorf("logout cookie is nil") 48 | } 49 | 50 | return openid.StateMismatchError(in.request.URL.Query(), in.cookie.State) 51 | } 52 | -------------------------------------------------------------------------------- /templates/input.css: -------------------------------------------------------------------------------- 1 | @import 'tailwindcss'; 2 | 3 | @theme { 4 | --color-primary-50: #ffe6e6; 5 | --color-primary-100: #ffc2c2; 6 | --color-primary-200: #f68282; 7 | --color-primary-300: #f25c5c; 8 | --color-primary-400: #de2e2e; 9 | --color-primary-500: #c30000; 10 | --color-primary-600: #ad0000; 11 | --color-primary-700: #8c0000; 12 | --color-primary-800: #5c0000; 13 | --color-primary-900: #260000; 14 | --color-primary-950: #180000; 15 | 16 | --color-action-50: #e6f0ff; 17 | --color-action-100: #cce1ff; 18 | --color-action-200: #99c3ff; 19 | --color-action-300: #66a5f4; 20 | --color-action-400: #3386e0; 21 | --color-action-500: #0067c5; 22 | --color-action-600: #0056b4; 23 | --color-action-700: #00459c; 24 | --color-action-800: #00347d; 25 | --color-action-900: #002252; 26 | --color-action-950: #00131a; 27 | } 28 | 29 | /* 30 | The default border color has changed to `currentColor` in Tailwind CSS v4, 31 | so we've added these compatibility styles to make sure everything still 32 | looks the same as it did with Tailwind CSS v3. 33 | 34 | If we ever want to remove these styles, we need to add an explicit border 35 | color utility to any element that depends on these defaults. 36 | */ 37 | @layer base { 38 | *, 39 | ::after, 40 | ::before, 41 | ::backdrop, 42 | ::file-selector-button { 43 | border-color: var(--color-gray-200, currentColor); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /mise.toml: -------------------------------------------------------------------------------- 1 | [tools] 2 | go = "1.25.5" 3 | 4 | [tasks.build] 5 | run = "go build -a -o bin/wonderwall ./cmd/wonderwall" 6 | 7 | [tasks.local] 8 | depends = ["fmt"] 9 | env = { OTEL_EXPORTER_OTLP_ENDPOINT = "http://localhost:4317" } 10 | run = '''go run cmd/wonderwall/main.go \ 11 | --openid.client-id=bogus \ 12 | --openid.client-secret=not-so-secret \ 13 | --openid.well-known-url=http://localhost:8888/default/.well-known/openid-configuration \ 14 | --ingress=http://localhost:3000 \ 15 | --bind-address=127.0.0.1:3000 \ 16 | --upstream-host=localhost:4000 \ 17 | --redis.uri=redis://localhost:6379 \ 18 | --log-level=info \ 19 | --log-format=text 20 | ''' 21 | 22 | [tasks.fmt] 23 | run = "go tool gofumpt -w ./" 24 | 25 | [tasks.test] 26 | depends = ["fmt"] 27 | run = "go test -count=1 -shuffle=on ./... -coverprofile cover.out" 28 | 29 | [tasks.check] 30 | depends = ["fmt"] 31 | run = [ 32 | "go vet ./...", 33 | "go tool staticcheck ./...", 34 | "go tool govulncheck -show=traces ./...", 35 | "go tool ratchet lint .github/workflows/*.yaml" 36 | ] 37 | 38 | [tasks.actions-update] 39 | description = "Upgrade all github actions to latest version satisfying their version tag" 40 | run = "go tool ratchet update .github/workflows/*.yaml" 41 | 42 | [tasks.actions-upgrade] 43 | description = "Upgrade all github actions to latest" 44 | run = "go tool ratchet upgrade .github/workflows/*.yaml" 45 | -------------------------------------------------------------------------------- /pkg/handler/handler_sso_server.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/nais/wonderwall/pkg/config" 7 | "github.com/nais/wonderwall/pkg/cookie" 8 | "github.com/nais/wonderwall/pkg/router" 9 | "github.com/nais/wonderwall/pkg/url" 10 | ) 11 | 12 | var _ router.Source = &SSOServer{} 13 | 14 | type SSOServer struct { 15 | *Standalone 16 | } 17 | 18 | func NewSSOServer(cfg *config.Config, handler *Standalone) (*SSOServer, error) { 19 | redirect, err := url.NewSSOServerRedirect(cfg) 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | handler.Redirect = redirect 25 | handler.CookieOptions = cookie.DefaultOptions(). 26 | WithPath("/"). 27 | WithDomain(cfg.SSO.Domain). 28 | WithSameSite(cfg.Cookie.SameSite.ToHttp()). 29 | WithSecure(cfg.Cookie.Secure) 30 | 31 | return &SSOServer{Standalone: handler}, nil 32 | } 33 | 34 | func (s *SSOServer) Logout(w http.ResponseWriter, r *http.Request) { 35 | s.Standalone.Logout(w, r) 36 | } 37 | 38 | func (s *SSOServer) LogoutFrontChannel(w http.ResponseWriter, r *http.Request) { 39 | s.Standalone.LogoutFrontChannel(w, r) 40 | } 41 | 42 | func (s *SSOServer) LogoutLocal(w http.ResponseWriter, r *http.Request) { 43 | s.Standalone.LogoutLocal(w, r) 44 | } 45 | 46 | // Wildcard redirects unhandled requests to the default redirect URL. 47 | func (s *SSOServer) Wildcard(w http.ResponseWriter, r *http.Request) { 48 | http.Redirect(w, r, s.Config.SSO.ServerDefaultRedirectURL, http.StatusFound) 49 | } 50 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/prometheusrule.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" }} 2 | --- 3 | apiVersion: monitoring.coreos.com/v1 4 | kind: PrometheusRule 5 | metadata: 6 | name: {{ include "wonderwall-forward-auth.fullname" . }} 7 | labels: 8 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 9 | spec: 10 | groups: 11 | - name: "wonderwall-forward-auth" 12 | rules: 13 | - alert: wonderwall-forward-auth (Zitadel) reports a high amount of internal errors 14 | expr: sum(increase(requests_total{service="{{ include "wonderwall-forward-auth.fullname" . }}", namespace="{{ .Release.Namespace }}", code="500"}[5m])) > 30 15 | for: 5m 16 | annotations: 17 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes. 18 | consequence: This probably means that end-users are having trouble with authentication. 19 | action: | 20 | * Check the logs and metrics in the dashboard 21 | * Check Aiven Valkey (session store) and verify Aiven network connectivity 22 | * Check the Zitadel dashboard: 23 | dashboard_url: "https://monitoring.nais.io/d/ben86a369fj7kd?var-tenant={{ .Values.fasit.tenant.name }}" 24 | labels: 25 | severity: critical 26 | namespace: {{ .Release.Namespace }} 27 | {{ end }} 28 | -------------------------------------------------------------------------------- /pkg/openid/cookies.go: -------------------------------------------------------------------------------- 1 | package openid 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/nais/wonderwall/internal/crypto" 9 | "github.com/nais/wonderwall/pkg/cookie" 10 | ) 11 | 12 | type LoginCookie struct { 13 | Acr string `json:"acr"` 14 | CodeVerifier string `json:"code_verifier"` 15 | Nonce string `json:"nonce"` 16 | RedirectURI string `json:"redirect_uri"` 17 | Referer string `json:"referer"` 18 | State string `json:"state"` 19 | } 20 | 21 | type LogoutCookie struct { 22 | State string `json:"state"` 23 | RedirectTo string `json:"redirect_to"` 24 | } 25 | 26 | func GetLoginCookie(r *http.Request, crypter crypto.Crypter) (*LoginCookie, error) { 27 | loginCookieJson, err := cookie.GetDecrypted(r, cookie.Login, crypter) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | var loginCookie LoginCookie 33 | err = json.Unmarshal([]byte(loginCookieJson), &loginCookie) 34 | if err != nil { 35 | return nil, fmt.Errorf("unmarshalling: %w", err) 36 | } 37 | 38 | return &loginCookie, nil 39 | } 40 | 41 | func GetLogoutCookie(r *http.Request, crypter crypto.Crypter) (*LogoutCookie, error) { 42 | logoutCookieJson, err := cookie.GetDecrypted(r, cookie.Logout, crypter) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | var logoutCookie LogoutCookie 48 | err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie) 49 | if err != nil { 50 | return nil, fmt.Errorf("unmarshalling: %w", err) 51 | } 52 | 53 | return &logoutCookie, nil 54 | } 55 | -------------------------------------------------------------------------------- /pkg/openid/acr/acr_test.go: -------------------------------------------------------------------------------- 1 | package acr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestValidateAcr(t *testing.T) { 10 | for _, tt := range []struct { 11 | name string 12 | expected string 13 | actual string 14 | wantErr bool 15 | }{ 16 | {"no mapping found, not equal", "some-value", "some-other-value", true}, 17 | {"no mapping found, expected equals actual", "some-value", "some-value", false}, 18 | {"Level3, higher acr accepted", "Level3", "idporten-loa-high", false}, 19 | {"Level3, no matching value", "Level3", "Level2", true}, 20 | {"Level3 -> idporten-loa-substantial", "Level3", "idporten-loa-substantial", false}, 21 | {"idporten-loa-substantial", "idporten-loa-substantial", "idporten-loa-substantial", false}, 22 | {"idporten-loa-substantial, higher acr accepted", "idporten-loa-substantial", "idporten-loa-high", false}, 23 | {"Level4, lower acr not accepted", "Level4", "idporten-loa-substantial", true}, 24 | {"Level4, no matching value", "Level4", "Level5", true}, 25 | {"Level4 -> idporten-loa-high", "Level4", "idporten-loa-high", false}, 26 | {"idporten-loa-high", "idporten-loa-high", "idporten-loa-high", false}, 27 | {"idporten-loa-high, lower acr not accepted", "idporten-loa-high", "idporten-loa-substantial", true}, 28 | } { 29 | t.Run(tt.name, func(t *testing.T) { 30 | err := Validate(tt.expected, tt.actual) 31 | if tt.wantErr { 32 | assert.Error(t, err) 33 | } else { 34 | assert.NoError(t, err) 35 | } 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /pkg/session/store_memory.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type memorySessionStore struct { 11 | lock sync.Mutex 12 | sessions map[string]*EncryptedData 13 | } 14 | 15 | var _ Store = &memorySessionStore{} 16 | 17 | func NewMemory() Store { 18 | return &memorySessionStore{ 19 | sessions: make(map[string]*EncryptedData), 20 | } 21 | } 22 | 23 | func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData, error) { 24 | s.lock.Lock() 25 | defer s.lock.Unlock() 26 | 27 | data, ok := s.sessions[key] 28 | if !ok { 29 | return nil, fmt.Errorf("%w: no such session: %s", ErrNotFound, key) 30 | } 31 | 32 | return data, nil 33 | } 34 | 35 | func (s *memorySessionStore) Write(_ context.Context, key string, value *EncryptedData, _ time.Duration) error { 36 | s.lock.Lock() 37 | defer s.lock.Unlock() 38 | 39 | s.sessions[key] = value 40 | return nil 41 | } 42 | 43 | func (s *memorySessionStore) Delete(_ context.Context, keys ...string) error { 44 | s.lock.Lock() 45 | defer s.lock.Unlock() 46 | 47 | for _, key := range keys { 48 | delete(s.sessions, key) 49 | } 50 | 51 | return nil 52 | } 53 | 54 | func (s *memorySessionStore) Update(ctx context.Context, key string, value *EncryptedData) error { 55 | _, err := s.Read(ctx, key) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | s.lock.Lock() 61 | defer s.lock.Unlock() 62 | 63 | s.sessions[key] = value 64 | return nil 65 | } 66 | 67 | func (s *memorySessionStore) MakeLock(_ string) Lock { 68 | return NewNoOpLock() 69 | } 70 | -------------------------------------------------------------------------------- /internal/http/middleware.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | 6 | "go.opentelemetry.io/otel/attribute" 7 | "go.opentelemetry.io/otel/trace" 8 | ) 9 | 10 | // DisallowNonNavigationalRequests checks if the request is non-navigational, and if so, responds with a 401. 11 | // We do this to separate between redirects for browser navigation and redirects for resource requests. 12 | // 13 | // This should only be used for endpoints that are only supposed to be _navigated to_ from a browser. 14 | // The 401 response prevents redirecting non-navigation requests to the identity provider, which usually results in 15 | // a CORS error for typical Fetch or XHR requests from the browser. 16 | // 17 | // This depends on the presence of the Fetch metadata headers, mostly present in modern browsers. 18 | // For compatibility with older browsers, requests without these headers are still allowed to pass through. 19 | func DisallowNonNavigationalRequests(next http.Handler) http.Handler { 20 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | if HasSecFetchMetadata(r) && !IsNavigationRequest(r) { 22 | span := trace.SpanFromContext(r.Context()) 23 | span.SetAttributes(attribute.Bool("request.disallowed", true)) 24 | 25 | w.Header().Set("Content-Type", "application/json") 26 | w.WriteHeader(http.StatusUnauthorized) 27 | w.Write([]byte(`{"error": "unauthenticated", "error_description": "this is an interactive endpoint; user-agents must be navigated to this endpoint", "error_path": "` + r.URL.Path + `"}`)) 28 | return 29 | } 30 | 31 | next.ServeHTTP(w, r) 32 | }) 33 | } 34 | -------------------------------------------------------------------------------- /internal/crypto/jwk.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "fmt" 7 | 8 | "github.com/lestrrat-go/jwx/v3/jwa" 9 | "github.com/lestrrat-go/jwx/v3/jwk" 10 | ) 11 | 12 | type JwkSet struct { 13 | Private jwk.Set 14 | Public jwk.Set 15 | } 16 | 17 | func NewJwk() (jwk.Key, error) { 18 | privateKey, err := rsa.GenerateKey(rand.Reader, 2048) 19 | if err != nil { 20 | return nil, fmt.Errorf("generating key: %w", err) 21 | } 22 | 23 | key, err := jwk.Import(privateKey) 24 | if err != nil { 25 | return nil, fmt.Errorf("importing key: %w", err) 26 | } 27 | 28 | err = key.Set(jwk.AlgorithmKey, jwa.RS256().String()) 29 | if err != nil { 30 | return nil, fmt.Errorf("setting algorithm: %w", err) 31 | } 32 | 33 | err = key.Set(jwk.KeyTypeKey, jwa.RSA().String()) 34 | if err != nil { 35 | return nil, fmt.Errorf("setting key type: %w", err) 36 | } 37 | 38 | err = jwk.AssignKeyID(key) 39 | if err != nil { 40 | return nil, fmt.Errorf("assigning key id: %w", err) 41 | } 42 | 43 | return key, nil 44 | } 45 | 46 | func NewJwkSet() (*JwkSet, error) { 47 | key, err := NewJwk() 48 | if err != nil { 49 | return nil, fmt.Errorf("creating jwk: %w", err) 50 | } 51 | 52 | privateKeys := jwk.NewSet() 53 | err = privateKeys.AddKey(key) 54 | if err != nil { 55 | return nil, fmt.Errorf("adding key to set: %w", err) 56 | } 57 | 58 | publicKeys, err := jwk.PublicSetOf(privateKeys) 59 | if err != nil { 60 | return nil, fmt.Errorf("creating public set: %w", err) 61 | } 62 | 63 | return &JwkSet{ 64 | Private: privateKeys, 65 | Public: publicKeys, 66 | }, nil 67 | } 68 | -------------------------------------------------------------------------------- /pkg/session/id.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/nais/wonderwall/internal/crypto" 8 | "github.com/nais/wonderwall/pkg/openid" 9 | openidconfig "github.com/nais/wonderwall/pkg/openid/config" 10 | ) 11 | 12 | // ExternalID returns the external session ID, derived from the given request or id_token; e.g. `sid` or `session_state`. 13 | // If none are present, a generated ID is returned. 14 | func ExternalID(r *http.Request, cfg openidconfig.Provider, idToken *openid.IDToken) (string, error) { 15 | // 1. check for 'sid' claim in id_token 16 | sessionID, err := idToken.Sid() 17 | if err == nil { 18 | return sessionID, nil 19 | } 20 | // 1a. error if sid claim is required according to openid config 21 | if err != nil && cfg.SidClaimRequired() { 22 | return "", err 23 | } 24 | 25 | // 2. check for session_state in callback params 26 | sessionID, err = getSessionStateFrom(r) 27 | if err == nil { 28 | return sessionID, nil 29 | } 30 | // 2a. error if session_state is required according to openid config 31 | if err != nil && cfg.SessionStateRequired() { 32 | return "", err 33 | } 34 | 35 | // 3. generate ID if all else fails 36 | sessionID, err = crypto.Text(64) 37 | if err != nil { 38 | return "", fmt.Errorf("generating session ID: %w", err) 39 | } 40 | return sessionID, nil 41 | } 42 | 43 | func getSessionStateFrom(r *http.Request) (string, error) { 44 | params := r.URL.Query() 45 | 46 | sessionState := params.Get("session_state") 47 | if len(sessionState) == 0 { 48 | return "", fmt.Errorf("missing required 'session_state' in params") 49 | } 50 | return sessionState, nil 51 | } 52 | -------------------------------------------------------------------------------- /internal/crypto/crypter_test.go: -------------------------------------------------------------------------------- 1 | package crypto_test 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/nais/wonderwall/internal/crypto" 10 | ) 11 | 12 | var ( 13 | plaintext = []byte("foo bar, this is a very nice plaintext") 14 | key = make([]byte, 32) 15 | ) 16 | 17 | const ( 18 | // Run this many iterations to make sure the IV is not re-used 19 | ivIterations = 50000 20 | ) 21 | 22 | // Generate a new encryption key on every test run. 23 | func init() { 24 | _, err := rand.Read(key) 25 | if err != nil { 26 | panic(err) 27 | } 28 | } 29 | 30 | // Test that encryption with a 256-bit key works, 31 | // and that the ciphertext differs from one message to the next. 32 | func TestEncrypt(t *testing.T) { 33 | var cur, prev []byte 34 | var err error 35 | 36 | crypter := crypto.NewCrypter(key) 37 | 38 | for i := ivIterations; i != 0; i-- { 39 | cur, err = crypter.Encrypt(plaintext) 40 | assert.Nil(t, err) 41 | assert.NotNil(t, cur) 42 | assert.NotEqual(t, prev, cur, "IV re-used") 43 | prev = make([]byte, len(cur)) 44 | copy(prev, cur) 45 | } 46 | } 47 | 48 | // Test that encrypted messages can be decrypted. 49 | func TestDecrypt(t *testing.T) { 50 | crypter := crypto.NewCrypter(key) 51 | 52 | ciphertext, err := crypter.Encrypt(plaintext) 53 | assert.Nil(t, err) 54 | 55 | decrypted, err := crypter.Decrypt(ciphertext) 56 | assert.Nil(t, err) 57 | assert.Equal(t, plaintext, decrypted) 58 | } 59 | 60 | func BenchmarkEncrypt(b *testing.B) { 61 | crypter := crypto.NewCrypter(key) 62 | 63 | for n := 0; n < b.N; n++ { 64 | crypter.Encrypt(plaintext) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /pkg/config/session.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | flag "github.com/spf13/pflag" 8 | ) 9 | 10 | type Session struct { 11 | ForwardAuth bool `json:"forward-auth"` 12 | ForwardAuthSetHeaders bool `json:"forward-auth-set-headers"` 13 | Inactivity bool `json:"inactivity"` 14 | InactivityTimeout time.Duration `json:"inactivity-timeout"` 15 | MaxLifetime time.Duration `json:"max-lifetime"` 16 | } 17 | 18 | func (s *Session) Validate() error { 19 | if s.ForwardAuthSetHeaders && !s.ForwardAuth { 20 | return fmt.Errorf("%q must be enabled when %q is enabled", SessionForwardAuth, SessionForwardAuthSetHeaders) 21 | } 22 | return nil 23 | } 24 | 25 | const ( 26 | SessionForwardAuth = "session.forward-auth" 27 | SessionForwardAuthSetHeaders = "session.forward-auth-set-headers" 28 | SessionInactivity = "session.inactivity" 29 | SessionInactivityTimeout = "session.inactivity-timeout" 30 | SessionMaxLifetime = "session.max-lifetime" 31 | ) 32 | 33 | func sessionFlags() { 34 | flag.Bool(SessionForwardAuth, false, "Enable endpoint for forward authentication.") 35 | flag.Bool(SessionForwardAuthSetHeaders, false, "Set 'X-Wonderwall-Forward-Auth-Token' header for responses from forward-auth endpoint.") 36 | flag.Bool(SessionInactivity, false, "Automatically expire user sessions if they have not refreshed their tokens within a given duration.") 37 | flag.Duration(SessionInactivityTimeout, 30*time.Minute, "Inactivity timeout for user sessions.") 38 | flag.Duration(SessionMaxLifetime, 10*time.Hour, "Max lifetime for user sessions.") 39 | } 40 | -------------------------------------------------------------------------------- /pkg/mock/config.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/nais/wonderwall/pkg/config" 7 | "github.com/nais/wonderwall/pkg/ingress" 8 | openidconfig "github.com/nais/wonderwall/pkg/openid/config" 9 | ) 10 | 11 | const ( 12 | Ingress = "http://wonderwall" 13 | ) 14 | 15 | func Config() *config.Config { 16 | return &config.Config{ 17 | EncryptionKey: `G8Roe6AcoBpdr5GhO3cs9iORl4XIC8eq`, // 256 bits key 18 | Ingresses: []string{Ingress}, 19 | OpenID: config.OpenID{ 20 | ACRValues: "idporten-loa-high", 21 | ClientID: "client-id", 22 | IDTokenSigningAlg: "RS256", 23 | PostLogoutRedirectURI: "https://google.com", 24 | Provider: "test", 25 | Scopes: []string{"some-scope"}, 26 | UILocales: "nb", 27 | }, 28 | Session: config.Session{ 29 | MaxLifetime: time.Hour, 30 | }, 31 | } 32 | } 33 | 34 | type TestConfiguration struct { 35 | TestClient *TestClientConfiguration 36 | TestProvider *TestProviderConfiguration 37 | } 38 | 39 | func (c *TestConfiguration) Client() openidconfig.Client { 40 | return c.TestClient 41 | } 42 | 43 | func (c *TestConfiguration) Provider() openidconfig.Provider { 44 | return c.TestProvider 45 | } 46 | 47 | func NewTestConfiguration(cfg *config.Config) *TestConfiguration { 48 | return &TestConfiguration{ 49 | TestClient: clientConfiguration(cfg), 50 | TestProvider: providerConfiguration(cfg), 51 | } 52 | } 53 | 54 | func Ingresses(cfg *config.Config) *ingress.Ingresses { 55 | parsed, err := ingress.ParseIngresses(cfg) 56 | if err != nil { 57 | panic(err) 58 | } 59 | 60 | return parsed 61 | } 62 | -------------------------------------------------------------------------------- /docker-compose.example.yml: -------------------------------------------------------------------------------- 1 | services: 2 | redis: 3 | image: redis:7 4 | ports: 5 | - "6379:6379" 6 | mock-oauth2-server: 7 | image: ghcr.io/navikt/mock-oauth2-server:3.0.1 8 | ports: 9 | - "8888:8080" 10 | environment: 11 | JSON_CONFIG: "{\"interactiveLogin\":false}" 12 | wonderwall: 13 | image: ghcr.io/nais/wonderwall:latest 14 | ports: 15 | - "3000:3000" 16 | command: > 17 | --openid.client-id=bogus 18 | --openid.client-secret=not-so-secret 19 | --openid.well-known-url=http://localhost:8888/default/.well-known/openid-configuration 20 | --ingress=http://localhost:3000 21 | --bind-address=0.0.0.0:3000 22 | --upstream-host=upstream:4000 23 | --redis.uri=redis://redis:6379 24 | --log-level=debug 25 | --log-format=text 26 | restart: on-failure 27 | extra_hosts: 28 | # Wonderwall needs to both reach and redirect user agents to the mock-oauth2-server: 29 | # - 'mock-oauth2-server:8888' resolves from the container, but is not resolvable for user agents at the host (e.g. during redirects). 30 | # - 'localhost:8888' allows user agents to resolve redirects to the mock-oauth2-server, but breaks connectivity from the container itself. 31 | # This additional mapping allows the container to reach the mock-oauth2-server at 'localhost' through the host network, as well as allowing user agents to correctly resolve redirects. 32 | - localhost:host-gateway 33 | upstream: 34 | image: mendhak/http-https-echo:30 35 | ports: 36 | - "4000:4000" 37 | environment: 38 | HTTP_PORT: 4000 39 | JWT_HEADER: Authorization 40 | LOG_IGNORE_PATH: / 41 | -------------------------------------------------------------------------------- /pkg/handler/autologin/autologin.go: -------------------------------------------------------------------------------- 1 | package autologin 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | "sync" 7 | 8 | "github.com/bmatcuk/doublestar/v4" 9 | 10 | "github.com/nais/wonderwall/pkg/config" 11 | ) 12 | 13 | var DefaultIgnorePatterns = []string{ 14 | "/favicon.ico", 15 | "/robots.txt", 16 | } 17 | 18 | type AutoLogin struct { 19 | Enabled bool 20 | IgnorePatterns []string 21 | cache sync.Map 22 | } 23 | 24 | func (a *AutoLogin) NeedsLogin(r *http.Request, isAuthenticated bool) bool { 25 | if isAuthenticated || !a.Enabled { 26 | return false 27 | } 28 | 29 | path := r.URL.Path 30 | if !strings.HasPrefix(path, "/") { 31 | path = "/" + path 32 | } 33 | 34 | if path != "/" { 35 | path = strings.TrimSuffix(path, "/") 36 | } 37 | 38 | if result, found := a.cache.Load(path); found { 39 | return result.(bool) 40 | } 41 | 42 | for _, pattern := range a.IgnorePatterns { 43 | match, _ := doublestar.Match(pattern, path) 44 | if match { 45 | a.cache.Store(path, false) 46 | return false 47 | } 48 | } 49 | 50 | a.cache.Store(path, true) 51 | return true 52 | } 53 | 54 | func New(cfg *config.Config) (*AutoLogin, error) { 55 | seen := make(map[string]bool) 56 | patterns := make([]string, 0) 57 | 58 | for _, path := range append(DefaultIgnorePatterns, cfg.AutoLoginIgnorePaths...) { 59 | if len(path) == 0 { 60 | continue 61 | } 62 | 63 | if path != "/" { 64 | path = strings.TrimSuffix(path, "/") 65 | } 66 | 67 | if _, found := seen[path]; !found { 68 | seen[path] = true 69 | patterns = append(patterns, path) 70 | } 71 | } 72 | 73 | return &AutoLogin{ 74 | Enabled: cfg.AutoLogin, 75 | IgnorePatterns: patterns, 76 | cache: sync.Map{}, 77 | }, nil 78 | } 79 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Wonderwall"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; 6 | flake-utils.url = "github:numtide/flake-utils"; 7 | gitignore = { 8 | url = "github:hercules-ci/gitignore.nix"; 9 | inputs.nixpkgs.follows = "nixpkgs"; 10 | }; 11 | }; 12 | 13 | outputs = {self, ...} @ inputs: 14 | inputs.flake-utils.lib.eachDefaultSystem (system: let 15 | pkgs = import inputs.nixpkgs {localSystem = {inherit system;};}; 16 | name = "wonderwall"; 17 | wonderwall = pkgs.buildGoModule { 18 | inherit name; 19 | # nativeBuildInputs = with pkgs; [ 20 | # golangci-lint 21 | # ]; 22 | # GOLANGCI_LINT = "${pkgs.golangci-lint}"; 23 | src = inputs.gitignore.lib.gitignoreSource ./.; 24 | vendorHash = "sha256-3RqVAgA9iJhX0mbwlVMH+NSUz3H9Uobgs8zm1x9fb1o="; # nixpkgs.lib.fakeSha256; 25 | }; 26 | in { 27 | devShells.default = pkgs.mkShell { 28 | inputsFrom = [wonderwall]; 29 | }; 30 | packages = { 31 | inherit wonderwall; 32 | docker = let 33 | imageRef = "europe-north1-docker.pkg.dev/nais-management-233d"; 34 | teamName = "nais"; 35 | dockerTag = 36 | if pkgs.lib.hasAttr "rev" self 37 | then "${builtins.toString self.revCount}-${self.shortRev}" 38 | else "gitDirty"; 39 | in 40 | pkgs.dockerTools.buildImage { 41 | config = {Entrypoint = ["${wonderwall}/bin/${name}"];}; 42 | name = "${imageRef}/${teamName}/${name}"; 43 | tag = "${dockerTag}"; 44 | }; 45 | }; 46 | packages.default = wonderwall; 47 | formatter = pkgs.alejandra; 48 | }); 49 | } 50 | -------------------------------------------------------------------------------- /templates/error.gohtml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | Teknisk feil 9 | 10 | 11 |
12 |
13 |
14 |

15 | Beklager, noe gikk galt. 16 |

17 |

18 | Statuskode {{.HttpStatusCode}} 19 |

20 |

21 | En teknisk feil gjør at siden er utilgjengelig. Dette skyldes ikke noe du gjorde.
22 | Vent litt og prøv igjen. 23 |

24 |
25 | {{if ne .HttpStatusCode 400}} 26 | 27 | Prøv igjen 28 | 29 | {{end}} 30 | 31 | Gå til forsiden 32 | 33 |
34 |

35 | ID: {{.CorrelationID}} 36 |

37 |
38 |
39 |
40 | 41 | 42 | -------------------------------------------------------------------------------- /pkg/openid/acr/acr.go: -------------------------------------------------------------------------------- 1 | package acr 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | const ( 8 | IDPortenLevel3 = "Level3" 9 | IDPortenLevelSubstantial = "idporten-loa-substantial" 10 | IDPortenLevel4 = "Level4" 11 | IDPortenLevelHigh = "idporten-loa-high" 12 | ) 13 | 14 | // IDPortenLegacyMapping is a translation table of valid acr_values that maps values from "old" to "new" ID-porten. 15 | var IDPortenLegacyMapping = map[string]string{ 16 | IDPortenLevel3: IDPortenLevelSubstantial, 17 | IDPortenLevel4: IDPortenLevelHigh, 18 | } 19 | 20 | // acceptedValuesMapping is a map of ACR (authentication context class reference) values. 21 | // Each value has an associated list of values that are regarded as equivalent or greater in terms of assurance levels. 22 | // Example: 23 | // - if we require an ACR value of "idporten-loa-substantial", then both "idporten-loa-substantial" and "idporten-loa-high" are accepted values. 24 | // - if we require an ACR value of "idporten-loa-high", then only "idporten-loa-high" is an acceptable value. 25 | var acceptedValuesMapping = map[string][]string{ 26 | IDPortenLevelSubstantial: {IDPortenLevelSubstantial, IDPortenLevelHigh}, 27 | IDPortenLevelHigh: {IDPortenLevelHigh}, 28 | } 29 | 30 | func Validate(expected, actual string) error { 31 | if translated, found := IDPortenLegacyMapping[expected]; found { 32 | expected = translated 33 | } 34 | 35 | acceptedValues, found := acceptedValuesMapping[expected] 36 | if !found { 37 | if expected == actual { 38 | return nil 39 | } 40 | return fmt.Errorf("invalid acr: got %q, expected %q", actual, expected) 41 | } 42 | 43 | for _, accepted := range acceptedValues { 44 | if actual == accepted { 45 | return nil 46 | } 47 | } 48 | 49 | return fmt.Errorf("invalid acr: got %q, must be one of %s", actual, acceptedValues) 50 | } 51 | -------------------------------------------------------------------------------- /pkg/middleware/context.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/nais/wonderwall/pkg/ingress" 8 | ) 9 | 10 | type contextKey string 11 | 12 | const ( 13 | ctxAccessToken = contextKey("AccessToken") 14 | ctxIDToken = contextKey("IDToken") 15 | ctxIngress = contextKey("Ingress") 16 | ctxPath = contextKey("Path") 17 | ) 18 | 19 | func AccessTokenFrom(ctx context.Context) (string, bool) { 20 | accessToken, ok := ctx.Value(ctxAccessToken).(string) 21 | return accessToken, ok 22 | } 23 | 24 | func WithAccessToken(ctx context.Context, accessToken string) context.Context { 25 | return context.WithValue(ctx, ctxAccessToken, accessToken) 26 | } 27 | 28 | func IDTokenFrom(ctx context.Context) (string, bool) { 29 | idToken, ok := ctx.Value(ctxIDToken).(string) 30 | return idToken, ok 31 | } 32 | 33 | func WithIDToken(ctx context.Context, idToken string) context.Context { 34 | return context.WithValue(ctx, ctxIDToken, idToken) 35 | } 36 | 37 | func IngressFrom(ctx context.Context) (ingress.Ingress, bool) { 38 | i, ok := ctx.Value(ctxIngress).(ingress.Ingress) 39 | return i, ok 40 | } 41 | 42 | func WithIngress(ctx context.Context, ingress ingress.Ingress) context.Context { 43 | return context.WithValue(ctx, ctxIngress, ingress) 44 | } 45 | 46 | func RequestWithIngress(r *http.Request, ing ingress.Ingress) *http.Request { 47 | ctx := r.Context() 48 | ctx = WithIngress(ctx, ing) 49 | return r.WithContext(ctx) 50 | } 51 | 52 | func PathFrom(ctx context.Context) (string, bool) { 53 | path, ok := ctx.Value(ctxPath).(string) 54 | return path, ok 55 | } 56 | 57 | func WithPath(ctx context.Context, path string) context.Context { 58 | return context.WithValue(ctx, ctxPath, path) 59 | } 60 | 61 | func RequestWithPath(r *http.Request, path string) *http.Request { 62 | ctx := r.Context() 63 | ctx = WithPath(ctx, path) 64 | return r.WithContext(ctx) 65 | } 66 | -------------------------------------------------------------------------------- /pkg/openid/client/logout.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/nais/wonderwall/internal/crypto" 9 | "github.com/nais/wonderwall/pkg/cookie" 10 | "github.com/nais/wonderwall/pkg/openid" 11 | urlpkg "github.com/nais/wonderwall/pkg/url" 12 | ) 13 | 14 | type Logout struct { 15 | *Client 16 | Cookie *openid.LogoutCookie 17 | logoutCallbackURL string 18 | } 19 | 20 | func NewLogout(c *Client, r *http.Request) (*Logout, error) { 21 | logoutCallbackURL, err := urlpkg.LogoutCallback(r) 22 | if err != nil { 23 | return nil, fmt.Errorf("generating logout callback url: %w", err) 24 | } 25 | 26 | state, err := crypto.Text(32) 27 | if err != nil { 28 | return nil, fmt.Errorf("generating state: %w", err) 29 | } 30 | 31 | logoutCookie := &openid.LogoutCookie{ 32 | State: state, 33 | } 34 | 35 | return &Logout{ 36 | Client: c, 37 | Cookie: logoutCookie, 38 | logoutCallbackURL: logoutCallbackURL, 39 | }, nil 40 | } 41 | 42 | func (in *Logout) SingleLogoutURL(idToken string) string { 43 | endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL() 44 | v := endSessionEndpoint.Query() 45 | v.Set("post_logout_redirect_uri", in.logoutCallbackURL) 46 | v.Set("state", in.Cookie.State) 47 | 48 | if len(idToken) > 0 { 49 | v.Set("id_token_hint", idToken) 50 | } 51 | 52 | endSessionEndpoint.RawQuery = v.Encode() 53 | return endSessionEndpoint.String() 54 | } 55 | 56 | func (in *Logout) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error { 57 | in.Cookie.RedirectTo = canonicalRedirect 58 | 59 | logoutCookieJson, err := json.Marshal(in.Cookie) 60 | if err != nil { 61 | return fmt.Errorf("marshalling logout cookie: %w", err) 62 | } 63 | 64 | value := string(logoutCookieJson) 65 | return cookie.EncryptAndSet(w, cookie.Logout, value, opts, crypter) 66 | } 67 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/networkpolicy.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: networking.k8s.io/v1 3 | kind: NetworkPolicy 4 | metadata: 5 | labels: 6 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 7 | name: {{ include "wonderwall-forward-auth.fullname" . }} 8 | spec: 9 | ingress: 10 | - from: 11 | - namespaceSelector: 12 | matchLabels: 13 | kubernetes.io/metadata.name: nais-system 14 | podSelector: 15 | matchLabels: 16 | app.kubernetes.io/name: prometheus 17 | - from: 18 | - namespaceSelector: 19 | matchLabels: 20 | kubernetes.io/metadata.name: nais-system 21 | podSelector: 22 | matchLabels: 23 | app.kubernetes.io/name: alloy 24 | - from: 25 | - namespaceSelector: 26 | matchLabels: 27 | kubernetes.io/metadata.name: nais-system 28 | podSelector: 29 | matchLabels: 30 | nais.io/ingressClass: {{ .Values.ingressClassName }} 31 | podSelector: 32 | matchLabels: 33 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }} 34 | policyTypes: 35 | - Ingress 36 | {{- if .Capabilities.APIVersions.Has "networking.gke.io/v1alpha3" }} 37 | --- 38 | apiVersion: networking.gke.io/v1alpha3 39 | kind: FQDNNetworkPolicy 40 | metadata: 41 | name: {{ include "wonderwall-forward-auth.fullname" . }}-fqdn 42 | labels: 43 | {{- include "wonderwall-forward-auth.labels" . | nindent 4 }} 44 | annotations: 45 | fqdnnetworkpolicies.networking.gke.io/aaaa-lookups: "skip" 46 | spec: 47 | egress: 48 | - ports: 49 | - port: 443 50 | protocol: TCP 51 | to: 52 | - fqdns: 53 | - auth.nais.io 54 | podSelector: 55 | matchLabels: 56 | {{- include "wonderwall-forward-auth.selectorLabels" . | nindent 6 }} 57 | policyTypes: 58 | - Egress 59 | {{- end }} 60 | -------------------------------------------------------------------------------- /pkg/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/go-chi/chi/v5" 13 | log "github.com/sirupsen/logrus" 14 | 15 | "github.com/nais/wonderwall/pkg/config" 16 | ) 17 | 18 | func Start(cfg *config.Config, r chi.Router) error { 19 | server := http.Server{ 20 | Addr: cfg.BindAddress, 21 | Handler: r, 22 | } 23 | 24 | serverCtx, serverStopCtx := context.WithCancel(context.Background()) 25 | 26 | sig := make(chan os.Signal, 1) 27 | signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) 28 | go func() { 29 | s := <-sig 30 | log.Infof("server: received %q; waiting for %s before starting graceful shutdown...", s, cfg.ShutdownWaitBeforePeriod) 31 | time.Sleep(cfg.ShutdownWaitBeforePeriod) 32 | 33 | // the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, so we need to subtract the wait-before period to exit before SIGKILL 34 | shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod 35 | shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, shutdownTimeout) 36 | 37 | go func() { 38 | <-shutdownCtx.Done() 39 | if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) { 40 | log.Fatalf("server: graceful shutdown timed out after %s; forcing exit.", shutdownTimeout) 41 | } 42 | }() 43 | 44 | log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout) 45 | err := server.Shutdown(shutdownCtx) 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | shutdownStopCtx() 50 | serverStopCtx() 51 | }() 52 | 53 | log.Infof("server: listening on %s", cfg.BindAddress) 54 | err := server.ListenAndServe() 55 | if err != nil && !errors.Is(err, http.ErrServerClosed) { 56 | return err 57 | } 58 | 59 | <-serverCtx.Done() 60 | log.Infof("server: shutdown completed") 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /pkg/session/store.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/prometheus/client_golang/prometheus" 9 | "github.com/redis/go-redis/extra/redisotel/v9" 10 | "github.com/redis/go-redis/extra/redisprometheus/v9" 11 | log "github.com/sirupsen/logrus" 12 | 13 | "github.com/nais/wonderwall/pkg/config" 14 | ) 15 | 16 | type Store interface { 17 | Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error 18 | Read(ctx context.Context, key string) (*EncryptedData, error) 19 | Delete(ctx context.Context, keys ...string) error 20 | Update(ctx context.Context, key string, value *EncryptedData) error 21 | 22 | MakeLock(key string) Lock 23 | } 24 | 25 | func NewStore(cfg *config.Config) (Store, error) { 26 | if len(cfg.Redis.Address) == 0 && len(cfg.Redis.URI) == 0 { 27 | log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!") 28 | return NewMemory(), nil 29 | } 30 | 31 | redisClient, err := cfg.Redis.Client() 32 | if err != nil { 33 | return nil, fmt.Errorf("failed to create Redis Client: %w", err) 34 | } 35 | 36 | collector := redisprometheus.NewCollector("wonderwall", "", redisClient) 37 | prometheus.Register(collector) 38 | 39 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) 40 | defer cancel() 41 | 42 | err = redisClient.Ping(ctx).Err() 43 | if err != nil { 44 | return nil, fmt.Errorf("failed to connect to configured Redis: %w", err) 45 | } 46 | 47 | if cfg.OpenTelemetry.Enabled { 48 | opts := []redisotel.TracingOption{ 49 | redisotel.WithDBStatement(false), 50 | redisotel.WithCallerEnabled(false), 51 | } 52 | if err := redisotel.InstrumentTracing(redisClient, opts...); err != nil { 53 | return nil, fmt.Errorf("failed to instrument Redis Client: %w", err) 54 | } 55 | log.Infof("session: using redis as backing store with OpenTelemetry instrumentation") 56 | } else { 57 | log.Infof("session: using redis as backing store") 58 | } 59 | 60 | return NewRedis(redisClient), nil 61 | } 62 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/networkpolicy.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: networking.k8s.io/v1 2 | kind: NetworkPolicy 3 | metadata: 4 | labels: 5 | {{- include "wonderwall.labels" . | nindent 4 }} 6 | name: {{ include "wonderwall.fullname" . }} 7 | spec: 8 | egress: 9 | - to: 10 | - namespaceSelector: {} 11 | podSelector: 12 | matchLabels: 13 | k8s-app: kube-dns 14 | - to: 15 | - namespaceSelector: 16 | matchLabels: 17 | linkerd.io/is-control-plane: "true" 18 | - to: 19 | - namespaceSelector: 20 | matchLabels: 21 | kubernetes.io/metadata.name: nais-system 22 | podSelector: 23 | matchLabels: 24 | app.kubernetes.io/name: tempo 25 | ingress: 26 | - from: 27 | - namespaceSelector: 28 | matchLabels: 29 | kubernetes.io/metadata.name: nais-system 30 | podSelector: 31 | matchLabels: 32 | app.kubernetes.io/name: prometheus 33 | - from: 34 | - namespaceSelector: 35 | matchLabels: 36 | kubernetes.io/metadata.name: nais-system 37 | podSelector: 38 | matchLabels: 39 | app.kubernetes.io/name: alloy 40 | - from: 41 | - namespaceSelector: 42 | matchLabels: 43 | kubernetes.io/metadata.name: nais-system 44 | podSelector: 45 | matchLabels: 46 | nais.io/ingressClass: {{ .Values.azure.forwardAuth.ingressClassName }} 47 | - from: 48 | - namespaceSelector: 49 | matchLabels: 50 | kubernetes.io/metadata.name: nais-system 51 | podSelector: 52 | matchLabels: 53 | nais.io/ingressClass: {{ .Values.idporten.ingressClassName }} 54 | - from: 55 | - namespaceSelector: 56 | matchLabels: 57 | linkerd.io/is-control-plane: "true" 58 | podSelector: 59 | matchLabels: 60 | {{- include "wonderwall.selectorLabels" . | nindent 6 }} 61 | policyTypes: 62 | - Ingress 63 | - Egress 64 | -------------------------------------------------------------------------------- /pkg/session/lock.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/bsm/redislock" 10 | "github.com/nais/wonderwall/internal/o11y/otel" 11 | "github.com/redis/go-redis/v9" 12 | "go.opentelemetry.io/otel/attribute" 13 | ) 14 | 15 | const ( 16 | KeyTemplate = "%s.lock" 17 | ) 18 | 19 | var ErrAcquireLock = errors.New("could not acquire lock") 20 | 21 | type Lock interface { 22 | Acquire(ctx context.Context, duration time.Duration) error 23 | Release(ctx context.Context) error 24 | } 25 | 26 | var _ Lock = &RedisLock{} 27 | 28 | type RedisLock struct { 29 | locker *redislock.Client 30 | lock *redislock.Lock 31 | key string 32 | } 33 | 34 | func NewRedisLock(client redis.Cmdable, key string) *RedisLock { 35 | return &RedisLock{ 36 | locker: redislock.New(client), 37 | key: key, 38 | } 39 | } 40 | 41 | func (r *RedisLock) Acquire(ctx context.Context, duration time.Duration) error { 42 | ctx, span := otel.StartSpan(ctx, "RedisLock.Acquire") 43 | defer span.End() 44 | span.SetAttributes(attribute.String("redis.lock_duration", duration.String())) 45 | span.SetAttributes(attribute.Bool("redis.lock_acquired", false)) 46 | 47 | lock, err := r.locker.Obtain(ctx, lockKey(r.key), duration, nil) 48 | if errors.Is(err, redislock.ErrNotObtained) { 49 | return ErrAcquireLock 50 | } 51 | if err != nil { 52 | return err 53 | } 54 | 55 | r.lock = lock 56 | span.SetAttributes(attribute.Bool("redis.lock_acquired", true)) 57 | return nil 58 | } 59 | 60 | func (r *RedisLock) Release(ctx context.Context) error { 61 | ctx, span := otel.StartSpan(ctx, "RedisLock.Release") 62 | defer span.End() 63 | return r.lock.Release(ctx) 64 | } 65 | 66 | var _ Lock = &NoOpLock{} 67 | 68 | type NoOpLock struct{} 69 | 70 | func NewNoOpLock() *NoOpLock { 71 | return new(NoOpLock) 72 | } 73 | 74 | func (n *NoOpLock) Acquire(_ context.Context, _ time.Duration) error { 75 | return nil 76 | } 77 | 78 | func (n *NoOpLock) Release(_ context.Context) error { 79 | return nil 80 | } 81 | 82 | func lockKey(key string) string { 83 | return fmt.Sprintf(KeyTemplate, key) 84 | } 85 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/Feature.yaml: -------------------------------------------------------------------------------- 1 | environmentKinds: 2 | - management 3 | dependencies: 4 | - allOf: 5 | - nais-netpols-management 6 | values: 7 | openid.clientID: 8 | computed: 9 | template: | 10 | {{ .Env.wonderwall_forward_auth_zitadel_client_id | quote }} 11 | openid.clientSecret: 12 | computed: 13 | template: | 14 | {{ .Env.wonderwall_forward_auth_zitadel_client_secret | quote }} 15 | openid.extraAudience: 16 | description: Comma separated list of additional audiences for id_token validation. 17 | computed: 18 | template: | 19 | {{ .Env.wonderwall_forward_auth_zitadel_project_id | quote }} 20 | openid.extraScopes: 21 | description: Comma separated list of additional scopes to request from the OpenID provider. 22 | computed: 23 | template: | 24 | "urn:zitadel:iam:org:id:{{ .Env.zitadel_organization_id }}" 25 | replicas.min: 26 | config: 27 | type: int 28 | replicas.max: 29 | config: 30 | type: int 31 | session.cookieEncryptionKey: 32 | description: Cookie encryption key, 256 bits (e.g. 32 ASCII characters) encoded with standard base64. 33 | computed: 34 | template: | 35 | {{ .Env.wonderwall_forward_auth_encryption_key | quote }} 36 | sso.domain: 37 | description: Domain for forward auth 38 | computed: 39 | template: | 40 | {{ .Tenant.Name }}.cloud.nais.io 41 | sso.defaultRedirectURL: 42 | description: Default redirect URL for forward auth 43 | computed: 44 | template: | 45 | {{ printf "https://%s" (subdomain . "console") | quote }} 46 | valkey.host: 47 | computed: 48 | template: | 49 | {{ .Env.wonderwall_forward_auth_valkey_host | quote }} 50 | valkey.port: 51 | computed: 52 | template: | 53 | {{ .Env.wonderwall_forward_auth_valkey_port | quote }} 54 | valkey.username: 55 | computed: 56 | template: | 57 | {{ .Env.wonderwall_forward_auth_valkey_username | quote }} 58 | valkey.password: 59 | computed: 60 | template: | 61 | {{ .Env.wonderwall_forward_auth_valkey_password | quote }} 62 | -------------------------------------------------------------------------------- /charts/wonderwall-forward-auth/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "wonderwall-forward-auth.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "wonderwall-forward-auth.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "wonderwall-forward-auth.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "wonderwall-forward-auth.labels" -}} 37 | helm.sh/chart: {{ include "wonderwall-forward-auth.chart" . }} 38 | {{ include "wonderwall-forward-auth.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "wonderwall-forward-auth.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "wonderwall-forward-auth.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "wonderwall-forward-auth.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "wonderwall-forward-auth.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} 63 | -------------------------------------------------------------------------------- /pkg/openid/config/provider_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/nais/wonderwall/pkg/config" 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/nais/wonderwall/pkg/mock" 10 | openidconfig "github.com/nais/wonderwall/pkg/openid/config" 11 | ) 12 | 13 | func TestProviderMetadata_Validate(t *testing.T) { 14 | metadata := &openidconfig.ProviderMetadata{ 15 | ACRValuesSupported: openidconfig.Supported{"idporten-loa-substantial", "idporten-loa-high"}, 16 | UILocalesSupported: openidconfig.Supported{"nb", "nb", "en", "se"}, 17 | IDTokenSigningAlgValuesSupported: openidconfig.Supported{"RS256"}, 18 | } 19 | 20 | for _, tt := range []struct { 21 | name string 22 | config config.OpenID 23 | assertion assert.ErrorAssertionFunc 24 | }{ 25 | { 26 | name: "happy path", 27 | config: config.OpenID{ACRValues: "idporten-loa-high", UILocales: "nb"}, 28 | assertion: assert.NoError, 29 | }, 30 | { 31 | name: "invalid acr", 32 | config: config.OpenID{ACRValues: "Level5"}, 33 | assertion: assert.Error, 34 | }, 35 | { 36 | name: "invalid locale", 37 | config: config.OpenID{UILocales: "de"}, 38 | assertion: assert.Error, 39 | }, 40 | { 41 | name: "has acr translation for Level4", 42 | config: config.OpenID{ACRValues: "Level4"}, 43 | assertion: assert.NoError, 44 | }, 45 | { 46 | name: "has acr translation for Level3", 47 | config: config.OpenID{ACRValues: "Level3"}, 48 | assertion: assert.NoError, 49 | }, 50 | { 51 | name: "invalid signing algorithm", 52 | config: config.OpenID{IDTokenSigningAlg: "HS256"}, 53 | assertion: assert.Error, 54 | }, 55 | } { 56 | t.Run(tt.name, func(t *testing.T) { 57 | cfg := mock.Config() 58 | if tt.config.ACRValues != "" { 59 | cfg.OpenID.ACRValues = tt.config.ACRValues 60 | } 61 | if tt.config.UILocales != "" { 62 | cfg.OpenID.UILocales = tt.config.UILocales 63 | } 64 | if tt.config.IDTokenSigningAlg != "" { 65 | cfg.OpenID.IDTokenSigningAlg = tt.config.IDTokenSigningAlg 66 | } 67 | 68 | err := metadata.Validate(cfg.OpenID) 69 | tt.assertion(t, err) 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /charts/wonderwall/values.yaml: -------------------------------------------------------------------------------- 1 | nameOverride: "" 2 | fullnameOverride: "" 3 | 4 | image: 5 | repository: europe-north1-docker.pkg.dev/nais-io/nais/images/wonderwall 6 | tag: latest 7 | imagePullSecrets: [] 8 | 9 | aiven: 10 | project: 11 | prometheusEndpointId: 12 | redisPlan: 13 | azure: 14 | enabled: false 15 | redisSecretName: wonderwall-azure-redis-rw 16 | sessionMaxLifetime: 10h 17 | forwardAuth: 18 | enabled: false 19 | replicasMin: 2 20 | replicasMax: 4 21 | clientSecretName: azure-sso-server 22 | ingressClassName: nais-ingress-fa 23 | # 256 bits key, in standard base64 encoding 24 | sessionCookieEncryptionKey: 25 | sessionCookieName: forwardauth 26 | ssoDefaultRedirectURL: 27 | ssoDomain: 28 | ssoServerSecretName: wonderwall-azure-sso-server 29 | groupIds: [] # [""] - additional group IDs to grant access to 30 | idporten: 31 | enabled: false 32 | clientAccessTokenLifetime: 3600 33 | clientSessionLifetime: 21600 34 | clientSecretName: idporten-sso-server 35 | ingressClassName: nais-ingress-external 36 | openidAcrValues: idporten-loa-high 37 | openidLocale: nb 38 | openidPostLogoutRedirectURL: 39 | openidResourceIndicator: 40 | redisSecretNames: 41 | read: wonderwall-idporten-redis-ro 42 | readwrite: wonderwall-idporten-redis-rw 43 | replicasMax: 4 44 | replicasMin: 2 45 | sessionCookieName: 46 | # 256 bits key, in standard base64 encoding 47 | sessionCookieEncryptionKey: 48 | sessionInactivity: true 49 | sessionInactivityTimeout: 1h 50 | sessionMaxLifetime: 6h 51 | ssoServerHost: 52 | # secret for configuring Wonderwall server itself 53 | ssoServerSecretName: wonderwall-idporten-sso-server 54 | ssoDefaultRedirectURL: 55 | ssoDomain: 56 | openid: 57 | enabled: true 58 | redisSecretName: wonderwall-openid-redis-rw 59 | # https:///.well-known/openid-configuration 60 | wellKnownUrl: 61 | redis: 62 | connectionIdleTimeout: 299 63 | resources: 64 | limits: 65 | cpu: "2" 66 | memory: 512Mi 67 | requests: 68 | cpu: 100m 69 | memory: 64Mi 70 | resourceSuffix: "" 71 | podDisruptionBudget: 72 | maxUnavailable: 1 73 | otel: 74 | endpoint: http://opentelemetry-collector.nais-system:4317 75 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1710146030, 9 | "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "gitignore": { 22 | "inputs": { 23 | "nixpkgs": [ 24 | "nixpkgs" 25 | ] 26 | }, 27 | "locked": { 28 | "lastModified": 1709087332, 29 | "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", 30 | "owner": "hercules-ci", 31 | "repo": "gitignore.nix", 32 | "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", 33 | "type": "github" 34 | }, 35 | "original": { 36 | "owner": "hercules-ci", 37 | "repo": "gitignore.nix", 38 | "type": "github" 39 | } 40 | }, 41 | "nixpkgs": { 42 | "locked": { 43 | "lastModified": 1722141560, 44 | "narHash": "sha256-Ul3rIdesWaiW56PS/Ak3UlJdkwBrD4UcagCmXZR9Z7Y=", 45 | "owner": "NixOS", 46 | "repo": "nixpkgs", 47 | "rev": "038fb464fcfa79b4f08131b07f2d8c9a6bcc4160", 48 | "type": "github" 49 | }, 50 | "original": { 51 | "owner": "NixOS", 52 | "ref": "nixpkgs-unstable", 53 | "repo": "nixpkgs", 54 | "type": "github" 55 | } 56 | }, 57 | "root": { 58 | "inputs": { 59 | "flake-utils": "flake-utils", 60 | "gitignore": "gitignore", 61 | "nixpkgs": "nixpkgs" 62 | } 63 | }, 64 | "systems": { 65 | "locked": { 66 | "lastModified": 1681028828, 67 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 68 | "owner": "nix-systems", 69 | "repo": "default", 70 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 71 | "type": "github" 72 | }, 73 | "original": { 74 | "owner": "nix-systems", 75 | "repo": "default", 76 | "type": "github" 77 | } 78 | } 79 | }, 80 | "root": "root", 81 | "version": 7 82 | } 83 | -------------------------------------------------------------------------------- /pkg/session/session_reader.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | 9 | "github.com/nais/wonderwall/internal/crypto" 10 | "github.com/nais/wonderwall/internal/o11y/otel" 11 | "github.com/nais/wonderwall/internal/retry" 12 | "github.com/nais/wonderwall/pkg/config" 13 | "go.opentelemetry.io/otel/attribute" 14 | "go.opentelemetry.io/otel/trace" 15 | ) 16 | 17 | var _ Reader = &reader{} 18 | 19 | type reader struct { 20 | cfg *config.Config 21 | cookieCrypter crypto.Crypter 22 | store Store 23 | } 24 | 25 | func NewReader(cfg *config.Config, cookieCrypter crypto.Crypter) (Reader, error) { 26 | store, err := NewStore(cfg) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | return &reader{ 32 | cfg: cfg, 33 | cookieCrypter: cookieCrypter, 34 | store: store, 35 | }, nil 36 | } 37 | 38 | func (in *reader) Get(r *http.Request) (*Session, error) { 39 | r, span := otel.StartSpanFromRequest(r, "Session.Get") 40 | defer span.End() 41 | 42 | ticket, err := getTicket(r, in.cookieCrypter) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | return in.getForTicket(r.Context(), ticket) 48 | } 49 | 50 | func (in *reader) getForTicket(ctx context.Context, ticket *Ticket) (*Session, error) { 51 | span := trace.SpanFromContext(ctx) 52 | span.SetAttributes(attribute.Bool("session.valid_session", false)) 53 | 54 | encrypted, err := retry.DoValue(ctx, func(ctx context.Context) (*EncryptedData, error) { 55 | encrypted, err := in.store.Read(ctx, ticket.Key()) 56 | if errors.Is(err, ErrNotFound) { 57 | return nil, err 58 | } 59 | if err != nil { 60 | return nil, retry.RetryableError(err) 61 | } 62 | return encrypted, nil 63 | }) 64 | if err != nil { 65 | return nil, fmt.Errorf("reading from store: %w", err) 66 | } 67 | 68 | data, err := encrypted.Decrypt(ticket.Crypter()) 69 | if err != nil { 70 | return nil, fmt.Errorf("%w: decrypting session data: %w", ErrInvalid, err) 71 | } 72 | 73 | sess := NewSession(data, ticket) 74 | if sess != nil { 75 | span.SetAttributes(attribute.String("session.id", sess.ExternalSessionID())) 76 | } 77 | err = data.Validate() 78 | if err != nil { 79 | return sess, err 80 | } 81 | 82 | span.SetAttributes(attribute.Bool("session.valid_session", true)) 83 | data.Metadata.SetSpanAttributes(span) 84 | return sess, nil 85 | } 86 | -------------------------------------------------------------------------------- /pkg/cookie/options_test.go: -------------------------------------------------------------------------------- 1 | package cookie_test 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/nais/wonderwall/pkg/cookie" 10 | ) 11 | 12 | func TestDefaultOptions(t *testing.T) { 13 | opts := cookie.DefaultOptions() 14 | 15 | assert.Equal(t, http.SameSiteLaxMode, opts.SameSite) 16 | assert.True(t, opts.Secure) 17 | assert.Empty(t, opts.Domain) 18 | assert.Empty(t, opts.Path) 19 | } 20 | 21 | func TestOptions_WithDomain(t *testing.T) { 22 | domain := ".some.domain" 23 | opts := cookie.Options{}.WithDomain(domain) 24 | 25 | assert.Equal(t, ".some.domain", opts.Domain) 26 | 27 | opts = cookie.Options{ 28 | Domain: ".domain", 29 | } 30 | newOpts := opts.WithDomain(".some.other.domain") 31 | 32 | assert.Equal(t, ".domain", opts.Domain, "original options should be unchanged") 33 | assert.Equal(t, ".some.other.domain", newOpts.Domain, "copy of options should have new value") 34 | } 35 | 36 | func TestOptions_WithPath(t *testing.T) { 37 | path := "/some/path" 38 | opts := cookie.Options{}.WithPath(path) 39 | 40 | assert.Equal(t, "/some/path", opts.Path) 41 | 42 | opts = cookie.Options{ 43 | Path: "/some/path", 44 | } 45 | newOpts := opts.WithPath("/some/other/path") 46 | 47 | assert.Equal(t, "/some/path", opts.Path, "original options should be unchanged") 48 | assert.Equal(t, "/some/other/path", newOpts.Path, "copy of options should have new value") 49 | } 50 | 51 | func TestOptions_WithSameSite(t *testing.T) { 52 | sameSite := http.SameSiteDefaultMode 53 | opts := cookie.Options{}.WithSameSite(sameSite) 54 | 55 | assert.Equal(t, http.SameSiteDefaultMode, opts.SameSite) 56 | 57 | opts = cookie.Options{ 58 | SameSite: http.SameSiteLaxMode, 59 | } 60 | newOpts := opts.WithSameSite(sameSite) 61 | 62 | assert.Equal(t, http.SameSiteLaxMode, opts.SameSite, "original options should be unchanged") 63 | assert.Equal(t, http.SameSiteDefaultMode, newOpts.SameSite, "copy of options should have new value") 64 | } 65 | 66 | func TestOptions_WithSecure(t *testing.T) { 67 | opts := cookie.Options{}.WithSecure(true) 68 | 69 | assert.True(t, opts.Secure) 70 | 71 | opts = cookie.Options{ 72 | Secure: false, 73 | } 74 | newOpts := opts.WithSecure(true) 75 | 76 | assert.False(t, opts.Secure, "original options should be unchanged") 77 | assert.True(t, newOpts.Secure, "copy of options should have new value") 78 | } 79 | -------------------------------------------------------------------------------- /pkg/mock/client.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "github.com/lestrrat-go/jwx/v3/jwk" 5 | "github.com/nais/wonderwall/internal/crypto" 6 | "github.com/nais/wonderwall/pkg/config" 7 | openidconfig "github.com/nais/wonderwall/pkg/openid/config" 8 | "github.com/nais/wonderwall/pkg/openid/scopes" 9 | ) 10 | 11 | type TestClientConfiguration struct { 12 | *config.Config 13 | clientJwk jwk.Key 14 | trustedAudiences map[string]bool 15 | } 16 | 17 | var _ openidconfig.Client = (*TestClientConfiguration)(nil) 18 | 19 | func (c *TestClientConfiguration) ACRValues() string { 20 | return c.Config.OpenID.ACRValues 21 | } 22 | 23 | func (c *TestClientConfiguration) Audiences() map[string]bool { 24 | return c.trustedAudiences 25 | } 26 | 27 | func (c *TestClientConfiguration) AuthMethod() openidconfig.AuthMethod { 28 | return openidconfig.AuthMethodPrivateKeyJWT 29 | } 30 | 31 | func (c *TestClientConfiguration) ClientID() string { 32 | return c.Config.OpenID.ClientID 33 | } 34 | 35 | func (c *TestClientConfiguration) ClientJWK() jwk.Key { 36 | return c.clientJwk 37 | } 38 | 39 | func (c *TestClientConfiguration) ClientSecret() string { 40 | return c.Config.OpenID.ClientSecret 41 | } 42 | 43 | func (c *TestClientConfiguration) NewClientAuthJWTType() bool { 44 | return c.Config.OpenID.NewClientAuthJWTType 45 | } 46 | 47 | func (c *TestClientConfiguration) SetPostLogoutRedirectURI(uri string) { 48 | c.Config.OpenID.PostLogoutRedirectURI = uri 49 | } 50 | 51 | func (c *TestClientConfiguration) PostLogoutRedirectURI() string { 52 | return c.Config.OpenID.PostLogoutRedirectURI 53 | } 54 | 55 | func (c *TestClientConfiguration) ResourceIndicator() string { 56 | return c.Config.OpenID.ResourceIndicator 57 | } 58 | 59 | func (c *TestClientConfiguration) Scopes() scopes.Scopes { 60 | return scopes.DefaultScopes().WithAdditional(c.Config.OpenID.Scopes...) 61 | } 62 | 63 | func (c *TestClientConfiguration) UILocales() string { 64 | return c.Config.OpenID.UILocales 65 | } 66 | 67 | func (c *TestClientConfiguration) WellKnownURL() string { 68 | return c.Config.OpenID.WellKnownURL 69 | } 70 | 71 | func clientConfiguration(cfg *config.Config) *TestClientConfiguration { 72 | key, err := crypto.NewJwk() 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | return &TestClientConfiguration{ 78 | Config: cfg, 79 | clientJwk: key, 80 | trustedAudiences: cfg.OpenID.TrustedAudiences(), 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /pkg/config/redis.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "crypto/tls" 5 | "time" 6 | 7 | "github.com/redis/go-redis/v9" 8 | flag "github.com/spf13/pflag" 9 | ) 10 | 11 | type Redis struct { 12 | Address string `json:"address"` 13 | Username string `json:"username"` 14 | Password string `json:"password"` 15 | TLS bool `json:"tls"` 16 | URI string `json:"uri"` 17 | ConnectionIdleTimeout int `json:"connection-idle-timeout"` 18 | } 19 | 20 | func (r *Redis) Client() (*redis.Client, error) { 21 | opts := &redis.Options{ 22 | Network: "tcp", 23 | Addr: r.Address, 24 | } 25 | 26 | if r.TLS { 27 | opts.TLSConfig = &tls.Config{} 28 | } 29 | 30 | if r.URI != "" { 31 | var err error 32 | 33 | opts, err = redis.ParseURL(r.URI) 34 | if err != nil { 35 | return nil, err 36 | } 37 | } 38 | 39 | opts.MinIdleConns = 1 40 | opts.MaxRetries = 5 41 | 42 | if r.Username != "" { 43 | opts.Username = r.Username 44 | } 45 | 46 | if r.Password != "" { 47 | opts.Password = r.Password 48 | } 49 | 50 | if r.ConnectionIdleTimeout > 0 { 51 | opts.ConnMaxIdleTime = time.Duration(r.ConnectionIdleTimeout) * time.Second 52 | } else if r.ConnectionIdleTimeout == -1 { 53 | opts.ConnMaxIdleTime = -1 54 | } 55 | 56 | return redis.NewClient(opts), nil 57 | } 58 | 59 | const ( 60 | RedisAddress = "redis.address" 61 | RedisPassword = "redis.password" 62 | RedisTLS = "redis.tls" 63 | RedisUsername = "redis.username" 64 | RedisURI = "redis.uri" 65 | RedisConnectionIdleTimeout = "redis.connection-idle-timeout" 66 | ) 67 | 68 | func redisFlags() { 69 | flag.String(RedisURI, "", "Redis URI string. An empty value will fall back to 'redis-address'.") 70 | flag.String(RedisAddress, "", "Deprecated: prefer using 'redis.uri'. Address of the Redis instance (host:port). An empty value will use in-memory session storage. Does not override address set by 'redis.uri'.") 71 | flag.String(RedisPassword, "", "Password for Redis. Overrides password set by 'redis.uri'.") 72 | flag.Bool(RedisTLS, true, "Whether or not to use TLS for connecting to Redis. Does not override TLS config set by 'redis.uri'.") 73 | flag.String(RedisUsername, "", "Username for Redis. Overrides username set by 'redis.uri'.") 74 | flag.Int(RedisConnectionIdleTimeout, 0, "Idle timeout for Redis connections, in seconds. If non-zero, the value should be less than the client timeout configured at the Redis server. A value of -1 disables timeout. If zero, the default value from go-redis is used (30 minutes). Overrides options set by 'redis.uri'.") 75 | } 76 | -------------------------------------------------------------------------------- /pkg/config/cookie.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "slices" 8 | "strings" 9 | 10 | flag "github.com/spf13/pflag" 11 | ) 12 | 13 | type Cookie struct { 14 | Prefix string `json:"prefix"` 15 | SameSite SameSite `json:"same-site"` 16 | Secure bool `json:"secure"` 17 | } 18 | 19 | func (c *Cookie) Validate(cfg *Config) error { 20 | if err := c.SameSite.Validate(); err != nil { 21 | return err 22 | } 23 | 24 | if c.Secure { 25 | return nil 26 | } 27 | 28 | for _, ingress := range cfg.Ingresses { 29 | u, err := url.ParseRequestURI(ingress) 30 | if err != nil { 31 | return fmt.Errorf("parsing ingress URL %q: %w", ingress, err) 32 | } 33 | 34 | if !strings.EqualFold(u.Hostname(), "localhost") { 35 | return fmt.Errorf("ingress %q is not localhost (was %q); cannot disable secure cookies", ingress, u.Hostname()) 36 | } 37 | 38 | if u.Scheme != "http" { 39 | return fmt.Errorf("ingress %q is not HTTP (was %q); cannot disable secure cookies", ingress, u.Scheme) 40 | } 41 | } 42 | 43 | logger.Warn("secure cookies are disabled; not suitable for production use!") 44 | return nil 45 | } 46 | 47 | type SameSite string 48 | 49 | const ( 50 | SameSiteLax SameSite = "Lax" 51 | SameSiteNone SameSite = "None" 52 | SameSiteStrict SameSite = "Strict" 53 | ) 54 | 55 | // ToHttp returns the equivalent http.SameSite value for the SameSite attribute. 56 | func (s SameSite) ToHttp() http.SameSite { 57 | switch s { 58 | case SameSiteNone: 59 | return http.SameSiteNoneMode 60 | case SameSiteStrict: 61 | return http.SameSiteStrictMode 62 | default: 63 | return http.SameSiteLaxMode 64 | } 65 | } 66 | 67 | func (s SameSite) Validate() error { 68 | all := []SameSite{ 69 | SameSiteLax, 70 | SameSiteNone, 71 | SameSiteStrict, 72 | } 73 | 74 | if slices.Contains(all, s) { 75 | return nil 76 | } 77 | return fmt.Errorf("%q must be one of %q (was %q)", CookieSameSite, all, s) 78 | } 79 | 80 | const ( 81 | CookiePrefix = "cookie.prefix" 82 | CookieSameSite = "cookie.same-site" 83 | CookieSecure = "cookie.secure" 84 | EncryptionKey = "encryption-key" 85 | ) 86 | 87 | func cookieFlags() { 88 | flag.String(CookiePrefix, "io.nais.wonderwall", "Prefix for cookie names.") 89 | flag.String(CookieSameSite, string(SameSiteLax), "SameSite attribute for session cookies.") 90 | flag.Bool(CookieSecure, true, "Set secure flag on session cookies. Can only be disabled when `ingress` only consist of localhost hosts. Generally, disabling this is only necessary when using Safari.") 91 | flag.String(EncryptionKey, "", "Base64 encoded 256-bit cookie encryption key; must be identical in instances that share session store.") 92 | } 93 | -------------------------------------------------------------------------------- /pkg/openid/client/logout_test.go: -------------------------------------------------------------------------------- 1 | package client_test 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/nais/wonderwall/pkg/mock" 10 | "github.com/nais/wonderwall/pkg/openid/client" 11 | ) 12 | 13 | const ( 14 | LogoutCallbackURI = mock.Ingress + "/oauth2/logout/callback" 15 | PostLogoutRedirectURI = "http://some-other-url" 16 | EndSessionEndpoint = "http://provider/endsession" 17 | ) 18 | 19 | func TestLogout_SingleLogoutURL(t *testing.T) { 20 | t.Run("with id_token", func(t *testing.T) { 21 | logout := newLogout(t) 22 | idToken := "some-id-token" 23 | state := logout.Cookie.State 24 | 25 | raw := logout.SingleLogoutURL(idToken) 26 | assert.NotEmpty(t, raw) 27 | 28 | logoutUrl, err := url.Parse(raw) 29 | assert.NoError(t, err) 30 | 31 | query := logoutUrl.Query() 32 | assert.Len(t, query, 3) 33 | 34 | assert.Contains(t, query, "id_token_hint") 35 | assert.Equal(t, idToken, query.Get("id_token_hint")) 36 | 37 | assert.Contains(t, query, "post_logout_redirect_uri") 38 | assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) 39 | 40 | assert.Contains(t, query, "state") 41 | assert.Equal(t, state, query.Get("state")) 42 | 43 | logoutUrl.RawQuery = "" 44 | assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) 45 | }) 46 | 47 | t.Run("without id_token", func(t *testing.T) { 48 | logout := newLogout(t) 49 | idToken := "" 50 | state := logout.Cookie.State 51 | 52 | raw := logout.SingleLogoutURL(idToken) 53 | assert.NotEmpty(t, raw) 54 | 55 | logoutUrl, err := url.Parse(raw) 56 | assert.NoError(t, err) 57 | 58 | query := logoutUrl.Query() 59 | assert.Len(t, query, 2) 60 | 61 | assert.NotContains(t, query, "id_token_hint") 62 | assert.Equal(t, idToken, query.Get("id_token_hint")) 63 | 64 | assert.Contains(t, query, "post_logout_redirect_uri") 65 | assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) 66 | 67 | assert.Contains(t, query, "state") 68 | assert.Equal(t, state, query.Get("state")) 69 | 70 | logoutUrl.RawQuery = "" 71 | assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) 72 | }) 73 | } 74 | 75 | func newLogout(t *testing.T) *client.Logout { 76 | cfg := mock.Config() 77 | 78 | openidCfg := mock.NewTestConfiguration(cfg) 79 | openidCfg.TestClient.SetPostLogoutRedirectURI(PostLogoutRedirectURI) 80 | openidCfg.TestProvider.SetEndSessionEndpoint(EndSessionEndpoint) 81 | ingresses := mock.Ingresses(cfg) 82 | 83 | req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout", ingresses) 84 | 85 | logout, err := newTestClientWithConfig(openidCfg).Logout(req) 86 | assert.NoError(t, err) 87 | 88 | return logout 89 | } 90 | -------------------------------------------------------------------------------- /pkg/url/url.go: -------------------------------------------------------------------------------- 1 | package url 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/url" 7 | 8 | mw "github.com/nais/wonderwall/pkg/middleware" 9 | "github.com/nais/wonderwall/pkg/router/paths" 10 | ) 11 | 12 | const ( 13 | RedirectQueryParameter = "redirect" 14 | ) 15 | 16 | var ErrNoMatchingIngress = errors.New("request host does not match any configured ingresses") 17 | 18 | // Login constructs a URL string that points to the login path for the given target URL. 19 | // The given redirect string should point to the location to be redirected to after login. 20 | func Login(target *url.URL, redirect string) string { 21 | u := target.JoinPath(paths.OAuth2, paths.Login) 22 | 23 | if len(redirect) > 0 { 24 | v := u.Query() 25 | v.Set(RedirectQueryParameter, redirect) 26 | u.RawQuery = v.Encode() 27 | } 28 | 29 | return u.String() 30 | } 31 | 32 | // LoginRelative constructs the relative URL with an absolute path that points to the application's login path, given an optional path prefix. 33 | // The given redirect string should point to the location to be redirected to after login. 34 | func LoginRelative(prefix, redirect string) string { 35 | u := new(url.URL) 36 | u.Path = prefix 37 | 38 | if prefix == "" { 39 | u.Path = "/" 40 | } 41 | 42 | return Login(u, redirect) 43 | } 44 | 45 | // Logout constructs a URL string that points to the logout path for the given target URL. 46 | // The given redirect string should point to the location to be redirected to after logout. 47 | func Logout(target *url.URL, redirect string) string { 48 | u := target.JoinPath(paths.OAuth2, paths.Logout) 49 | 50 | if len(redirect) > 0 { 51 | v := u.Query() 52 | v.Set(RedirectQueryParameter, redirect) 53 | u.RawQuery = v.Encode() 54 | } 55 | 56 | return u.String() 57 | } 58 | 59 | func LoginCallback(r *http.Request) (string, error) { 60 | return makeCallbackURL(r, paths.LoginCallback) 61 | } 62 | 63 | func LogoutCallback(r *http.Request) (string, error) { 64 | return makeCallbackURL(r, paths.LogoutCallback) 65 | } 66 | 67 | func makeCallbackURL(r *http.Request, callbackPath string) (string, error) { 68 | u, err := MatchingIngress(r) 69 | if err != nil { 70 | return "", err 71 | } 72 | 73 | return u.JoinPath(paths.OAuth2, callbackPath).String(), nil 74 | } 75 | 76 | func MatchingPath(r *http.Request) *url.URL { 77 | u := &url.URL{} 78 | 79 | p, found := mw.PathFrom(r.Context()) 80 | if found && len(p) > 0 { 81 | u.Path = p 82 | } else { 83 | u.Path = "/" 84 | } 85 | 86 | return u 87 | } 88 | 89 | func MatchingIngress(r *http.Request) (*url.URL, error) { 90 | ing, found := mw.IngressFrom(r.Context()) 91 | if !found { 92 | return nil, ErrNoMatchingIngress 93 | } 94 | 95 | return ing.NewURL(), nil 96 | } 97 | -------------------------------------------------------------------------------- /pkg/openid/client/logout_callback_test.go: -------------------------------------------------------------------------------- 1 | package client_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/nais/wonderwall/pkg/config" 9 | "github.com/nais/wonderwall/pkg/mock" 10 | "github.com/nais/wonderwall/pkg/openid" 11 | "github.com/nais/wonderwall/pkg/openid/client" 12 | "github.com/nais/wonderwall/pkg/url" 13 | ) 14 | 15 | func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) { 16 | const defaultState = "some-state" 17 | const defaultRedirectURI = "http://some-fancy-logout-page" 18 | 19 | for _, tt := range []struct { 20 | name string 21 | emptyDefaultURI bool 22 | cookie *openid.LogoutCookie 23 | expected string 24 | }{ 25 | { 26 | name: "happy path", 27 | expected: defaultRedirectURI, 28 | }, 29 | { 30 | name: "empty default uri", 31 | emptyDefaultURI: true, 32 | expected: mock.Ingress, 33 | }, 34 | { 35 | name: "state mismatch", 36 | cookie: &openid.LogoutCookie{ 37 | State: "some-other-state", 38 | }, 39 | expected: defaultRedirectURI, 40 | }, 41 | { 42 | name: "happy path, redirect in cookie", 43 | cookie: &openid.LogoutCookie{ 44 | State: defaultState, 45 | RedirectTo: "http://wonderwall/some/path", 46 | }, 47 | expected: "http://wonderwall/some/path", 48 | }, 49 | { 50 | name: "empty redirect in cookie", 51 | cookie: &openid.LogoutCookie{ 52 | State: defaultState, 53 | RedirectTo: "", 54 | }, 55 | expected: defaultRedirectURI, 56 | }, 57 | { 58 | name: "state mismatch, with redirect in cookie", 59 | cookie: &openid.LogoutCookie{ 60 | State: "some-other-state", 61 | RedirectTo: "http://wonderwall/some/path", 62 | }, 63 | expected: defaultRedirectURI, 64 | }, 65 | { 66 | name: "invalid redirect in cookie", 67 | cookie: &openid.LogoutCookie{ 68 | State: defaultState, 69 | RedirectTo: "http://not-wonderwall/some/path", 70 | }, 71 | expected: defaultRedirectURI, 72 | }, 73 | } { 74 | t.Run(tt.name, func(t *testing.T) { 75 | cfg := mock.Config() 76 | cfg.OpenID.PostLogoutRedirectURI = defaultRedirectURI 77 | 78 | if tt.emptyDefaultURI { 79 | cfg.OpenID.PostLogoutRedirectURI = "" 80 | } 81 | 82 | lc := newLogoutCallback(cfg, defaultState, tt.cookie) 83 | 84 | uri := lc.PostLogoutRedirectURI() 85 | assert.NotEmpty(t, uri) 86 | assert.Equal(t, tt.expected, uri) 87 | }) 88 | } 89 | } 90 | 91 | func newLogoutCallback(cfg *config.Config, state string, cookie *openid.LogoutCookie) *client.LogoutCallback { 92 | openidCfg := mock.NewTestConfiguration(cfg) 93 | ingresses := mock.Ingresses(cfg) 94 | validator := url.NewAbsoluteValidator(ingresses.Hosts()) 95 | req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback?state="+state, ingresses) 96 | return newTestClientWithConfig(openidCfg).LogoutCallback(req, cookie, validator) 97 | } 98 | -------------------------------------------------------------------------------- /pkg/session/store_redis.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/nais/wonderwall/internal/o11y/otel" 10 | "github.com/nais/wonderwall/pkg/metrics" 11 | "github.com/redis/go-redis/v9" 12 | "go.opentelemetry.io/otel/attribute" 13 | ) 14 | 15 | type redisSessionStore struct { 16 | client redis.Cmdable 17 | } 18 | 19 | var _ Store = &redisSessionStore{} 20 | 21 | func NewRedis(client redis.Cmdable) Store { 22 | return &redisSessionStore{ 23 | client: client, 24 | } 25 | } 26 | 27 | func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) { 28 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Read") 29 | defer span.End() 30 | span.SetAttributes(attribute.Bool("redis.key_exists", false)) 31 | 32 | encryptedData := &EncryptedData{} 33 | err := metrics.ObserveRedisLatency(metrics.RedisOperationRead, func() error { 34 | return s.client.Get(ctx, key).Scan(encryptedData) 35 | }) 36 | if err == nil { 37 | span.SetAttributes(attribute.Bool("redis.key_exists", true)) 38 | return encryptedData, nil 39 | } 40 | 41 | if errors.Is(err, redis.Nil) { 42 | return nil, fmt.Errorf("%w: %w", ErrNotFound, err) 43 | } 44 | 45 | return nil, err 46 | } 47 | 48 | func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error { 49 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Write") 50 | defer span.End() 51 | span.SetAttributes(attribute.String("redis.key_expiry", expiration.String())) 52 | 53 | err := metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error { 54 | return s.client.Set(ctx, key, value, expiration).Err() 55 | }) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | return nil 61 | } 62 | 63 | func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error { 64 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Delete") 65 | defer span.End() 66 | 67 | err := metrics.ObserveRedisLatency(metrics.RedisOperationDelete, func() error { 68 | return s.client.Del(ctx, keys...).Err() 69 | }) 70 | if err == nil { 71 | return nil 72 | } 73 | 74 | if errors.Is(err, redis.Nil) { 75 | return fmt.Errorf("%w: %w", ErrNotFound, err) 76 | } 77 | 78 | return err 79 | } 80 | 81 | func (s *redisSessionStore) Update(ctx context.Context, key string, value *EncryptedData) error { 82 | ctx, span := otel.StartSpan(ctx, "RedisSessionStore.Update") 83 | defer span.End() 84 | 85 | _, err := s.Read(ctx, key) 86 | if err != nil { 87 | return err 88 | } 89 | 90 | err = metrics.ObserveRedisLatency(metrics.RedisOperationUpdate, func() error { 91 | return s.client.Set(ctx, key, value, redis.KeepTTL).Err() 92 | }) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | return nil 98 | } 99 | 100 | func (s *redisSessionStore) MakeLock(key string) Lock { 101 | return NewRedisLock(s.client, key) 102 | } 103 | -------------------------------------------------------------------------------- /pkg/session/ticket.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | 9 | "github.com/nais/liberator/pkg/keygen" 10 | "go.opentelemetry.io/otel/attribute" 11 | "go.opentelemetry.io/otel/trace" 12 | 13 | "github.com/nais/wonderwall/internal/crypto" 14 | "github.com/nais/wonderwall/pkg/cookie" 15 | ) 16 | 17 | // Ticket contains the user agent's data required to access their associated session. 18 | type Ticket struct { 19 | // SessionKey identifies the session. 20 | SessionKey string `json:"id"` 21 | // EncryptionKey is the data encryption key (DEK) used to encrypt the session's data. 22 | // Its size is equal to the expected key size for the used AEAD, defined in crypto.KeySize. 23 | EncryptionKey []byte `json:"dek"` 24 | crypter crypto.Crypter 25 | } 26 | 27 | func NewTicket(sessionKey string) (*Ticket, error) { 28 | encKey, err := keygen.Keygen(crypto.KeySize) 29 | if err != nil { 30 | return nil, fmt.Errorf("generate encryption key: %w", err) 31 | } 32 | 33 | return &Ticket{SessionKey: sessionKey, EncryptionKey: encKey}, nil 34 | } 35 | 36 | // Crypter returns a crypto.Crypter initialized with the session's data encryption key. 37 | func (c *Ticket) Crypter() crypto.Crypter { 38 | if c.crypter == nil { 39 | c.crypter = crypto.NewCrypter(c.EncryptionKey) 40 | } 41 | return c.crypter 42 | } 43 | 44 | // Key returns the key that identifies the session. 45 | func (c *Ticket) Key() string { 46 | return c.SessionKey 47 | } 48 | 49 | // SetCookie marshals the Ticket, encrypts the value with the given crypto.Crypter, and writes the resulting cookie to the 50 | // given http.ResponseWriter, applying any cookie.Options to the cookie itself. 51 | func (c *Ticket) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter) error { 52 | b, err := json.Marshal(c) 53 | if err != nil { 54 | return fmt.Errorf("marshalling ticket: %w", err) 55 | } 56 | 57 | return cookie.EncryptAndSet(w, cookie.Session, string(b), opts, crypter) 58 | } 59 | 60 | // getTicket returns a Ticket from the session cookie found in the http.Request, given a crypto.Crypter that 61 | // can decrypt the cookie is provided. 62 | func getTicket(r *http.Request, crypter crypto.Crypter) (*Ticket, error) { 63 | span := trace.SpanFromContext(r.Context()) 64 | span.SetAttributes(attribute.Bool("session.valid_ticket", false)) 65 | 66 | ticketJson, err := cookie.GetDecrypted(r, cookie.Session, crypter) 67 | if errors.Is(err, http.ErrNoCookie) { 68 | return nil, fmt.Errorf("ticket: cookie %w", ErrNotFound) 69 | } 70 | if errors.Is(err, cookie.ErrInvalidValue) || errors.Is(err, cookie.ErrDecrypt) { 71 | return nil, fmt.Errorf("ticket: cookie: %w: %w", ErrInvalid, err) 72 | } 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | var ticket Ticket 78 | err = json.Unmarshal([]byte(ticketJson), &ticket) 79 | if err != nil { 80 | return nil, fmt.Errorf("ticket: unmarshalling: %w", err) 81 | } 82 | 83 | span.SetAttributes(attribute.Bool("session.valid_ticket", true)) 84 | return &ticket, nil 85 | } 86 | -------------------------------------------------------------------------------- /pkg/config/sso.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | 7 | flag "github.com/spf13/pflag" 8 | ) 9 | 10 | type SSOMode string 11 | 12 | const ( 13 | SSOModeServer SSOMode = "server" 14 | SSOModeProxy SSOMode = "proxy" 15 | ) 16 | 17 | type SSO struct { 18 | Enabled bool `json:"enabled"` 19 | Domain string `json:"domain"` 20 | Mode SSOMode `json:"mode"` 21 | SessionCookieName string `json:"session-cookie-name"` 22 | ServerURL string `json:"server-url"` 23 | ServerDefaultRedirectURL string `json:"server-default-redirect-url"` 24 | } 25 | 26 | func (s SSO) IsServer() bool { 27 | return s.Enabled && s.Mode == SSOModeServer 28 | } 29 | 30 | func (s SSO) Validate(c *Config) error { 31 | if !s.Enabled { 32 | return nil 33 | } 34 | 35 | if len(c.Redis.Address) == 0 && len(c.Redis.URI) == 0 { 36 | return fmt.Errorf("at least one of %q or %q must be set when %s is set", RedisAddress, RedisURI, SSOEnabled) 37 | } 38 | 39 | if len(s.SessionCookieName) == 0 { 40 | return fmt.Errorf("%q must not be empty when %s is set", SSOSessionCookieName, SSOEnabled) 41 | } 42 | 43 | switch s.Mode { 44 | case SSOModeProxy: 45 | _, err := url.ParseRequestURI(s.ServerURL) 46 | if err != nil { 47 | return fmt.Errorf("%q must be a valid url: %w", SSOServerURL, err) 48 | } 49 | case SSOModeServer: 50 | if len(s.Domain) == 0 { 51 | return fmt.Errorf("%q cannot be empty", SSODomain) 52 | } 53 | 54 | _, err := url.ParseRequestURI(s.ServerDefaultRedirectURL) 55 | if err != nil { 56 | return fmt.Errorf("%q must be a valid url: %w", SSOServerDefaultRedirectURL, err) 57 | } 58 | default: 59 | return fmt.Errorf("%q must be one of [%q, %q]", SSOModeFlag, SSOModeServer, SSOModeProxy) 60 | } 61 | 62 | return nil 63 | } 64 | 65 | const ( 66 | SSOEnabled = "sso.enabled" 67 | SSODomain = "sso.domain" 68 | SSOModeFlag = "sso.mode" 69 | SSOServerDefaultRedirectURL = "sso.server-default-redirect-url" 70 | SSOSessionCookieName = "sso.session-cookie-name" 71 | SSOServerURL = "sso.server-url" 72 | ) 73 | 74 | func ssoFlags() { 75 | flag.Bool(SSOEnabled, false, "Enable single sign-on mode; one server acting as the OIDC Relying Party, and N proxies. The proxies delegate most endpoint operations to the server, and only implements a reverse proxy that reads the user's session data from the shared store.") 76 | flag.String(SSODomain, "", "The domain that the session cookies should be set for, usually the second-level domain name (e.g. example.com).") 77 | flag.String(SSOModeFlag, string(SSOModeServer), "The SSO mode for this instance. Must be one of 'server' or 'proxy'.") 78 | flag.String(SSOSessionCookieName, "", "Session cookie name. Must be the same across all SSO Servers and Proxies.") 79 | flag.String(SSOServerDefaultRedirectURL, "", "The URL that the SSO server should redirect to by default if a given redirect query parameter is invalid.") 80 | flag.String(SSOServerURL, "", "The URL used by the proxy to point to the SSO server instance.") 81 | } 82 | -------------------------------------------------------------------------------- /internal/crypto/crypter.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | cryptorand "crypto/rand" 5 | "encoding/base64" 6 | "fmt" 7 | 8 | "github.com/nais/liberator/pkg/keygen" 9 | log "github.com/sirupsen/logrus" 10 | "golang.org/x/crypto/chacha20poly1305" 11 | 12 | "github.com/nais/wonderwall/pkg/config" 13 | ) 14 | 15 | const ( 16 | KeySize = chacha20poly1305.KeySize 17 | 18 | // MaxPlaintextSize is set to 64 MB, which is a fairly generous limit. The implementation in x/crypto/xchacha20poly1305 has a plaintext limit to 256 GB. 19 | // We generally only handle data that is stored within a cookie or a session store, i.e. it should be reasonably small. 20 | // In most cases the data is around 4 KB or less, mostly depending on the length of the tokens returned from the identity provider. 21 | MaxPlaintextSize = 64 * 1024 * 1024 22 | ) 23 | 24 | type crypter struct { 25 | key []byte 26 | } 27 | 28 | type Crypter interface { 29 | Encrypt([]byte) ([]byte, error) 30 | Decrypt([]byte) ([]byte, error) 31 | } 32 | 33 | func NewCrypter(key []byte) Crypter { 34 | return &crypter{ 35 | key: key, 36 | } 37 | } 38 | 39 | func EncryptionKeyOrGenerate(cfg *config.Config) ([]byte, error) { 40 | key, err := base64.StdEncoding.DecodeString(cfg.EncryptionKey) 41 | if err != nil { 42 | if len(cfg.EncryptionKey) > 0 { 43 | return nil, fmt.Errorf("decode encryption key: %w", err) 44 | } 45 | } 46 | 47 | if len(key) == 0 { 48 | log.Warn("no encryption key was provided, generating a random ephemeral key; sessions will not be able to be decrypted after restart") 49 | key, err = keygen.Keygen(KeySize) 50 | if err != nil { 51 | return nil, fmt.Errorf("generate random encryption key: %w", err) 52 | } 53 | } 54 | 55 | if len(key) != chacha20poly1305.KeySize { 56 | return nil, fmt.Errorf("bad key length (expected %d, got %d)", chacha20poly1305.KeySize, len(key)) 57 | } 58 | 59 | return key, nil 60 | } 61 | 62 | // Encrypt encrypts a plaintext with XChaCha20-Poly1305. 63 | func (c *crypter) Encrypt(plaintext []byte) ([]byte, error) { 64 | aead, err := chacha20poly1305.NewX(c.key) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | plaintextSize := len(plaintext) 70 | if plaintextSize > MaxPlaintextSize { 71 | return nil, fmt.Errorf("crypter: plaintext too large (%d > %d)", plaintextSize, MaxPlaintextSize) 72 | } 73 | 74 | // Select a random nonce, and leave capacity for the ciphertext. 75 | nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+plaintextSize+aead.Overhead()) 76 | _, err = cryptorand.Read(nonce) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | return aead.Seal(nonce, nonce, plaintext, nil), nil 82 | } 83 | 84 | // Decrypt decrypts a ciphertext encrypted with XChaCha20-Poly1305. 85 | func (c *crypter) Decrypt(ciphertext []byte) ([]byte, error) { 86 | aead, err := chacha20poly1305.NewX(c.key) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | if len(ciphertext) < aead.NonceSize() { 92 | return nil, fmt.Errorf("ciphertext is too short") 93 | } 94 | 95 | // Split nonce and ciphertext. 96 | nonce, encrypted := ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():] 97 | return aead.Open(nil, nonce, encrypted, nil) 98 | } 99 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/_resources.yaml: -------------------------------------------------------------------------------- 1 | {{/* 2 | Redis resource template. 3 | Expects a dict as input with the following keys: 4 | - root: The root values, e.g "." 5 | - provider: The identity provider, e.g. "azure" or "idporten" 6 | */}} 7 | {{- define "common.redis.tpl" -}} 8 | {{- $root := .root }} 9 | {{- $provider := .provider }} 10 | {{- $name := include "aiven.redisName" (dict "root" $root "provider" $provider) }} 11 | --- 12 | apiVersion: aiven.io/v1alpha1 13 | kind: Valkey 14 | metadata: 15 | name: {{ $name }} 16 | annotations: 17 | helm.sh/resource-policy: keep 18 | labels: 19 | {{- include "wonderwall.labels" $root | nindent 4 }} 20 | wonderwall.nais.io/provider: {{ $provider }} 21 | spec: 22 | project: {{ $root.Values.aiven.project | required ".Values.aiven.project is required." }} 23 | plan: {{ $root.Values.aiven.redisPlan | required ".Values.aiven.redisPlan is required." }} 24 | maintenanceWindowDow: "sunday" 25 | maintenanceWindowTime: "02:00:00" 26 | terminationProtection: true 27 | userConfig: 28 | valkey_maxmemory_policy: "allkeys-lru" 29 | {{- end }} 30 | 31 | {{/* 32 | Prometheus ServiceIntegration resource template. 33 | Expects a dict as input with the following keys: 34 | - root: The root values, e.g "." 35 | - provider: The identity provider, e.g. "azure" or "idporten" 36 | */}} 37 | {{- define "common.serviceintegration.tpl" -}} 38 | {{- $root := .root }} 39 | {{- $provider := .provider }} 40 | {{- $name := include "aiven.serviceintegrationName" (dict "root" $root "provider" $provider) }} 41 | {{- $redisName := include "aiven.redisName" (dict "root" $root "provider" $provider) }} 42 | --- 43 | apiVersion: aiven.io/v1alpha1 44 | kind: ServiceIntegration 45 | metadata: 46 | name: {{ $name }} 47 | labels: 48 | {{- include "wonderwall.labels" $root | nindent 4 }} 49 | wonderwall.nais.io/provider: {{ $provider }} 50 | spec: 51 | sourceServiceName: {{ $redisName }} 52 | project: {{ $root.Values.aiven.project }} 53 | integrationType: prometheus 54 | destinationEndpointId: {{ $root.Values.aiven.prometheusEndpointId | required ".Values.aiven.prometheusEndpointId is required." | splitList "/" | last }} 55 | {{- end }} 56 | 57 | {{/* 58 | AivenApplication resource template. 59 | Expects a dict as input with the following keys: 60 | - root: The root values, e.g "." 61 | - provider: The identity provider, e.g. "azure" or "idporten" 62 | - access: The access level for the Redis instance, e.g. "readwrite" or "read" 63 | - secretName: The name of the secret that should be generated 64 | */}} 65 | {{- define "common.aivenapplication.tpl" -}} 66 | {{- $root := .root }} 67 | {{- $provider := .provider }} 68 | {{- $access := .access }} 69 | {{- $secretName := .secretName -}} 70 | {{- $instance := include "aiven.instanceName" (dict "provider" $provider "root" $root) }} 71 | --- 72 | apiVersion: aiven.nais.io/v1 73 | kind: AivenApplication 74 | metadata: 75 | name: {{ $instance }}-{{ $access }} 76 | labels: 77 | {{- include "wonderwall.labels" $root | nindent 4 }} 78 | wonderwall.nais.io/provider: {{ $provider }} 79 | spec: 80 | valkey: 81 | - access: {{ $access }} 82 | instance: {{ $instance }} 83 | secretName: {{ $secretName }} 84 | protected: true 85 | {{- end }} 86 | -------------------------------------------------------------------------------- /pkg/session/store_test.go: -------------------------------------------------------------------------------- 1 | package session_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | jwtlib "github.com/lestrrat-go/jwx/v3/jwt" 9 | "github.com/nais/liberator/pkg/keygen" 10 | "github.com/stretchr/testify/assert" 11 | 12 | "github.com/nais/wonderwall/internal/crypto" 13 | "github.com/nais/wonderwall/pkg/openid" 14 | "github.com/nais/wonderwall/pkg/session" 15 | ) 16 | 17 | func decryptedEqual(t *testing.T, expected, actual *session.Data) { 18 | assert.Equal(t, expected.AccessToken, actual.AccessToken) 19 | assert.Equal(t, expected.RefreshToken, actual.RefreshToken) 20 | assert.Equal(t, expected.IDToken, actual.IDToken) 21 | assert.Equal(t, expected.ExternalSessionID, actual.ExternalSessionID) 22 | assert.Equal(t, expected.Acr, actual.Acr) 23 | assert.WithinDuration(t, expected.Metadata.Session.CreatedAt, actual.Metadata.Session.CreatedAt, 0) 24 | assert.WithinDuration(t, expected.Metadata.Session.EndsAt, actual.Metadata.Session.EndsAt, 0) 25 | assert.WithinDuration(t, expected.Metadata.Tokens.ExpireAt, actual.Metadata.Tokens.ExpireAt, 0) 26 | assert.WithinDuration(t, expected.Metadata.Tokens.RefreshedAt, actual.Metadata.Tokens.RefreshedAt, 0) 27 | } 28 | 29 | func makeCrypter(t *testing.T) crypto.Crypter { 30 | key, err := keygen.Keygen(32) 31 | assert.NoError(t, err) 32 | return crypto.NewCrypter(key) 33 | } 34 | 35 | func makeData() *session.Data { 36 | idToken := jwtlib.New() 37 | idToken.Set("jti", "id-token-jti") 38 | 39 | accessToken := "some-access-token" 40 | refreshToken := "some-refresh-token" 41 | 42 | tokens := &openid.Tokens{ 43 | AccessToken: accessToken, 44 | IDToken: openid.NewIDToken("id_token", idToken), 45 | RefreshToken: refreshToken, 46 | } 47 | 48 | expiresIn := time.Hour 49 | endsIn := time.Hour 50 | 51 | metadata := session.NewMetadata(expiresIn, endsIn) 52 | return session.NewData("myid", tokens, metadata) 53 | } 54 | 55 | func write(t *testing.T, store session.Store, key string, value *session.EncryptedData) { 56 | err := store.Write(context.Background(), key, value, time.Minute) 57 | assert.NoError(t, err) 58 | } 59 | 60 | func read(t *testing.T, store session.Store, key string, encrypted *session.EncryptedData, crypter crypto.Crypter) *session.Data { 61 | result, err := store.Read(context.Background(), key) 62 | assert.NoError(t, err) 63 | assert.Equal(t, encrypted, result) 64 | 65 | decrypted, err := result.Decrypt(crypter) 66 | assert.NoError(t, err) 67 | 68 | return decrypted 69 | } 70 | 71 | func update(t *testing.T, store session.Store, key string, data *session.Data, crypter crypto.Crypter) (*session.Data, *session.EncryptedData) { 72 | data.AccessToken = "new-access-token" 73 | data.RefreshToken = "new-refresh-token" 74 | encryptedData, err := data.Encrypt(crypter) 75 | assert.NoError(t, err) 76 | 77 | err = store.Update(context.Background(), key, encryptedData) 78 | assert.NoError(t, err) 79 | 80 | return data, encryptedData 81 | } 82 | 83 | func del(t *testing.T, store session.Store, key string) { 84 | err := store.Delete(context.Background(), key) 85 | assert.NoError(t, err) 86 | 87 | result, err := store.Read(context.Background(), key) 88 | assert.Error(t, err) 89 | assert.ErrorIs(t, err, session.ErrNotFound) 90 | assert.Nil(t, result) 91 | } 92 | -------------------------------------------------------------------------------- /internal/http/request.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "strings" 7 | 8 | "github.com/nais/wonderwall/pkg/cookie" 9 | ) 10 | 11 | // IsNavigationRequest checks if the request is a navigation request by using Sec-Fetch headers. 12 | // This is used to separate between redirects for browser navigation and redirects for resource requests (e.g., Fetch or XHR). 13 | // We fall back to checking the Accept header if the browser doesn't support fetch metadata. 14 | func IsNavigationRequest(r *http.Request) bool { 15 | // we assume that navigation requests are always GET requests 16 | if r.Method != http.MethodGet { 17 | return false 18 | } 19 | 20 | mode := r.Header.Get("Sec-Fetch-Mode") 21 | dest := r.Header.Get("Sec-Fetch-Dest") 22 | if mode == "" && dest == "" { 23 | return Accepts(r, "text/html") 24 | } 25 | 26 | return mode == "navigate" && dest == "document" 27 | } 28 | 29 | func HasSecFetchMetadata(r *http.Request) bool { 30 | return r.Header.Get("Sec-Fetch-Mode") != "" && r.Header.Get("Sec-Fetch-Dest") != "" 31 | } 32 | 33 | func Accepts(r *http.Request, accepted ...string) bool { 34 | // iterate over all Accept headers 35 | for _, header := range r.Header.Values("Accept") { 36 | // iterate over all comma-separated values in a single Accept header 37 | for _, v := range strings.Split(header, ",") { 38 | v = strings.ToLower(v) 39 | v = strings.TrimSpace(v) 40 | v = strings.Split(v, ";")[0] 41 | 42 | for _, accept := range accepted { 43 | if v == accept { 44 | return true 45 | } 46 | } 47 | } 48 | } 49 | 50 | return false 51 | } 52 | 53 | // Attributes returns a map of interesting properties for the request. 54 | func Attributes(r *http.Request) map[string]any { 55 | return map[string]any{ 56 | "request.cookies": nonEmptyRequestCookies(r), 57 | "request.host": r.Host, 58 | "request.is_navigational": IsNavigationRequest(r), 59 | "request.method": r.Method, 60 | "request.path": r.URL.Path, 61 | "request.protocol": r.Proto, 62 | "request.referer": refererStripped(r), 63 | "request.sec_fetch_dest": r.Header.Get("Sec-Fetch-Dest"), 64 | "request.sec_fetch_mode": r.Header.Get("Sec-Fetch-Mode"), 65 | "request.sec_fetch_site": r.Header.Get("Sec-Fetch-Site"), 66 | "request.user_agent": r.UserAgent(), 67 | } 68 | } 69 | 70 | func nonEmptyRequestCookies(r *http.Request) string { 71 | result := make([]string, 0) 72 | 73 | for _, c := range r.Cookies() { 74 | if !isRelevantCookie(c.Name) || len(c.Value) <= 0 { 75 | continue 76 | } 77 | 78 | result = append(result, c.Name) 79 | } 80 | 81 | return strings.Join(result, ", ") 82 | } 83 | 84 | func isRelevantCookie(name string) bool { 85 | switch name { 86 | case cookie.Session, 87 | cookie.Login, 88 | cookie.Logout: 89 | return true 90 | } 91 | 92 | return false 93 | } 94 | 95 | func refererStripped(r *http.Request) string { 96 | referer := r.Referer() 97 | refererUrl, err := url.Parse(referer) 98 | if err == nil { 99 | refererUrl.RawQuery = "" 100 | refererUrl.RawFragment = "" 101 | referer = refererUrl.String() 102 | } 103 | 104 | return referer 105 | } 106 | -------------------------------------------------------------------------------- /pkg/middleware/logentry.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | "time" 8 | 9 | "github.com/go-chi/chi/v5/middleware" 10 | log "github.com/sirupsen/logrus" 11 | "go.opentelemetry.io/otel/trace" 12 | 13 | httpinternal "github.com/nais/wonderwall/internal/http" 14 | "github.com/nais/wonderwall/pkg/router/paths" 15 | ) 16 | 17 | type logger struct { 18 | Logger *log.Logger 19 | Provider string 20 | } 21 | 22 | // Logger provides a middleware that logs requests and responses. 23 | func Logger(provider string) logger { 24 | return logger{ 25 | Logger: log.StandardLogger(), 26 | Provider: provider, 27 | } 28 | } 29 | 30 | // LogEntryFrom returns a log entry from the request context. 31 | func LogEntryFrom(r *http.Request) *log.Entry { 32 | ctx := r.Context() 33 | entry, ok := ctx.Value(middleware.LogEntryCtxKey).(*logEntryAdapter) 34 | if ok { 35 | return entry.Logger 36 | } 37 | 38 | return log.NewEntry(log.StandardLogger()). 39 | WithField("fallback_logger", true). 40 | WithFields(httpinternal.Attributes(r)). 41 | WithFields(traceFields(r)) 42 | } 43 | 44 | func (l *logger) Handler(next http.Handler) http.Handler { 45 | fn := func(w http.ResponseWriter, r *http.Request) { 46 | entry := l.newLogEntry(r) 47 | ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 48 | 49 | if !strings.HasSuffix(r.URL.Path, paths.Ping) { 50 | t1 := time.Now() 51 | defer func() { 52 | entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil) 53 | }() 54 | } 55 | 56 | next.ServeHTTP(ww, middleware.WithLogEntry(r, entry)) 57 | } 58 | return http.HandlerFunc(fn) 59 | } 60 | 61 | func (l *logger) newLogEntry(r *http.Request) *logEntryAdapter { 62 | return &logEntryAdapter{ 63 | requestFields: httpinternal.Attributes(r), 64 | Logger: l.Logger.WithContext(r.Context()). 65 | WithField("provider", l.Provider). 66 | WithFields(traceFields(r)), 67 | } 68 | } 69 | 70 | // logEntryAdapter implements [middleware.LogEntry] 71 | type logEntryAdapter struct { 72 | Logger *log.Entry 73 | requestFields log.Fields 74 | } 75 | 76 | func (l *logEntryAdapter) Write(status, bytes int, _ http.Header, elapsed time.Duration, _ any) { 77 | responseFields := log.Fields{ 78 | "response_status": status, 79 | "response_bytes": bytes, 80 | "response_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0, // in milliseconds, with fractional 81 | } 82 | 83 | l.Logger.WithFields(l.requestFields). 84 | WithFields(responseFields). 85 | Debugf("response: %d %s", status, http.StatusText(status)) 86 | } 87 | 88 | func (l *logEntryAdapter) Panic(v interface{}, _ []byte) { 89 | stacktrace := "#" 90 | 91 | fields := log.Fields{ 92 | "stacktrace": stacktrace, 93 | "error": fmt.Sprintf("%+v", v), 94 | } 95 | 96 | l.Logger = l.Logger.WithFields(fields) 97 | } 98 | 99 | func traceFields(r *http.Request) log.Fields { 100 | fields := log.Fields{} 101 | span := trace.SpanFromContext(r.Context()) 102 | if span.SpanContext().HasTraceID() { 103 | fields["trace_id"] = span.SpanContext().TraceID().String() 104 | } else { 105 | fields["correlation_id"] = middleware.GetReqID(r.Context()) 106 | } 107 | 108 | return fields 109 | } 110 | -------------------------------------------------------------------------------- /pkg/session/session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/nais/wonderwall/internal/crypto" 11 | "github.com/nais/wonderwall/pkg/cookie" 12 | "github.com/nais/wonderwall/pkg/openid" 13 | ) 14 | 15 | var ( 16 | ErrInactive = errors.New("is inactive") 17 | ErrInvalid = errors.New("session is invalid") 18 | ErrInvalidExternal = errors.New("session has invalid state at identity provider") 19 | ErrNotFound = errors.New("not found") 20 | ) 21 | 22 | // Reader knows how to read a session. 23 | type Reader interface { 24 | // Get returns the session for a given http.Request, or an error if the session is invalid or not found. 25 | Get(r *http.Request) (*Session, error) 26 | } 27 | 28 | // Writer knows how to create, update and delete a session. 29 | type Writer interface { 30 | // Create creates and stores a session in the Store. 31 | Create(r *http.Request, tokens *openid.Tokens, sessionLifetime time.Duration) (*Session, error) 32 | // Delete deletes a session for a given Session. 33 | Delete(ctx context.Context, session *Session) error 34 | // DeleteForExternalID deletes a session for a given external session ID (e.g. front-channel logout). 35 | DeleteForExternalID(ctx context.Context, id string) error 36 | // Refresh refreshes the user's tokens and returns the updated session. If the session should not be 37 | // refreshed, it will return the existing session without modifications. 38 | Refresh(r *http.Request, sess *Session) (*Session, error) 39 | } 40 | 41 | // Manager is both a Reader and a Writer. 42 | type Manager interface { 43 | Reader 44 | Writer 45 | 46 | // GetOrRefresh returns the session for a given http.Request. If the tokens within the session are expired and the 47 | // session is still valid, it will automatically attempt to refresh and update the session. 48 | GetOrRefresh(r *http.Request) (*Session, error) 49 | } 50 | 51 | type Session struct { 52 | data *Data 53 | ticket *Ticket 54 | } 55 | 56 | func (in *Session) AccessToken() (string, error) { 57 | if in.data != nil && in.data.HasActiveAccessToken() { 58 | return in.data.AccessToken, nil 59 | } 60 | 61 | return "", fmt.Errorf("%w: access token is expired", ErrInvalid) 62 | } 63 | 64 | func (in *Session) Acr() string { 65 | if in.data != nil { 66 | return in.data.Acr 67 | } 68 | return "" 69 | } 70 | 71 | func (in *Session) ExternalSessionID() string { 72 | if in.data != nil { 73 | return in.data.ExternalSessionID 74 | } 75 | 76 | return "" 77 | } 78 | 79 | func (in *Session) IDToken() string { 80 | return in.data.IDToken 81 | } 82 | 83 | func (in *Session) MetadataVerbose() MetadataVerbose { 84 | return in.data.Metadata.Verbose() 85 | } 86 | 87 | func (in *Session) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter) error { 88 | return in.ticket.SetCookie(w, opts, crypter) 89 | } 90 | 91 | func (in *Session) canRefresh() bool { 92 | return in.data != nil && in.data.HasRefreshToken() && !in.data.Metadata.IsRefreshOnCooldown() 93 | } 94 | 95 | func (in *Session) encrypt() (*EncryptedData, error) { 96 | return in.data.Encrypt(in.ticket.Crypter()) 97 | } 98 | 99 | func (in *Session) key() string { 100 | return in.ticket.Key() 101 | } 102 | 103 | func (in *Session) shouldRefresh() bool { 104 | return in.data != nil && in.data.Metadata.ShouldRefresh() 105 | } 106 | 107 | func NewSession(data *Data, ticket *Ticket) *Session { 108 | return &Session{data: data, ticket: ticket} 109 | } 110 | -------------------------------------------------------------------------------- /docs/sessions.md: -------------------------------------------------------------------------------- 1 | # Session Management 2 | 3 | When a user authenticates themselves, they receive a session. Sessions are stored server-side; we only store a session identifier at the end-user's user agent. 4 | 5 | A session has three states: 6 | 7 | - _active_ - the session is valid 8 | - _inactive_ - the session has reached the _inactivity timeout_ and is considered invalid 9 | - _expired_ - the session has reached its _maximum lifetime_ and is considered invalid 10 | 11 | Requests with an _invalid_ session are considered _unauthenticated_. 12 | 13 | ## Session Metadata 14 | 15 | User agents can access their own session metadata by using [the `/oauth2/session` endpoint](endpoints.md#oauth2session). 16 | 17 | ## Session Expiry 18 | 19 | Every session has a maximum lifetime. 20 | The lifetime is indicated by the `session.ends_at` and `session.ends_in_seconds` fields in the session metadata. 21 | 22 | When the session reaches the maximum lifetime, it is considered to be _expired_, after which the user is essentially unauthenticated. 23 | A new session must be acquired by redirecting the user to [the `/oauth2/login` endpoint](endpoints.md#oauth2login) again. 24 | 25 | The maximum lifetime can be configured with the `session.max-lifetime` flag. 26 | 27 | ## Session Refreshing 28 | 29 | The tokens within the session will usually expire before the session itself. 30 | This is indicated by the `tokens.expire_at` and `tokens.expire_in_seconds` fields in the session metadata. 31 | 32 | If you've configured a session lifetime that is longer than the token expiry, you'll probably want to _refresh_ the tokens to avoid redirecting end-users to the `/oauth2/login` endpoint whenever the access tokens have expired. 33 | 34 | ### Automatic vs Manual Refresh 35 | 36 | The behaviour for refreshing depends on the [runtime mode](configuration.md#modes) for Wonderwall. 37 | 38 | In standalone mode, tokens are automatically refreshed. 39 | Tokens will at the _earliest_ automatically be renewed 5 minutes before they expire. 40 | If the token already _has_ expired, a refresh attempt is still automatically triggered as long as the session itself not has ended or is marked as inactive. 41 | 42 | Automatic refreshes happens whenever the end-user visits or requests any path that is proxied to the upstream application. 43 | 44 | In SSO mode, tokens can not be automatically refreshed. They must be refreshed by performing a request to [the `/oauth2/session/refresh` endpoint](endpoints.md#oauth2sessionrefresh). 45 | 46 | ## Session Inactivity 47 | 48 | A session can be marked as _inactive_ before it _expires_ (reaches the maximum lifetime). 49 | This happens if the time since the last _refresh_ exceeds the given _inactivity timeout_. 50 | 51 | An _inactive_ session _cannot_ be refreshed; a new session must be acquired by redirecting the user to the `/oauth2/login` endpoint. 52 | This is useful if you want to ensure that an end-user can re-authenticate with the identity provider if they've been gone from an authenticated session for some time. 53 | 54 | Inactivity support is enabled with the `session.inactivity` option. 55 | 56 | The activity state of the session is indicated by the `session.active` field in the session metadata. 57 | 58 | The time until the session will be marked as inactive are indicated by the `session.timeout_at` and `session.timeout_in_seconds` fields in the session metadata. 59 | 60 | The timeout is configured with `session.inactivity-timeout`. 61 | If this timeout is shorter than the token expiry, the session metadata fields `tokens.expire_at` and `tokens.expire_in_seconds` will be reduced accordingly to reflect the inactivity timeout. 62 | -------------------------------------------------------------------------------- /charts/wonderwall/Feature.yaml: -------------------------------------------------------------------------------- 1 | dependencies: 2 | - allOf: 3 | - aiven-operator 4 | - aivenator 5 | - mutilator 6 | - replicator 7 | environmentKinds: 8 | - tenant 9 | - legacy 10 | timeout: "1800s" 11 | values: 12 | aiven.project: 13 | description: Aiven project for Redis. 14 | computed: 15 | template: '"{{ .Env.aiven_project }}"' 16 | aiven.redisPlan: 17 | description: Aiven plan for Redis. 18 | required: true 19 | config: 20 | type: string 21 | aiven.prometheusEndpointId: 22 | description: Aiven Prometheus integration endpoint ID. 23 | computed: 24 | template: '"{{ .Env.aiven_prometheus_endpoint_id }}"' 25 | azure.enabled: 26 | description: Enable Azure AD. Requires Azurerator to be enabled. 27 | config: 28 | type: bool 29 | azure.forwardAuth.enabled: 30 | description: Enables forward auth server. Requires Azurerator and loadbalancer-fa to be enabled. 31 | config: 32 | type: bool 33 | azure.forwardAuth.groupIds: 34 | description: Additional group IDs to grant access to 35 | config: 36 | type: string_array 37 | azure.forwardAuth.sessionCookieEncryptionKey: 38 | description: Cookie encryption key, 256 bits (e.g. 32 ASCII characters) encoded with standard base64. 39 | config: 40 | type: string 41 | secret: true 42 | azure.forwardAuth.ssoDomain: 43 | description: Cookie domain for forward auth 44 | config: 45 | type: string 46 | azure.forwardAuth.ssoDefaultRedirectURL: 47 | description: Default redirect URL for forward auth 48 | config: 49 | type: string 50 | idporten.enabled: 51 | description: Enable ID-porten. Requires Digdirator to be enabled. 52 | config: 53 | type: bool 54 | idporten.openidResourceIndicator: 55 | description: Resource indicator for audience-restricted tokens. 56 | config: 57 | type: string 58 | idporten.openidPostLogoutRedirectURL: 59 | description: Where to redirect the user after global logout. 60 | config: 61 | type: string 62 | idporten.replicasMax: 63 | description: Maximum replicas for SSO server. 64 | config: 65 | type: int 66 | idporten.replicasMin: 67 | description: Minimum replicas for SSO server. 68 | config: 69 | type: int 70 | idporten.sessionCookieEncryptionKey: 71 | description: Cookie encryption key, 256 bits (e.g. 32 ASCII characters) encoded with standard base64. 72 | config: 73 | type: string 74 | secret: true 75 | idporten.sessionCookieName: 76 | description: Cookie name for SSO sessions. 77 | config: 78 | type: string 79 | idporten.ssoDefaultRedirectURL: 80 | description: Fallback URL for invalid SSO redirects. 81 | config: 82 | type: string 83 | idporten.ssoDomain: 84 | description: Allowed domain for SSO (for cookies, CORS and redirect URL validation). 85 | config: 86 | type: string 87 | idporten.ssoServerHost: 88 | description: Host for SSO server. 89 | config: 90 | type: string 91 | idporten.ingressClassName: 92 | description: Ingress class for SSO server. 93 | config: 94 | type: string 95 | image.tag: 96 | config: 97 | type: string 98 | openid.wellKnownUrl: 99 | description: Well-known URL to generic identity provider. Optional. Only needed if a default global provider is desired. 100 | config: 101 | type: string 102 | ignoreKind: 103 | - legacy 104 | resourceSuffix: 105 | description: Suffix for resources that may conflict in parallel environments. 106 | config: 107 | type: string 108 | -------------------------------------------------------------------------------- /pkg/openid/provider/provider.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/lestrrat-go/httprc/v3" 10 | "github.com/lestrrat-go/jwx/v3/jwa" 11 | "github.com/lestrrat-go/jwx/v3/jwk" 12 | "github.com/nais/wonderwall/internal/o11y/otel" 13 | openidconfig "github.com/nais/wonderwall/pkg/openid/config" 14 | "go.opentelemetry.io/otel/attribute" 15 | ) 16 | 17 | const ( 18 | JwkMinimumRefreshInterval = 5 * time.Second 19 | ) 20 | 21 | type JwksProvider struct { 22 | config openidconfig.Provider 23 | jwksCache *jwk.Cache 24 | jwksLock *jwksLock 25 | } 26 | 27 | type jwksLock struct { 28 | lastRefresh time.Time 29 | sync.Mutex 30 | } 31 | 32 | func (p *JwksProvider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) { 33 | url := p.config.JwksURI() 34 | set, err := p.jwksCache.Lookup(ctx, url) 35 | if err != nil { 36 | return nil, fmt.Errorf("provider: fetching jwks: %w", err) 37 | } 38 | 39 | set, err = ensureJwkSetWithAlg(set, p.config.IDTokenSigningAlg()) 40 | if err != nil { 41 | return nil, fmt.Errorf("provider: mutating jwks: %w", err) 42 | } 43 | 44 | return &set, nil 45 | } 46 | 47 | func (p *JwksProvider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) { 48 | ctx, span := otel.StartSpan(ctx, "JwksProvider.RefreshPublicJwkSet") 49 | defer span.End() 50 | p.jwksLock.Lock() 51 | defer p.jwksLock.Unlock() 52 | 53 | // redirect to cache if recently refreshed to avoid overwhelming provider 54 | diff := time.Since(p.jwksLock.lastRefresh) 55 | if diff < JwkMinimumRefreshInterval { 56 | span.SetAttributes(attribute.Bool("jwks.cooldown", true)) 57 | return p.GetPublicJwkSet(ctx) 58 | } 59 | 60 | p.jwksLock.lastRefresh = time.Now() 61 | 62 | url := p.config.JwksURI() 63 | set, err := p.jwksCache.Refresh(ctx, url) 64 | if err != nil { 65 | return nil, fmt.Errorf("provider: refreshing jwks: %w", err) 66 | } 67 | 68 | set, err = ensureJwkSetWithAlg(set, p.config.IDTokenSigningAlg()) 69 | if err != nil { 70 | return nil, fmt.Errorf("provider: mutating jwks: %w", err) 71 | } 72 | 73 | span.SetAttributes(attribute.Bool("jwks.refreshed", true)) 74 | return &set, nil 75 | } 76 | 77 | func NewJwksProvider(ctx context.Context, openidCfg openidconfig.Config) (*JwksProvider, error) { 78 | providerCfg := openidCfg.Provider() 79 | 80 | uri := providerCfg.JwksURI() 81 | cache, err := jwk.NewCache(ctx, httprc.NewClient()) 82 | if err != nil { 83 | return nil, fmt.Errorf("creating jwks cache: %w", err) 84 | } 85 | 86 | if err := cache.Register(ctx, uri); err != nil { 87 | return nil, fmt.Errorf("registering jwks provider uri to cache: %w", err) 88 | } 89 | 90 | return &JwksProvider{ 91 | config: providerCfg, 92 | jwksCache: cache, 93 | jwksLock: &jwksLock{}, 94 | }, nil 95 | } 96 | 97 | func ensureJwkSetWithAlg(set jwk.Set, expectedAlg jwa.KeyAlgorithm) (jwk.Set, error) { 98 | for i := 0; i < set.Len(); i++ { 99 | key, ok := set.Key(i) 100 | if !ok { 101 | continue 102 | } 103 | 104 | alg, ok := key.Algorithm() 105 | if ok { 106 | // drop keys with "alg=none" 107 | if alg == jwa.NoSignature() { 108 | if err := set.RemoveKey(key); err != nil { 109 | return nil, fmt.Errorf("removing key: %w", err) 110 | } 111 | } 112 | 113 | // don't mutate keys with a valid algorithm 114 | continue 115 | } 116 | 117 | // set "alg" to expected algorithm for keys that don't have one set 118 | if err := key.Set(jwk.AlgorithmKey, expectedAlg); err != nil { 119 | return nil, fmt.Errorf("setting key algorithm: %w", err) 120 | } 121 | } 122 | 123 | return set, nil 124 | } 125 | -------------------------------------------------------------------------------- /pkg/middleware/prometheus.go: -------------------------------------------------------------------------------- 1 | // This code was originally written by Rene Zbinden and modified by Vladimir Konovalov. 2 | // Copied from https://github.com/766b/chi-prometheus and further adapted. 3 | 4 | package middleware 5 | 6 | import ( 7 | "net/http" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | chi_middleware "github.com/go-chi/chi/v5/middleware" 13 | "github.com/prometheus/client_golang/prometheus" 14 | 15 | "github.com/nais/wonderwall/pkg/metrics" 16 | "github.com/nais/wonderwall/pkg/router/paths" 17 | ) 18 | 19 | var defaultBuckets = []float64{.001, .01, .05, .1, .5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5} 20 | 21 | const ( 22 | serviceName = "wonderwall" 23 | reqsName = "requests_total" 24 | latencyName = "request_duration_seconds" 25 | ) 26 | 27 | // PrometheusMiddleware is a handler that exposes prometheus metrics for the number of requests, 28 | // the latency and the response size, partitioned by status code, method and HTTP path. 29 | type PrometheusMiddleware struct { 30 | reqs *prometheus.CounterVec 31 | latency *prometheus.HistogramVec 32 | } 33 | 34 | // Prometheus returns a new PrometheusMiddleware handler. 35 | func Prometheus(provider string, buckets ...float64) *PrometheusMiddleware { 36 | var m PrometheusMiddleware 37 | m.reqs = prometheus.NewCounterVec( 38 | prometheus.CounterOpts{ 39 | Name: reqsName, 40 | Help: "How many HTTP requests processed, partitioned by status code, method and HTTP path.", 41 | ConstLabels: prometheus.Labels{"service": serviceName, metrics.LabelProvider: provider}, 42 | }, 43 | []string{"code", "method", "path", "host"}, 44 | ) 45 | 46 | if len(buckets) == 0 { 47 | buckets = defaultBuckets 48 | } 49 | m.latency = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 50 | Name: latencyName, 51 | Help: "How long it took to process the request, partitioned by status code, method and HTTP path.", 52 | ConstLabels: prometheus.Labels{"service": serviceName, metrics.LabelProvider: provider}, 53 | Buckets: buckets, 54 | }, 55 | []string{"code", "method", "path", "host"}, 56 | ) 57 | 58 | prometheus.Register(m.reqs) 59 | prometheus.Register(m.latency) 60 | 61 | return &m 62 | } 63 | 64 | func (m *PrometheusMiddleware) Initialize(path, method string, code int) { 65 | m.reqs.WithLabelValues( 66 | strconv.Itoa(code), 67 | method, 68 | path, 69 | ) 70 | } 71 | 72 | func (m *PrometheusMiddleware) Handler(next http.Handler) http.Handler { 73 | relevantPaths := map[string]bool{ 74 | paths.OAuth2 + paths.Login: true, 75 | paths.OAuth2 + paths.LoginCallback: true, 76 | paths.OAuth2 + paths.Logout: true, 77 | paths.OAuth2 + paths.LogoutCallback: true, 78 | paths.OAuth2 + paths.LogoutFrontChannel: true, 79 | paths.OAuth2 + paths.LogoutLocal: true, 80 | paths.OAuth2 + paths.Ping: false, 81 | paths.OAuth2 + paths.Session: true, 82 | paths.OAuth2 + paths.Session + paths.Refresh: true, 83 | paths.OAuth2 + paths.Session + paths.ForwardAuth: true, 84 | } 85 | 86 | fn := func(w http.ResponseWriter, r *http.Request) { 87 | found := false 88 | for path, relevant := range relevantPaths { 89 | if strings.HasSuffix(r.URL.Path, path) && relevant { 90 | found = true 91 | break 92 | } 93 | } 94 | 95 | if !found { 96 | next.ServeHTTP(w, r) 97 | return 98 | } 99 | 100 | start := time.Now() 101 | ww := chi_middleware.NewWrapResponseWriter(w, r.ProtoMajor) 102 | next.ServeHTTP(ww, r) 103 | statusCode := strconv.Itoa(ww.Status()) 104 | duration := time.Since(start) 105 | m.reqs.WithLabelValues(statusCode, r.Method, r.URL.Path, r.Host).Inc() 106 | m.latency.WithLabelValues(statusCode, r.Method, r.URL.Path, r.Host).Observe(duration.Seconds()) 107 | } 108 | return http.HandlerFunc(fn) 109 | } 110 | -------------------------------------------------------------------------------- /pkg/openid/client/login_callback.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | 9 | "github.com/nais/wonderwall/internal/o11y/otel" 10 | "github.com/nais/wonderwall/pkg/openid" 11 | "go.opentelemetry.io/otel/attribute" 12 | "go.opentelemetry.io/otel/trace" 13 | ) 14 | 15 | var ( 16 | ErrCallbackIdentityProvider = errors.New("callback: identity provider error") 17 | ErrCallbackInvalidCookie = errors.New("callback: invalid cookie") 18 | ErrCallbackInvalidState = errors.New("callback: invalid state") 19 | ErrCallbackInvalidIssuer = errors.New("callback: invalid issuer") 20 | ErrCallbackRedeemTokens = errors.New("callback: redeeming tokens") 21 | ) 22 | 23 | func (c *Client) LoginCallback(r *http.Request, cookie *openid.LoginCookie) (*openid.Tokens, error) { 24 | span := trace.SpanFromContext(r.Context()) 25 | if cookie == nil { 26 | return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidCookie, "cookie is nil") 27 | } 28 | 29 | query := r.URL.Query() 30 | 31 | if oauthError := query.Get("error"); len(oauthError) > 0 { 32 | oauthErrorDescription := query.Get("error_description") 33 | span.SetAttributes(attribute.String("login.oauth_error", oauthError), attribute.String("login.oauth_error_description", oauthErrorDescription)) 34 | return nil, fmt.Errorf("%w: %s: %s", ErrCallbackIdentityProvider, oauthError, oauthErrorDescription) 35 | } 36 | 37 | if err := openid.StateMismatchError(query, cookie.State); err != nil { 38 | span.SetAttributes(attribute.Bool("login.state_mismatch", true)) 39 | return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidState, err) 40 | } 41 | 42 | if err := c.authorizationServerIssuerIdentification(query.Get("iss")); err != nil { 43 | span.SetAttributes(attribute.Bool("login.issuer_mismatch", true)) 44 | return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidIssuer, err) 45 | } 46 | 47 | tokens, err := c.redeemTokens(r.Context(), query.Get("code"), cookie) 48 | if err != nil { 49 | return nil, fmt.Errorf("%w: %s", ErrCallbackRedeemTokens, err) 50 | } 51 | 52 | return tokens, nil 53 | } 54 | 55 | // Verify iss parameter if provider supports RFC 9207 - OAuth 2.0 Authorization Server Issuer Identification 56 | func (c *Client) authorizationServerIssuerIdentification(iss string) error { 57 | if !c.cfg.Provider().AuthorizationResponseIssParameterSupported() { 58 | return nil 59 | } 60 | 61 | if len(iss) == 0 { 62 | return fmt.Errorf("missing issuer parameter") 63 | } 64 | 65 | expectedIss := c.cfg.Provider().Issuer() 66 | if iss != expectedIss { 67 | return fmt.Errorf("issuer mismatch: expected %q, got %q", expectedIss, iss) 68 | } 69 | 70 | return nil 71 | } 72 | 73 | func (c *Client) redeemTokens(ctx context.Context, code string, cookie *openid.LoginCookie) (*openid.Tokens, error) { 74 | ctx, span := otel.StartSpan(ctx, "Client.RedeemTokens") 75 | defer span.End() 76 | clientAuth, err := c.ClientAuthenticationParams() 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | payload := openid.ExchangeAuthorizationCodeParams( 82 | c.cfg.Client().ClientID(), 83 | code, 84 | cookie.CodeVerifier, 85 | cookie.RedirectURI, 86 | ).With(clientAuth).AuthCodeOptions() 87 | 88 | rawTokens, err := c.AuthCodeGrant(ctx, code, payload) 89 | if err != nil { 90 | return nil, fmt.Errorf("exchanging authorization code for token: %w", err) 91 | } 92 | span.SetAttributes(attribute.Int64("oauth.token_expires_in_seconds", rawTokens.ExpiresIn)) 93 | 94 | jwkSet, err := c.jwksProvider.GetPublicJwkSet(ctx) 95 | if err != nil { 96 | return nil, fmt.Errorf("getting jwks: %w", err) 97 | } 98 | 99 | tokens, err := openid.NewTokens(rawTokens, jwkSet, c.cfg, cookie) 100 | if err != nil { 101 | // JWKS might not be up to date, so we'll want to force a refresh for the next attempt 102 | _, _ = c.jwksProvider.RefreshPublicJwkSet(ctx) 103 | return nil, fmt.Errorf("parsing tokens: %w", err) 104 | } 105 | 106 | return tokens, nil 107 | } 108 | -------------------------------------------------------------------------------- /pkg/ingress/ingress.go: -------------------------------------------------------------------------------- 1 | package ingress 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "strings" 8 | 9 | "github.com/nais/wonderwall/pkg/config" 10 | ) 11 | 12 | const ( 13 | XForwardedHost = "X-Forwarded-Host" 14 | ) 15 | 16 | type Ingresses struct { 17 | ingressMap map[string]Ingress 18 | hosts []string 19 | paths []string 20 | urls []string 21 | } 22 | 23 | func ParseIngresses(cfg *config.Config) (*Ingresses, error) { 24 | ingresses := cfg.Ingresses 25 | if len(ingresses) == 0 { 26 | return nil, fmt.Errorf("must have at least 1 ingress") 27 | } 28 | 29 | seen := make(map[string]Ingress) 30 | 31 | for _, raw := range ingresses { 32 | ingress, err := ParseIngress(raw) 33 | if err != nil { 34 | return nil, fmt.Errorf("parsing ingress '%s': %w", raw, err) 35 | } 36 | 37 | if _, found := seen[ingress.String()]; !found { 38 | seen[ingress.String()] = *ingress 39 | } 40 | } 41 | 42 | return &Ingresses{ 43 | ingressMap: seen, 44 | hosts: mapIngresses(seen, Ingress.Host), 45 | paths: mapIngresses(seen, Ingress.Path), 46 | urls: mapIngresses(seen, Ingress.String), 47 | }, nil 48 | } 49 | 50 | func (i *Ingresses) Hosts() []string { 51 | return i.hosts 52 | } 53 | 54 | func (i *Ingresses) Paths() []string { 55 | return i.paths 56 | } 57 | 58 | func (i *Ingresses) Strings() []string { 59 | return i.urls 60 | } 61 | 62 | func (i *Ingresses) MatchingIngress(r *http.Request) (Ingress, bool) { 63 | for _, ingress := range i.ingressMap { 64 | hostMatch := ingress.Host() == r.Host || ingress.Host() == r.Header.Get(XForwardedHost) 65 | pathMatch := ingress.Path() == i.MatchingPath(r) 66 | 67 | if hostMatch && pathMatch { 68 | return ingress, true 69 | } 70 | } 71 | 72 | return Ingress{}, false 73 | } 74 | 75 | func (i *Ingresses) MatchingPath(r *http.Request) string { 76 | reqPath := r.URL.Path 77 | result := "" 78 | 79 | for _, p := range i.Paths() { 80 | if len(p) == 0 { 81 | continue 82 | } 83 | 84 | if strings.HasPrefix(reqPath, p) && len(p) > len(result) { 85 | result = p 86 | } 87 | } 88 | 89 | return result 90 | } 91 | 92 | func (i *Ingresses) Single() Ingress { 93 | var res Ingress 94 | 95 | for _, v := range i.ingressMap { 96 | res = v 97 | break 98 | } 99 | 100 | return res 101 | } 102 | 103 | func mapIngresses(ingresses map[string]Ingress, fn func(i Ingress) string) []string { 104 | seen := make(map[string]bool, 0) 105 | result := make([]string, 0) 106 | 107 | for _, ingress := range ingresses { 108 | value := fn(ingress) 109 | 110 | if _, found := seen[value]; !found { 111 | seen[value] = true 112 | result = append(result, value) 113 | } 114 | } 115 | 116 | return result 117 | } 118 | 119 | func ParseIngress(ingress string) (*Ingress, error) { 120 | if len(ingress) == 0 { 121 | return nil, fmt.Errorf("ingress cannot be empty") 122 | } 123 | 124 | u, err := url.ParseRequestURI(ingress) 125 | if err != nil { 126 | return nil, err 127 | } 128 | 129 | if len(u.Host) == 0 { 130 | return nil, fmt.Errorf("must have non-empty host") 131 | } 132 | 133 | err = mustScheme(u) 134 | if err != nil { 135 | return nil, err 136 | } 137 | 138 | u.Path = strings.TrimRight(u.Path, "/") 139 | 140 | return &Ingress{ 141 | URL: u, 142 | }, nil 143 | } 144 | 145 | func mustScheme(u *url.URL) error { 146 | validSchemes := []string{"http", "https"} 147 | 148 | valid := false 149 | for _, scheme := range validSchemes { 150 | if u.Scheme == scheme { 151 | valid = true 152 | } 153 | } 154 | 155 | if !valid { 156 | return fmt.Errorf("invalid URL scheme, must be one of %s", validSchemes) 157 | } 158 | 159 | return nil 160 | } 161 | 162 | type Ingress struct { 163 | *url.URL 164 | } 165 | 166 | func (i Ingress) Path() string { 167 | return i.URL.Path 168 | } 169 | 170 | func (i Ingress) Host() string { 171 | return i.URL.Host 172 | } 173 | 174 | func (i Ingress) String() string { 175 | return i.URL.String() 176 | } 177 | 178 | func (i Ingress) NewURL() *url.URL { 179 | u := *i.URL 180 | return &u 181 | } 182 | -------------------------------------------------------------------------------- /cmd/wonderwall/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/KimMachineGun/automemlimit/memlimit" 9 | "github.com/nais/wonderwall/internal/crypto" 10 | "github.com/nais/wonderwall/internal/o11y/otel" 11 | "github.com/nais/wonderwall/pkg/config" 12 | "github.com/nais/wonderwall/pkg/cookie" 13 | "github.com/nais/wonderwall/pkg/handler" 14 | "github.com/nais/wonderwall/pkg/metrics" 15 | openidconfig "github.com/nais/wonderwall/pkg/openid/config" 16 | "github.com/nais/wonderwall/pkg/openid/provider" 17 | "github.com/nais/wonderwall/pkg/router" 18 | "github.com/nais/wonderwall/pkg/server" 19 | log "github.com/sirupsen/logrus" 20 | ) 21 | 22 | func main() { 23 | err := run() 24 | if err != nil { 25 | log.Fatalf("Fatal error: %s", err) 26 | } 27 | } 28 | 29 | func run() error { 30 | cfg, err := config.Initialize() 31 | if err != nil { 32 | return err 33 | } 34 | 35 | if _, err := memlimit.SetGoMemLimitWithOpts(); err != nil { 36 | log.Debugf("setting GOMEMLIMIT: %+v", err) 37 | } 38 | 39 | key, err := crypto.EncryptionKeyOrGenerate(cfg) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | crypt := crypto.NewCrypter(key) 45 | 46 | ctx, cancel := context.WithCancel(context.Background()) 47 | defer cancel() 48 | 49 | if cfg.Cookie.Prefix != cookie.DefaultPrefix { 50 | cookie.ConfigureCookieNamesWithPrefix(cfg.Cookie.Prefix) 51 | } 52 | 53 | if cfg.SSO.Enabled { 54 | cookie.ConfigureCookieNamesWithPrefix(cfg.SSO.SessionCookieName) 55 | cookie.Session = cfg.SSO.SessionCookieName 56 | } 57 | 58 | if cfg.OpenTelemetry.Enabled { 59 | otelShutdown, err := otel.Setup(ctx, cfg) 60 | if err != nil { 61 | return fmt.Errorf("initializing OpenTelemetry: %w", err) 62 | } 63 | defer otelShutdown(ctx) 64 | } 65 | 66 | var src router.Source 67 | 68 | if cfg.SSO.Enabled { 69 | switch cfg.SSO.Mode { 70 | case config.SSOModeServer: 71 | src, err = ssoServer(ctx, cfg, crypt) 72 | case config.SSOModeProxy: 73 | src, err = ssoProxy(cfg, crypt) 74 | default: 75 | return fmt.Errorf("invalid SSO mode: %q", cfg.SSO.Mode) 76 | } 77 | } else { 78 | src, err = standalone(ctx, cfg, crypt) 79 | } 80 | if err != nil { 81 | return fmt.Errorf("initializing routing handler: %w", err) 82 | } 83 | 84 | r := router.New(src, cfg) 85 | 86 | if cfg.MetricsBindAddress != "" { 87 | go func() { 88 | log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress) 89 | err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider) 90 | if err != nil { 91 | log.Fatalf("fatal: metrics server error: %s", err) 92 | } 93 | }() 94 | } 95 | 96 | if cfg.ProbeBindAddress != "" { 97 | go func() { 98 | mux := http.NewServeMux() 99 | healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 100 | w.WriteHeader(http.StatusOK) 101 | w.Write([]byte("ok")) 102 | }) 103 | mux.HandleFunc("/", healthz) 104 | mux.HandleFunc("/healthz", healthz) 105 | log.Debugf("probe: listening on %s", cfg.ProbeBindAddress) 106 | err := http.ListenAndServe(cfg.ProbeBindAddress, mux) 107 | if err != nil { 108 | log.Fatalf("fatal: probe server error: %s", err) 109 | } 110 | }() 111 | } 112 | 113 | return server.Start(cfg, r) 114 | } 115 | 116 | func standalone(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.Standalone, error) { 117 | openidConfig, err := openidconfig.NewConfig(cfg) 118 | if err != nil { 119 | return nil, err 120 | } 121 | 122 | jwksProvider, err := provider.NewJwksProvider(ctx, openidConfig) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | return handler.NewStandalone(cfg, jwksProvider, openidConfig, crypt) 128 | } 129 | 130 | func ssoServer(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.SSOServer, error) { 131 | h, err := standalone(ctx, cfg, crypt) 132 | if err != nil { 133 | return nil, err 134 | } 135 | 136 | return handler.NewSSOServer(cfg, h) 137 | } 138 | 139 | func ssoProxy(cfg *config.Config, crypt crypto.Crypter) (*handler.SSOProxy, error) { 140 | return handler.NewSSOProxy(cfg, crypt) 141 | } 142 | -------------------------------------------------------------------------------- /pkg/url/validator.go: -------------------------------------------------------------------------------- 1 | package url 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "regexp" 7 | "strings" 8 | 9 | mw "github.com/nais/wonderwall/pkg/middleware" 10 | ) 11 | 12 | // Used to check final redirects are not susceptible to open redirects. 13 | // Matches //, /\ and both of these with whitespace in between (eg / / or / \). 14 | var invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) 15 | 16 | var _ Validator = &AbsoluteValidator{} 17 | 18 | type Validator interface { 19 | IsValidRedirect(r *http.Request, redirect string) bool 20 | } 21 | 22 | type AbsoluteValidator struct { 23 | allowedDomains []string 24 | } 25 | 26 | func NewAbsoluteValidator(allowedDomains []string) *AbsoluteValidator { 27 | return &AbsoluteValidator{allowedDomains: allowedDomains} 28 | } 29 | 30 | // IsValidRedirect validates that the given redirect string is a valid absolute URL. 31 | // It must use the 'http' or 'https' scheme. 32 | // It must point to a host that matches the configured list of allowed domains. 33 | func (v *AbsoluteValidator) IsValidRedirect(r *http.Request, redirect string) bool { 34 | u, ok := parsableRequestURI(r, redirect) 35 | if !ok { 36 | return false 37 | } 38 | 39 | if !isRelativeURL(u) && isValidScheme(u) && isAllowedHost(u, v.allowedDomains) { 40 | return true 41 | } 42 | 43 | if isRelativeURL(u) { 44 | mw.LogEntryFrom(r).Infof("validator: not an absolute URL") 45 | return false 46 | } 47 | 48 | if !isValidScheme(u) { 49 | mw.LogEntryFrom(r).Infof("validator: invalid scheme; must be one of ['http', 'https']") 50 | return false 51 | } 52 | 53 | if !isAllowedHost(u, v.allowedDomains) { 54 | mw.LogEntryFrom(r).Infof("validator: host does not match any allowlisted domains: %q", v.allowedDomains) 55 | return false 56 | } 57 | 58 | return false 59 | } 60 | 61 | var _ Validator = &RelativeValidator{} 62 | 63 | type RelativeValidator struct{} 64 | 65 | func NewRelativeValidator() *RelativeValidator { 66 | return &RelativeValidator{} 67 | } 68 | 69 | // IsValidRedirect validates that the given redirect string is a valid relative URL. 70 | // It must be an absolute path (i.e. has a leading '/'). 71 | func (v *RelativeValidator) IsValidRedirect(r *http.Request, redirect string) bool { 72 | u, ok := parsableRequestURI(r, redirect) 73 | if !ok { 74 | return false 75 | } 76 | 77 | if isRelativeURL(u) && isValidAbsolutePath(u.String()) { 78 | return true 79 | } 80 | 81 | mw.LogEntryFrom(r).Infof("validator: not a valid relative URL") 82 | return false 83 | } 84 | 85 | func parsableRequestURI(r *http.Request, redirect string) (*url.URL, bool) { 86 | if redirect == "" { 87 | mw.LogEntryFrom(r).Debugf("validator: redirect is empty") 88 | return nil, false 89 | } 90 | 91 | u, err := url.ParseRequestURI(redirect) 92 | if err != nil { 93 | mw.LogEntryFrom(r).Infof("validator: %+v", err) 94 | return nil, false 95 | } 96 | 97 | return u, true 98 | } 99 | 100 | func isAllowedHost(u *url.URL, allowedDomains []string) bool { 101 | host := u.Host 102 | hostname := u.Hostname() 103 | 104 | if host == "" || hostname == "" || len(allowedDomains) == 0 { 105 | return false 106 | } 107 | 108 | for _, allowed := range allowedDomains { 109 | if isAllowedDomain(u, allowed) { 110 | return true 111 | } 112 | } 113 | 114 | return false 115 | } 116 | 117 | func isValidScheme(u *url.URL) bool { 118 | return u.Scheme == "http" || u.Scheme == "https" 119 | } 120 | 121 | func isRelativeURL(u *url.URL) bool { 122 | return u.Scheme == "" && u.Host == "" 123 | } 124 | 125 | func isValidAbsolutePath(redirect string) bool { 126 | return strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect) 127 | } 128 | 129 | func isAllowedDomain(u *url.URL, allowed string) bool { 130 | if len(allowed) == 0 { 131 | return false 132 | } 133 | 134 | host := u.Host 135 | hostname := u.Hostname() 136 | 137 | // exact match on host:port or host 138 | if host == allowed || hostname == allowed { 139 | return true 140 | } 141 | 142 | // subdomain of allowed domain 143 | if !strings.HasPrefix(allowed, ".") { 144 | allowed = "." + allowed 145 | } 146 | return strings.HasSuffix(host, allowed) 147 | } 148 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | The contract for using Wonderwall is fairly straightforward. 4 | 5 | For any endpoint that requires authentication: 6 | 7 | 1. Validate the `Authorization` header, and any tokens within. 8 | 2. If the `Authorization` header is missing, redirect the user to the [login endpoint](#1-login). 9 | 3. If the JWT `access_token` in the `Authorization` header is invalid or expired, redirect the user to 10 | the [login endpoint](#1-login). 11 | 4. If you need to log out a user, redirect the user to the [logout endpoint](#2-logout). 12 | 13 | Note that Wonderwall does not validate the `access_token` that is attached; this is the responsibility of the upstream application. 14 | Wonderwall only validates the `id_token` in accordance with the OpenID Connect Core specifications. 15 | 16 | ## Scenarios 17 | 18 | ### 1. Login 19 | 20 | When you must authenticate a user, redirect to the user to [the `/oauth2/login` endpoint](endpoints.md#oauth2login). 21 | 22 | #### 1.1. Autologin 23 | 24 | The `auto-login` option will configure Wonderwall to enforce authentication for **all** requests, except for the paths that are explicitly [excluded](configuration.md#auto-login-ignore-paths). 25 | 26 | If the user is _unauthenticated_ or has an [_inactive_ or _expired_ session](sessions.md), all requests will be short-circuited (i.e. return early and **not** proxied to your application). 27 | The short-circuited response depends on whether the request is a _top-level navigation_ request or not. 28 | 29 | A _top-level navigation request_ is a `GET` request that has the [Fetch metadata request headers](https://developer.mozilla.org/en-US/docs/Glossary/Fetch_metadata_request_header) `Sec-Fetch-Dest=document` and `Sec-Fetch-Mode=navigate`. 30 | If the user agent does not support the Fetch metadata headers, we look for an `Accept` header that includes `text/html`, which all major browsers send for navigation requests. 31 | Internet Explorer 8 won't work with this of course, so hopefully you're not in a position that requires supporting this browser. 32 | 33 | A top-level navigation request results in a HTTP 302 Found response with the `Location` header pointing to [the `/oauth2/login` endpoint](endpoints.md#oauth2login). 34 | The `redirect` parameter in the login URL is set to the value found in the `Referer` header, so that the user is redirected back to their intended location after login. 35 | If the `Referer` header is empty, the `redirect` parameter is set to the matching ingress path for the original request. 36 | 37 | Other requests are considered non-navigational requests and result in a HTTP 401 Unauthorized response with the `Location` header set as described above. 38 | 39 | For defence in depth, you should still check the `Authorization` header for a token and validate the token even when using auto-login. 40 | 41 | ### 2. Logout 42 | 43 | When you must log out a user, redirect to the user to [the `/oauth2/logout` endpoint](endpoints.md#oauth2logout). 44 | 45 | The user's session with the sidecar will be cleared, and the user will be redirected to the identity provider for 46 | global/single-logout, if logged in with SSO (single sign-on) at the identity provider. 47 | 48 | #### 2.1 Local Logout 49 | 50 | If you only want to perform a _local logout_ for the user, perform a `GET` request from the user's browser / user agent to [the `/oauth2/logout/local` endpoint](endpoints.md#oauth2logoutlocal). 51 | 52 | This will only clear the user's local session (i.e. remove the cookies) with the sidecar, without performing global logout at the identity provider. 53 | The endpoint responds with a HTTP 204 after successful logout. It will **not** respond with a redirect. 54 | 55 | A local logout is useful for scenarios where users frequently switch between multiple accounts. 56 | This means that they do not have to re-enter their credentials (e.g. username, password, 2FA) between each local logout, as they still have an SSO-session logged in with the identity provider. 57 | If the user is using a shared device with other users, only performing a local logout is thus a security risk. 58 | 59 | **Ensure you understand the difference in intentions between the two logout endpoints. If you're unsure, use `/oauth2/logout`.** 60 | 61 | ### 3. Advanced: Session Management 62 | 63 | See the [session management](sessions.md) page for details. 64 | -------------------------------------------------------------------------------- /pkg/openid/client/client_test.go: -------------------------------------------------------------------------------- 1 | package client_test 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/lestrrat-go/jwx/v3/jwa" 11 | "github.com/lestrrat-go/jwx/v3/jws" 12 | "github.com/lestrrat-go/jwx/v3/jwt" 13 | "github.com/stretchr/testify/assert" 14 | 15 | "github.com/nais/wonderwall/pkg/mock" 16 | "github.com/nais/wonderwall/pkg/openid/client" 17 | ) 18 | 19 | func TestClientAuthenticationAssertion(t *testing.T) { 20 | cfg := mock.Config() 21 | cfg.OpenID.ClientID = "some-client-id" 22 | 23 | openidConfig := mock.NewTestConfiguration(cfg) 24 | openidConfig.TestProvider.SetIssuer("some-issuer") 25 | c := newTestClientWithConfig(openidConfig) 26 | 27 | expiry := 30 * time.Second 28 | jwtAssertion, err := c.ClientAuthenticationAssertion(expiry) 29 | assert.NoError(t, err) 30 | 31 | assertFlattenedAudience(t, jwtAssertion) 32 | 33 | key := openidConfig.Client().ClientJWK() 34 | publicKey, err := key.PublicKey() 35 | assert.NoError(t, err) 36 | 37 | alg, ok := publicKey.Algorithm() 38 | assert.True(t, ok) 39 | 40 | opts := []jwt.ParseOption{ 41 | jwt.WithKey(alg, publicKey), 42 | jwt.WithRequiredClaim(jwt.IssuedAtKey), 43 | jwt.WithRequiredClaim(jwt.ExpirationKey), 44 | jwt.WithRequiredClaim(jwt.JwtIDKey), 45 | } 46 | assertion, err := jwt.ParseString(jwtAssertion, opts...) 47 | assert.NoError(t, err) 48 | 49 | aud, ok := assertion.Audience() 50 | assert.True(t, ok) 51 | assert.ElementsMatch(t, []string{"some-issuer"}, aud) 52 | 53 | iss, ok := assertion.Issuer() 54 | assert.True(t, ok) 55 | assert.Equal(t, "some-client-id", iss) 56 | 57 | sub, ok := assertion.Subject() 58 | assert.True(t, ok) 59 | assert.Equal(t, "some-client-id", sub) 60 | 61 | iat, ok := assertion.IssuedAt() 62 | assert.True(t, ok) 63 | assert.True(t, iat.Before(time.Now())) 64 | 65 | exp, ok := assertion.Expiration() 66 | assert.True(t, ok) 67 | assert.True(t, exp.After(time.Now())) 68 | assert.True(t, exp.Before(time.Now().Add(expiry))) 69 | 70 | msg, err := jws.ParseString(jwtAssertion) 71 | assert.NoError(t, err) 72 | assert.Len(t, msg.Signatures(), 1) 73 | headers := msg.Signatures()[0].ProtectedHeaders() 74 | 75 | typ, ok := headers.Type() 76 | assert.True(t, ok) 77 | assert.Equal(t, "JWT", typ) 78 | 79 | alg, ok = headers.Algorithm() 80 | assert.True(t, ok) 81 | assert.Equal(t, jwa.RS256(), alg) 82 | 83 | expectedKid, ok := key.KeyID() 84 | assert.True(t, ok) 85 | kid, ok := headers.KeyID() 86 | assert.True(t, ok) 87 | assert.Equal(t, expectedKid, kid) 88 | } 89 | 90 | func TestClientAuthenticationAssertionHeader(t *testing.T) { 91 | cfg := mock.Config() 92 | cfg.OpenID.ClientID = "some-client-id" 93 | cfg.OpenID.NewClientAuthJWTType = true 94 | 95 | openidConfig := mock.NewTestConfiguration(cfg) 96 | openidConfig.TestProvider.SetIssuer("some-issuer") 97 | c := newTestClientWithConfig(openidConfig) 98 | 99 | expiry := 30 * time.Second 100 | jwtAssertion, err := c.ClientAuthenticationAssertion(expiry) 101 | assert.NoError(t, err) 102 | 103 | msg, err := jws.ParseString(jwtAssertion) 104 | assert.NoError(t, err) 105 | assert.Len(t, msg.Signatures(), 1) 106 | headers := msg.Signatures()[0].ProtectedHeaders() 107 | 108 | typ, ok := headers.Type() 109 | assert.True(t, ok) 110 | assert.Equal(t, "client-authentication+jwt", typ) 111 | } 112 | 113 | // assertFlattenedAudience asserts that the raw JWT assertion has a flattened audience claim, i.e. aud is a string value. 114 | // We do this as the jwx library only exposes the audience as a slice of strings for parsed JWTs. 115 | func assertFlattenedAudience(t *testing.T, jwtAssertion string) { 116 | parts := strings.Split(jwtAssertion, ".") 117 | assert.Len(t, parts, 3) 118 | 119 | rawClaims, err := base64.RawURLEncoding.DecodeString(parts[1]) 120 | assert.NoError(t, err) 121 | 122 | claims := make(map[string]any) 123 | err = json.Unmarshal(rawClaims, &claims) 124 | assert.NoError(t, err) 125 | 126 | assert.Equal(t, "some-issuer", claims["aud"]) 127 | } 128 | 129 | func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client { 130 | jwksProvider := mock.NewTestJwksProvider() 131 | return client.NewClient(config, jwksProvider) 132 | } 133 | -------------------------------------------------------------------------------- /internal/o11y/otel/otel.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/nais/wonderwall/pkg/config" 9 | log "github.com/sirupsen/logrus" 10 | "github.com/uptrace/opentelemetry-go-extra/otellogrus" 11 | "go.opentelemetry.io/otel" 12 | "go.opentelemetry.io/otel/attribute" 13 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 14 | "go.opentelemetry.io/otel/propagation" 15 | "go.opentelemetry.io/otel/sdk/resource" 16 | tracesdk "go.opentelemetry.io/otel/sdk/trace" 17 | "go.opentelemetry.io/otel/semconv/v1.37.0" 18 | "go.opentelemetry.io/otel/trace" 19 | "go.opentelemetry.io/otel/trace/noop" 20 | ) 21 | 22 | const ( 23 | // How long between each time OT sends something to the collector. 24 | batchTimeout = 5 * time.Second 25 | ) 26 | 27 | var tracer = noop.NewTracerProvider().Tracer("noop") 28 | 29 | func Setup(ctx context.Context, cfg *config.Config) (func(context.Context), error) { 30 | prop := newPropagator() 31 | otel.SetTextMapPropagator(prop) 32 | 33 | res, err := newResource(attributesFrom(cfg)) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | tracerProvider, err := newTraceProvider(ctx, res) 39 | if err != nil { 40 | return nil, err 41 | } 42 | otel.SetTracerProvider(tracerProvider) 43 | tracer = tracerProvider.Tracer(cfg.OpenTelemetry.ServiceName) 44 | 45 | log.Debug("opentelemetry: initialized configuration") 46 | shutdown := func(ctx context.Context) { 47 | if err := tracerProvider.Shutdown(ctx); err != nil { 48 | log.Fatalf("fatal: otel shutdown error: %+v", err) 49 | } 50 | } 51 | 52 | // Add OpenTelemetry logging hook to logrus. 53 | // This attaches logs to the associated span in the given log context as events. 54 | log.AddHook(otellogrus.NewHook(otellogrus.WithLevels( 55 | log.PanicLevel, 56 | log.FatalLevel, 57 | log.ErrorLevel, 58 | log.WarnLevel, 59 | ))) 60 | 61 | return shutdown, nil 62 | } 63 | 64 | func StartSpan(ctx context.Context, spanName string) (context.Context, trace.Span) { 65 | return tracer.Start(ctx, spanName) 66 | } 67 | 68 | // StartSpanFromRequest starts a span from an incoming HTTP request and returns th request with the updated context. 69 | func StartSpanFromRequest(r *http.Request, spanName string) (*http.Request, trace.Span) { 70 | ctx := r.Context() 71 | ctx, span := StartSpan(ctx, spanName) 72 | return r.WithContext(ctx), span 73 | } 74 | 75 | func AddErrorEvent(span trace.Span, eventName, errType string, err error) { 76 | span.AddEvent(eventName, trace.WithAttributes( 77 | semconv.ExceptionTypeKey.String(errType), 78 | semconv.ExceptionMessageKey.String(err.Error()), 79 | )) 80 | } 81 | 82 | func attributesFrom(cfg *config.Config) []attribute.KeyValue { 83 | attrs := []attribute.KeyValue{ 84 | semconv.ServiceName(cfg.OpenTelemetry.ServiceName), 85 | semconv.ServiceVersion(cfg.Version), 86 | attribute.String("wonderwall.identity_provider.name", string(cfg.OpenID.Provider)), 87 | attribute.String("wonderwall.identity_provider.url", cfg.OpenID.WellKnownURL), 88 | attribute.Bool("wonderwall.autologin", cfg.AutoLogin), 89 | attribute.Bool("wonderwall.sso", cfg.SSO.Enabled), 90 | } 91 | if cfg.SSO.Enabled { 92 | attrs = append(attrs, attribute.String("wonderwall.sso.mode", string(cfg.SSO.Mode))) 93 | } 94 | return attrs 95 | } 96 | 97 | func newResource(attributes []attribute.KeyValue) (*resource.Resource, error) { 98 | return resource.Merge( 99 | resource.Default(), 100 | resource.NewWithAttributes( 101 | semconv.SchemaURL, 102 | attributes..., 103 | ), 104 | ) 105 | } 106 | 107 | func newPropagator() propagation.TextMapPropagator { 108 | return propagation.NewCompositeTextMapPropagator( 109 | propagation.TraceContext{}, 110 | propagation.Baggage{}, 111 | ) 112 | } 113 | 114 | func newTraceProvider(ctx context.Context, res *resource.Resource) (*tracesdk.TracerProvider, error) { 115 | traceExporter, err := otlptracegrpc.New(ctx) 116 | if err != nil { 117 | return nil, err 118 | } 119 | 120 | traceProvider := tracesdk.NewTracerProvider( 121 | tracesdk.WithBatcher( 122 | traceExporter, 123 | tracesdk.WithBatchTimeout(batchTimeout), 124 | ), 125 | tracesdk.WithResource(res), 126 | ) 127 | return traceProvider, nil 128 | } 129 | -------------------------------------------------------------------------------- /pkg/cookie/cookie.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/nais/wonderwall/internal/crypto" 11 | ) 12 | 13 | const ( 14 | DefaultPrefix = "io.nais.wonderwall" 15 | ) 16 | 17 | var ( 18 | Login = login(DefaultPrefix) 19 | LoginCount = loginCount(DefaultPrefix) 20 | Logout = logout(DefaultPrefix) 21 | Retry = retry(DefaultPrefix) 22 | Session = session(DefaultPrefix) 23 | ErrInvalidValue = errors.New("invalid value") 24 | ErrDecrypt = errors.New("unable to decrypt, key or scheme mismatch") 25 | ) 26 | 27 | type Cookie struct { 28 | *http.Cookie 29 | } 30 | 31 | func (in *Cookie) Encrypt(crypter crypto.Crypter) (*Cookie, error) { 32 | plaintext := []byte(in.Cookie.Value) 33 | ciphertext, err := crypter.Encrypt(plaintext) 34 | if err != nil { 35 | return nil, fmt.Errorf("unable to encrypt cookie '%s': %w", in.Cookie.Name, err) 36 | } 37 | 38 | value := base64.RawURLEncoding.EncodeToString(ciphertext) 39 | in.Cookie.Value = value 40 | return in, nil 41 | } 42 | 43 | func (in *Cookie) Decrypt(crypter crypto.Crypter) (string, error) { 44 | ciphertext, err := base64.RawURLEncoding.DecodeString(in.Value) 45 | if err != nil { 46 | return "", fmt.Errorf("%w: named '%s': %w", ErrInvalidValue, in.Name, err) 47 | } 48 | 49 | plaintext, err := crypter.Decrypt(ciphertext) 50 | if err != nil { 51 | return "", fmt.Errorf("%w: named '%s': %w", ErrDecrypt, in.Name, err) 52 | } 53 | 54 | return string(plaintext), err 55 | } 56 | 57 | func Clear(w http.ResponseWriter, name string, opts Options) { 58 | expires := time.Unix(0, 0) 59 | maxAge := -1 60 | 61 | cookie := &http.Cookie{ 62 | Expires: expires, 63 | HttpOnly: true, 64 | MaxAge: maxAge, 65 | Name: name, 66 | Path: "/", 67 | SameSite: opts.SameSite, 68 | Secure: opts.Secure, 69 | } 70 | 71 | if len(opts.Domain) > 0 { 72 | cookie.Domain = opts.Domain 73 | } 74 | 75 | if len(opts.Path) > 0 { 76 | cookie.Path = opts.Path 77 | } 78 | 79 | http.SetCookie(w, cookie) 80 | } 81 | 82 | func Get(r *http.Request, key string) (*Cookie, error) { 83 | cookie, err := r.Cookie(key) 84 | if err != nil { 85 | return nil, fmt.Errorf("no cookie named '%s': %w", key, err) 86 | } 87 | 88 | return &Cookie{cookie}, nil 89 | } 90 | 91 | func GetDecrypted(r *http.Request, key string, crypter crypto.Crypter) (string, error) { 92 | encryptedCookie, err := Get(r, key) 93 | if err != nil { 94 | return "", err 95 | } 96 | 97 | return encryptedCookie.Decrypt(crypter) 98 | } 99 | 100 | func Make(name, value string, opts Options) *Cookie { 101 | cookie := &http.Cookie{ 102 | HttpOnly: true, 103 | Name: name, 104 | Path: "/", 105 | SameSite: opts.SameSite, 106 | Secure: opts.Secure, 107 | Value: value, 108 | } 109 | 110 | if len(opts.Domain) > 0 { 111 | cookie.Domain = opts.Domain 112 | } 113 | 114 | if len(opts.Path) > 0 { 115 | cookie.Path = opts.Path 116 | } 117 | 118 | return &Cookie{cookie} 119 | } 120 | 121 | func Set(w http.ResponseWriter, cookie *Cookie) { 122 | http.SetCookie(w, cookie.Cookie) 123 | } 124 | 125 | func EncryptAndSet(w http.ResponseWriter, key, value string, opts Options, crypter crypto.Crypter) error { 126 | encryptedCookie, err := Make(key, value, opts).Encrypt(crypter) 127 | if err != nil { 128 | return err 129 | } 130 | 131 | Set(w, encryptedCookie) 132 | return nil 133 | } 134 | 135 | func ConfigureCookieNamesWithPrefix(prefix string) { 136 | Login = login(prefix) 137 | Logout = logout(prefix) 138 | Retry = retry(prefix) 139 | Session = session(prefix) 140 | } 141 | 142 | func withPrefix(prefix, s string) string { 143 | return fmt.Sprintf("%s.%s", prefix, s) 144 | } 145 | 146 | func login(prefix string) string { 147 | return withPrefix(prefix, "callback") 148 | } 149 | 150 | func loginCount(prefix string) string { 151 | return withPrefix(prefix, "logincount") 152 | } 153 | 154 | func logout(prefix string) string { 155 | return withPrefix(prefix, "logout") 156 | } 157 | 158 | func retry(prefix string) string { 159 | return withPrefix(prefix, "retry") 160 | } 161 | 162 | func session(prefix string) string { 163 | return withPrefix(prefix, "session") 164 | } 165 | -------------------------------------------------------------------------------- /hack/dashboard.yaml: -------------------------------------------------------------------------------- 1 | title: Wonderwall 2 | editable: true 3 | tags: [generated, yaml] 4 | auto_refresh: 1m 5 | time: ["now-24h", "now"] 6 | timezone: default # valid values are: utc, browser, default 7 | 8 | # Render to JSON using https://github.com/K-Phoen/grabana v0.17.0 or newer 9 | # Import into Grafana using UI (remember to select folder) 10 | 11 | variables: 12 | - custom: 13 | name: env 14 | default: dev 15 | values_map: 16 | dev: dev 17 | prod: prod 18 | - query: 19 | name: namespace 20 | label: Namespace 21 | datasource: $env-gcp 22 | request: "label_values(kube_pod_container_info{container=\"wonderwall\"}, namespace)" 23 | include_all: true 24 | default_all: true 25 | all_value: ".*" 26 | - datasource: 27 | name: ds 28 | type: prometheus 29 | regex: $env-gcp 30 | include_all: true 31 | hide: variable 32 | - query: 33 | name: redis_op 34 | label: Redis Operation 35 | datasource: $env-gcp 36 | request: "label_values(wonderwall_redis_latency_bucket, operation)" 37 | include_all: true 38 | default_all: true 39 | hide: variable 40 | 41 | rows: 42 | - name: Versions 43 | collapse: false 44 | panels: 45 | - single_stat: 46 | title: Sidecar versions in use 47 | datasource: $ds 48 | transparent: true 49 | span: 12 50 | targets: 51 | - prometheus: 52 | query: count(label_replace(kube_pod_container_info{container="wonderwall",namespace=~"$namespace"}, "version", "$1", "image", ".*:(.*)")) by (version) 53 | legend: "{{ version }}" 54 | instant: true 55 | - name: Resource usage 56 | collapse: false 57 | panels: 58 | - graph: 59 | title: Memory usage - $ds 60 | datasource: $ds 61 | transparent: true 62 | targets: 63 | - prometheus: 64 | query: sum(container_memory_working_set_bytes{container="wonderwall",namespace=~"$namespace"}) by (pod, namespace) 65 | legend: "working set {{ pod }} in {{ namespace }}" 66 | - prometheus: 67 | query: sum(container_memory_usage_bytes{container="wonderwall",namespace=~"$namespace"}) by (pod, namespace) 68 | legend: "Resident set size {{ pod }} in {{ namespace }}" 69 | - graph: 70 | title: CPU usage - $ds 71 | datasource: $ds 72 | transparent: true 73 | targets: 74 | - prometheus: 75 | query: sum(irate(container_cpu_usage_seconds_total{container="wonderwall",namespace=~"$namespace"}[2m])) by (pod, namespace) 76 | legend: "{{ pod }} in {{ namespace }}" 77 | - name: HTTP 78 | collapse: false 79 | panels: 80 | - graph: 81 | title: HTTP requests 82 | datasource: $ds 83 | transparent: true 84 | targets: 85 | - prometheus: 86 | query: sum(rate(requests_total{job="wonderwall",namespace=~"$namespace"}[5m])) by (code) 87 | legend: "{{ code }}" 88 | - graph: 89 | title: HTTP latency 90 | datasource: $ds 91 | transparent: true 92 | targets: 93 | - prometheus: 94 | query: sum(irate(request_duration_seconds_sum{job="wonderwall",namespace=~"$namespace"}[2m])) by (path) 95 | legend: "{{ path }}" 96 | - name: Redis Latency 97 | collapse: false 98 | panels: 99 | - heatmap: 100 | # Must be done manually in Grafana after import: Set max datapoints to 25 101 | title: $redis_op 102 | datasource: $ds 103 | repeat: redis_op 104 | data_format: time_series_buckets 105 | hide_zero_buckets: true 106 | transparent: true 107 | span: 4 108 | tooltip: 109 | show: true 110 | showhistogram: false 111 | decimals: 0 112 | yaxis: 113 | unit: "dtdurations" 114 | decimals: 0 115 | targets: 116 | - prometheus: 117 | query: sum(increase(wonderwall_redis_latency_bucket{operation="$redis_op",namespace=~"$namespace"}[$__interval])) by (le) 118 | legend: "{{ le }}" 119 | format: heatmap 120 | -------------------------------------------------------------------------------- /pkg/openid/config/client.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/lestrrat-go/jwx/v3/jwk" 7 | log "github.com/sirupsen/logrus" 8 | 9 | "github.com/nais/wonderwall/pkg/config" 10 | "github.com/nais/wonderwall/pkg/openid/scopes" 11 | ) 12 | 13 | type AuthMethod string 14 | 15 | const ( 16 | AuthMethodPrivateKeyJWT AuthMethod = "private_key_jwt" 17 | AuthMethodClientSecret AuthMethod = "client_secret" 18 | ) 19 | 20 | type Client interface { 21 | ACRValues() string 22 | Audiences() map[string]bool 23 | AuthMethod() AuthMethod 24 | ClientID() string 25 | ClientJWK() jwk.Key 26 | ClientSecret() string 27 | NewClientAuthJWTType() bool 28 | PostLogoutRedirectURI() string 29 | ResourceIndicator() string 30 | Scopes() scopes.Scopes 31 | UILocales() string 32 | WellKnownURL() string 33 | } 34 | 35 | type client struct { 36 | config.OpenID 37 | authMethod AuthMethod 38 | clientJwk jwk.Key 39 | trustedAudiences map[string]bool 40 | } 41 | 42 | var _ Client = (*client)(nil) 43 | 44 | func (in *client) ACRValues() string { 45 | return in.OpenID.ACRValues 46 | } 47 | 48 | func (in *client) Audiences() map[string]bool { 49 | return in.trustedAudiences 50 | } 51 | 52 | func (in *client) AuthMethod() AuthMethod { 53 | return in.authMethod 54 | } 55 | 56 | func (in *client) ClientID() string { 57 | return in.OpenID.ClientID 58 | } 59 | 60 | func (in *client) ClientJWK() jwk.Key { 61 | return in.clientJwk 62 | } 63 | 64 | func (in *client) ClientSecret() string { 65 | return in.OpenID.ClientSecret 66 | } 67 | 68 | func (in *client) NewClientAuthJWTType() bool { 69 | return in.OpenID.NewClientAuthJWTType 70 | } 71 | 72 | func (in *client) PostLogoutRedirectURI() string { 73 | return in.OpenID.PostLogoutRedirectURI 74 | } 75 | 76 | func (in *client) ResourceIndicator() string { 77 | return in.OpenID.ResourceIndicator 78 | } 79 | 80 | func (in *client) Scopes() scopes.Scopes { 81 | return scopes.DefaultScopes().WithAdditional(in.OpenID.Scopes...) 82 | } 83 | 84 | func (in *client) UILocales() string { 85 | return in.OpenID.UILocales 86 | } 87 | 88 | func (in *client) WellKnownURL() string { 89 | return in.OpenID.WellKnownURL 90 | } 91 | 92 | func NewClientConfig(cfg *config.Config) (Client, error) { 93 | c := &client{ 94 | OpenID: cfg.OpenID, 95 | trustedAudiences: cfg.OpenID.TrustedAudiences(), 96 | } 97 | 98 | if len(cfg.OpenID.ClientJWK) == 0 && len(cfg.OpenID.ClientSecret) == 0 { 99 | return nil, fmt.Errorf("missing required config: at least one of %q or %q must be set", config.OpenIDClientJWK, config.OpenIDClientSecret) 100 | } 101 | 102 | if len(cfg.OpenID.ClientSecret) > 0 { 103 | c.authMethod = AuthMethodClientSecret 104 | } 105 | 106 | if len(cfg.OpenID.ClientJWK) > 0 { 107 | if c.authMethod == AuthMethodClientSecret { 108 | log.WithField("logger", "wonderwall.config").Debug("both client JWK and client secret were set; using client JWK...") 109 | } 110 | 111 | clientJwk, err := jwk.ParseKey([]byte(cfg.OpenID.ClientJWK)) 112 | if err != nil { 113 | return nil, fmt.Errorf("parsing client JWK: %w", err) 114 | } 115 | 116 | c.clientJwk = clientJwk 117 | c.authMethod = AuthMethodPrivateKeyJWT 118 | } 119 | 120 | var clientConfig Client 121 | switch cfg.OpenID.Provider { 122 | case config.ProviderIDPorten: 123 | clientConfig = c.IDPorten() 124 | case config.ProviderAzure: 125 | clientConfig = c.Azure() 126 | case "": 127 | return nil, fmt.Errorf("missing required config %q", config.OpenIDProvider) 128 | default: 129 | clientConfig = c 130 | } 131 | 132 | if len(clientConfig.ClientID()) == 0 { 133 | return nil, fmt.Errorf("missing required config %q", config.OpenIDClientID) 134 | } 135 | 136 | if len(clientConfig.WellKnownURL()) == 0 { 137 | return nil, fmt.Errorf("missing required config %q", config.OpenIDWellKnownURL) 138 | } 139 | 140 | return clientConfig, nil 141 | } 142 | 143 | type azure struct { 144 | *client 145 | } 146 | 147 | func (in *client) Azure() *azure { 148 | return &azure{ 149 | client: in, 150 | } 151 | } 152 | 153 | func (in *azure) Scopes() scopes.Scopes { 154 | return scopes.DefaultScopes(). 155 | WithAzureScope(in.OpenID.ClientID). 156 | WithOfflineAccess(). 157 | WithAdditional(in.OpenID.Scopes...) 158 | } 159 | 160 | type idporten struct { 161 | *client 162 | } 163 | 164 | func (in *client) IDPorten() *idporten { 165 | return &idporten{ 166 | client: in, 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /pkg/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/nais/wonderwall/pkg/config" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestConfig_Validate(t *testing.T) { 12 | type test struct { 13 | name string 14 | mutate func(cfg *config.Config) 15 | } 16 | 17 | run := func(name string, base *config.Config, errorCases []test) { 18 | t.Run(name, func(t *testing.T) { 19 | t.Run("happy path", func(t *testing.T) { 20 | assert.NoError(t, base.Validate()) 21 | }) 22 | 23 | for _, tt := range errorCases { 24 | t.Run(tt.name, func(t *testing.T) { 25 | cfg := *base 26 | tt.mutate(&cfg) 27 | assert.Error(t, cfg.Validate()) 28 | }) 29 | } 30 | }) 31 | } 32 | 33 | base, err := config.Initialize() 34 | require.NoError(t, err) 35 | 36 | run("default", base, []test{ 37 | { 38 | "invalid value for cookie.same-site", 39 | func(cfg *config.Config) { 40 | cfg.Cookie.SameSite = "invalid" 41 | }, 42 | }, 43 | { 44 | "upstream ip must be set if port is set", 45 | func(cfg *config.Config) { 46 | cfg.UpstreamIP = "" 47 | cfg.UpstreamPort = 8080 48 | }, 49 | }, 50 | { 51 | "upstream port must be set if ip is set", 52 | func(cfg *config.Config) { 53 | cfg.UpstreamIP = "127.0.0.1" 54 | cfg.UpstreamPort = 0 55 | }, 56 | }, 57 | { 58 | "upstream port must not exceed 65535", 59 | func(cfg *config.Config) { 60 | cfg.UpstreamIP = "127.0.0.1" 61 | cfg.UpstreamPort = 65536 62 | }, 63 | }, 64 | { 65 | "upstream port must not be negative", 66 | func(cfg *config.Config) { 67 | cfg.UpstreamIP = "127.0.0.1" 68 | cfg.UpstreamPort = -1 69 | }, 70 | }, 71 | { 72 | "shutdown graceful period must be greater than wait before period", 73 | func(cfg *config.Config) { 74 | cfg.ShutdownGracefulPeriod = 1 75 | cfg.ShutdownWaitBeforePeriod = 1 76 | }, 77 | }, 78 | { 79 | "secure cookies cannot be disabled for non-localhost ingress", 80 | func(cfg *config.Config) { 81 | cfg.Cookie.Secure = false 82 | cfg.Ingresses = []string{"http://not-localhost.example"} 83 | }, 84 | }, 85 | { 86 | "secure cookies cannot be disabled for secure ingress", 87 | func(cfg *config.Config) { 88 | cfg.Cookie.Secure = false 89 | cfg.Ingresses = []string{"https://localhost:3000"} 90 | }, 91 | }, 92 | }) 93 | 94 | server := *base 95 | server.SSO.Enabled = true 96 | server.SSO.Mode = config.SSOModeServer 97 | server.SSO.Domain = "example.com" 98 | server.SSO.SessionCookieName = "some-cookie" 99 | server.SSO.ServerDefaultRedirectURL = "https://default.local" 100 | server.Redis.Address = "localhost:6379" 101 | 102 | run("sso server", &server, []test{ 103 | { 104 | "missing redis", 105 | func(cfg *config.Config) { 106 | cfg.Redis = config.Redis{} 107 | }, 108 | }, 109 | { 110 | "missing cookie name", 111 | func(cfg *config.Config) { 112 | cfg.SSO.SessionCookieName = "" 113 | }, 114 | }, 115 | { 116 | "missing session cookie name", 117 | func(cfg *config.Config) { 118 | cfg.SSO.SessionCookieName = "" 119 | }, 120 | }, 121 | { 122 | "missing domain", 123 | func(cfg *config.Config) { 124 | cfg.SSO.Domain = "" 125 | }, 126 | }, 127 | { 128 | "invalid default redirect url", 129 | func(cfg *config.Config) { 130 | cfg.SSO.ServerDefaultRedirectURL = "invalid" 131 | }, 132 | }, 133 | { 134 | "invalid mode", 135 | func(cfg *config.Config) { 136 | cfg.SSO.Mode = "invalid" 137 | }, 138 | }, 139 | }) 140 | 141 | proxy := *base 142 | proxy.SSO.Enabled = true 143 | proxy.SSO.Mode = config.SSOModeProxy 144 | proxy.SSO.ServerURL = "https://sso-server.local" 145 | proxy.SSO.SessionCookieName = "some-cookie" 146 | proxy.Redis.Address = "localhost:6379" 147 | 148 | run("sso proxy", &proxy, []test{ 149 | { 150 | "missing redis", 151 | func(cfg *config.Config) { 152 | cfg.Redis = config.Redis{} 153 | }, 154 | }, 155 | { 156 | "missing cookie name", 157 | func(cfg *config.Config) { 158 | cfg.SSO.SessionCookieName = "" 159 | }, 160 | }, 161 | { 162 | "missing session cookie name", 163 | func(cfg *config.Config) { 164 | cfg.SSO.SessionCookieName = "" 165 | }, 166 | }, 167 | { 168 | "invalid server url", 169 | func(cfg *config.Config) { 170 | cfg.SSO.ServerURL = "invalid" 171 | }, 172 | }, 173 | { 174 | "invalid mode", 175 | func(cfg *config.Config) { 176 | cfg.SSO.Mode = "invalid" 177 | }, 178 | }, 179 | }) 180 | } 181 | -------------------------------------------------------------------------------- /charts/wonderwall/templates/prometheusrule.yaml: -------------------------------------------------------------------------------- 1 | {{ if .Values.idporten.enabled }} 2 | --- 3 | apiVersion: monitoring.coreos.com/v1 4 | kind: PrometheusRule 5 | metadata: 6 | name: {{ include "wonderwall.fullname" . }}-idporten-alerts 7 | labels: 8 | {{- include "wonderwall.labels" . | nindent 4 }} 9 | spec: 10 | groups: 11 | - name: "wonderwall-idporten" 12 | rules: 13 | - alert: Wonderwall sidecars for ID-porten reports a high amount of internal errors 14 | expr: sum(increase(requests_total{job="nais-system/monitoring-wonderwall", code="500", provider="idporten"}[5m])) > 10 15 | for: 5m 16 | annotations: 17 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes. 18 | consequence: This probably means that end-users are having trouble with authentication. 19 | action: | 20 | * Check the logs for errors: 21 | * Check DigDir status: / 22 | * Check Aiven Redis (session store) and verify Aiven network connectivity 23 | * Check with DigDir in [#nav-digdir](https://nav-it.slack.com/archives/C013RTT99G9) 24 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=idporten" 25 | labels: 26 | severity: critical 27 | namespace: {{ .Release.Namespace }} 28 | - alert: Wonderwall SSO server for ID-porten reports a high amount of internal errors 29 | expr: sum(increase(requests_total{app="wonderwall-idporten", namespace="{{ .Release.Namespace }}", code="500"}[5m])) > 10 30 | for: 5m 31 | annotations: 32 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes. 33 | consequence: This probably means that end-users are having trouble with authentication. 34 | action: | 35 | * Check the logs for errors: 36 | * Check DigDir status: / 37 | * Check Aiven Redis (session store) and verify Aiven network connectivity 38 | * Check with DigDir in [#nav-digdir](https://nav-it.slack.com/archives/C013RTT99G9) 39 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=idporten&var-namespace=nais-system" 40 | labels: 41 | severity: critical 42 | namespace: {{ .Release.Namespace }} 43 | {{ end }} 44 | {{ if .Values.azure.enabled }} 45 | --- 46 | apiVersion: monitoring.coreos.com/v1 47 | kind: PrometheusRule 48 | metadata: 49 | name: {{ include "wonderwall.fullname" . }}-azure-alerts 50 | labels: 51 | {{- include "wonderwall.labels" . | nindent 4 }} 52 | spec: 53 | groups: 54 | - name: "wonderwall-azure" 55 | rules: 56 | - alert: Wonderwall for Azure AD reports a high amount of internal errors 57 | expr: sum(increase(requests_total{job="nais-system/monitoring-wonderwall", code="500", provider="azure"}[5m])) > 30 58 | for: 5m 59 | annotations: 60 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes. 61 | consequence: This probably means that end-users are having trouble with authentication. 62 | action: | 63 | * Check the logs for errors: 64 | * Check Azure status: https://status.azure.com/nb-no/status 65 | * Check Aiven Redis (session store) and verify Aiven network connectivity 66 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=azure" 67 | labels: 68 | severity: critical 69 | namespace: {{ .Release.Namespace }} 70 | {{- if .Values.azure.forwardAuth.enabled }} 71 | - alert: Wonderwall-fa (forward-auth / ansatt) for Azure AD reports a high amount of internal errors 72 | expr: sum(increase(requests_total{app="wonderwall-fa", namespace="{{ .Release.Namespace }}", code="500"}[5m])) > 30 73 | for: 5m 74 | annotations: 75 | summary: Wonderwall has responded with HTTP 500 for a high amount of requests within the last 5 minutes. 76 | consequence: This probably means that end-users are having trouble with authentication. 77 | action: | 78 | * Check the logs for errors: 79 | * Check Azure status: https://status.azure.com/nb-no/status 80 | * Check Aiven Redis (session store) and verify Aiven network connectivity 81 | dashboard_url: "https://monitoring.nais.io/d/wQPQ7uHnz/wonderwall?var-provider=azure&var-namespace=nais-system" 82 | labels: 83 | severity: critical 84 | namespace: {{ .Release.Namespace }} 85 | {{- end }} 86 | {{ end }} 87 | --------------------------------------------------------------------------------