├── .dockerignore
├── docs
├── index.md
├── changelog.md
├── assets
│ ├── ds-symbol-negative-mono.png
│ └── ds-symbol-positive-mono.png
├── architecture
│ ├── filtering-data.md
│ └── middleware-stack.md
├── overrides
│ └── partials
│ │ └── integrations
│ │ └── analytics
│ │ └── plausible.html
└── user-guide
│ ├── tips.md
│ ├── getting-started.md
│ ├── deployment.md
│ └── route-level-auth.md
├── examples
├── custom-integration
│ ├── .python-version
│ ├── pyproject.toml
│ ├── Dockerfile
│ ├── README.md
│ ├── docker-compose.yaml
│ └── src
│ │ └── custom_integration.py
└── opa
│ ├── policies
│ └── stac
│ │ └── policy.rego
│ ├── docker-compose.yaml
│ └── README.md
├── src
└── stac_auth_proxy
│ ├── utils
│ ├── __init__.py
│ ├── stac.py
│ ├── filters.py
│ ├── cache.py
│ ├── middleware.py
│ └── requests.py
│ ├── filters
│ ├── __init__.py
│ ├── template.py
│ └── opa.py
│ ├── handlers
│ ├── __init__.py
│ ├── healthz.py
│ ├── swagger_ui.py
│ └── reverse_proxy.py
│ ├── lambda.py
│ ├── __init__.py
│ ├── __main__.py
│ ├── middleware
│ ├── AddProcessTimeHeaderMiddleware.py
│ ├── __init__.py
│ ├── RemoveRootPathMiddleware.py
│ ├── Cql2ApplyFilterQueryStringMiddleware.py
│ ├── UpdateOpenApiMiddleware.py
│ ├── Cql2ApplyFilterBodyMiddleware.py
│ ├── AuthenticationExtensionMiddleware.py
│ ├── Cql2RewriteLinksFilterMiddleware.py
│ ├── ProcessLinksMiddleware.py
│ ├── Cql2BuildFilterMiddleware.py
│ ├── Cql2ValidateResponseBodyMiddleware.py
│ └── EnforceAuthMiddleware.py
│ ├── config.py
│ ├── lifespan.py
│ └── app.py
├── helm
├── Chart.yaml
├── templates
│ ├── service.yaml
│ ├── serviceaccount.yaml
│ ├── ingress.yaml
│ ├── _helpers.tpl
│ ├── deployment.yaml
│ └── NOTES.txt
└── values.yaml
├── .vscode
├── settings.json
└── launch.json
├── .github
└── workflows
│ ├── conventional-commits-prs.yaml
│ ├── publish-pypi.yaml
│ ├── release-please.yml
│ ├── publish-helm.yaml
│ ├── docs.yaml
│ ├── cicd.yaml
│ └── publish-docker.yaml
├── .coveragerc
├── .pre-commit-config.yaml
├── tests
├── test_configure_app.py
├── test_proxy.py
├── test_remove_root_path.py
├── utils.py
├── test_cache.py
├── test_filters_opa.py
├── test_defaults.py
├── test_utils.py
├── test_lifespan.py
├── test_auth_extension.py
└── conftest.py
├── LICENSE
├── Dockerfile
├── Makefile
├── README.md
├── pyproject.toml
├── .gitignore
├── mkdocs.yml
└── docker-compose.yaml
/.dockerignore:
--------------------------------------------------------------------------------
1 | .pgdata
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ../README.md
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | ../CHANGELOG.md
--------------------------------------------------------------------------------
/examples/custom-integration/.python-version:
--------------------------------------------------------------------------------
1 | 3.13
2 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utils module for stac_auth_proxy."""
2 |
--------------------------------------------------------------------------------
/docs/assets/ds-symbol-negative-mono.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/developmentseed/stac-auth-proxy/HEAD/docs/assets/ds-symbol-negative-mono.png
--------------------------------------------------------------------------------
/docs/assets/ds-symbol-positive-mono.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/developmentseed/stac-auth-proxy/HEAD/docs/assets/ds-symbol-positive-mono.png
--------------------------------------------------------------------------------
/helm/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | name: stac-auth-proxy
3 | description: A Helm chart for stac-auth-proxy
4 | type: application
5 | version: 0.1.2
6 | appVersion: "1.0.0"
7 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "tests",
4 | "-n",
5 | "auto"
6 | ],
7 | "python.testing.unittestEnabled": false,
8 | "python.testing.pytestEnabled": true
9 | }
--------------------------------------------------------------------------------
/src/stac_auth_proxy/filters/__init__.py:
--------------------------------------------------------------------------------
1 | """CQL2 filter factories."""
2 |
3 | from .opa import Opa
4 | from .template import Template
5 |
6 | __all__ = [
7 | "Opa",
8 | "Template",
9 | ]
10 |
--------------------------------------------------------------------------------
/examples/custom-integration/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "custom_integration"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.9"
7 | dependencies = []
8 |
--------------------------------------------------------------------------------
/examples/custom-integration/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG STAC_AUTH_PROXY_VERSION
2 | FROM ghcr.io/developmentseed/stac-auth-proxy:${STAC_AUTH_PROXY_VERSION}
3 |
4 | ADD . /opt/stac-auth-proxy-integration
5 |
6 | RUN pip install /opt/stac-auth-proxy-integration
7 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/handlers/__init__.py:
--------------------------------------------------------------------------------
1 | """Handlers to process requests."""
2 |
3 | from .healthz import HealthzHandler
4 | from .reverse_proxy import ReverseProxyHandler
5 | from .swagger_ui import SwaggerUI
6 |
7 | __all__ = ["ReverseProxyHandler", "HealthzHandler", "SwaggerUI"]
8 |
--------------------------------------------------------------------------------
/examples/opa/policies/stac/policy.rego:
--------------------------------------------------------------------------------
1 | package stac
2 |
3 | default items_cql2 := "\"naip:year\" = 2021"
4 |
5 | items_cql2 := "1=1" if {
6 | input.payload.sub != null
7 | }
8 |
9 | default collections_cql2 := "id = 'naip'"
10 |
11 | collections_cql2 := "1=1" if {
12 | input.payload.sub != null
13 | }
14 |
--------------------------------------------------------------------------------
/examples/custom-integration/README.md:
--------------------------------------------------------------------------------
1 | # Custom Integration Example
2 |
3 | This example demonstrates how to integrate with a custom filter factory.
4 |
5 | ## Running the Example
6 |
7 | From the root directory, run:
8 |
9 | ```sh
10 | docker compose -f docker-compose.yaml -f examples/custom-integration/docker-compose.yaml up
11 | ```
12 |
--------------------------------------------------------------------------------
/examples/custom-integration/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | # This compose file is intended to be run alongside the `docker-compose.yaml` file in the
2 | # root directory.
3 |
4 | services:
5 | proxy:
6 | build:
7 | context: examples/custom-integration
8 | args:
9 | STAC_AUTH_PROXY_VERSION: 0.1.2
10 | environment:
11 | ITEMS_FILTER_CLS: custom_integration:cql2_builder
12 | ITEMS_FILTER_KWARGS: '{"admin_user": "user123"}'
13 |
--------------------------------------------------------------------------------
/helm/templates/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: {{ include "stac-auth-proxy.fullname" . }}
5 | labels:
6 | {{- include "stac-auth-proxy.labels" . | nindent 4 }}
7 | spec:
8 | type: {{ .Values.service.type }}
9 | ports:
10 | - port: {{ .Values.service.port }}
11 | targetPort: http
12 | protocol: TCP
13 | name: http
14 | selector:
15 | {{- include "stac-auth-proxy.selectorLabels" . | nindent 4 }}
--------------------------------------------------------------------------------
/.github/workflows/conventional-commits-prs.yaml:
--------------------------------------------------------------------------------
1 | name: PR Conventional Commit Validation
2 |
3 | on:
4 | pull_request:
5 | types: [opened, synchronize, reopened, edited]
6 |
7 | jobs:
8 | validate-pr-title:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: PR Conventional Commit Validation
12 | uses: ytanikin/pr-conventional-commits@1.4.0
13 | with:
14 | task_types: '["feat","fix","docs","test","ci","refactor","perf","chore","revert"]'
15 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/lambda.py:
--------------------------------------------------------------------------------
1 | """Handler for AWS Lambda."""
2 |
3 | from stac_auth_proxy import create_app
4 |
5 | try:
6 | from mangum import Mangum
7 | except ImportError:
8 | raise ImportError(
9 | "mangum is required to use the Lambda handler. Install stac-auth-proxy[lambda]."
10 | )
11 |
12 |
13 | handler = Mangum(
14 | create_app(),
15 | # NOTE: lifespan="off" skips conformance check and upstream health checks on startup
16 | lifespan="off",
17 | )
18 |
--------------------------------------------------------------------------------
/.github/workflows/publish-pypi.yaml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | permissions:
9 | contents: read
10 | id-token: write
11 |
12 | jobs:
13 | publish-to-pypi:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: actions/setup-python@v3
18 | - uses: astral-sh/setup-uv@v4
19 | with:
20 | enable-cache: true
21 | - run: uv build
22 | - run: uv publish
23 |
--------------------------------------------------------------------------------
/helm/templates/serviceaccount.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.serviceAccount.create -}}
2 | apiVersion: v1
3 | kind: ServiceAccount
4 | metadata:
5 | name: {{ include "stac-auth-proxy.serviceAccountName" . }}
6 | labels:
7 | {{- include "stac-auth-proxy.labels" . | nindent 4 }}
8 | {{- with .Values.serviceAccount.annotations }}
9 | annotations:
10 | {{- toYaml . | nindent 4 }}
11 | {{- end }}
12 | {{- with .Values.serviceAccount.imagePullSecrets }}
13 | imagePullSecrets:
14 | {{- toYaml . | nindent 2 }}
15 | {{- end }}
16 | {{- end }}
--------------------------------------------------------------------------------
/src/stac_auth_proxy/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | STAC Auth Proxy package.
3 |
4 | This package contains the components for the STAC authentication and proxying system.
5 | It includes FastAPI routes for handling authentication, authorization, and interaction
6 | with some internal STAC API.
7 | """
8 |
9 | from .app import configure_app, create_app
10 | from .config import Settings
11 | from .lifespan import build_lifespan
12 |
13 | __all__ = [
14 | "build_lifespan",
15 | "create_app",
16 | "configure_app",
17 | "Settings",
18 | ]
19 |
--------------------------------------------------------------------------------
/examples/opa/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | proxy:
3 | environment:
4 | ITEMS_FILTER_CLS: stac_auth_proxy.filters:Opa
5 | ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/items_cql2"]'
6 | COLLECTIONS_FILTER_CLS: stac_auth_proxy.filters:Opa
7 | COLLECTIONS_FILTER_ARGS: '["http://opa:8181", "stac/collections_cql2"]'
8 |
9 | opa:
10 | image: openpolicyagent/opa:latest
11 | command: "run --server --addr=:8181 --watch /policies"
12 | ports:
13 | - "8181:8181"
14 | volumes:
15 | - ./examples/opa/policies:/policies
16 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/utils/stac.py:
--------------------------------------------------------------------------------
1 | """STAC-specific utilities."""
2 |
3 | from itertools import chain
4 |
5 |
6 | def get_links(data: dict) -> chain[dict]:
7 | """Get all links from a STAC response."""
8 | return chain(
9 | # Item/Collection
10 | data.get("links", []),
11 | # Collections/Items/Search
12 | (
13 | link
14 | for prop in ["features", "collections"]
15 | for object_with_links in data.get(prop, [])
16 | for link in object_with_links.get("links", [])
17 | ),
18 | )
19 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/__main__.py:
--------------------------------------------------------------------------------
1 | """Entry point for running the module without customized code."""
2 |
3 | import os
4 |
5 | import uvicorn
6 | from uvicorn.config import LOGGING_CONFIG
7 |
8 | uvicorn.run(
9 | f"{__package__}.app:create_app",
10 | host="0.0.0.0",
11 | port=int(os.getenv("PORT", 8000)),
12 | log_config={
13 | **LOGGING_CONFIG,
14 | "loggers": {
15 | **LOGGING_CONFIG["loggers"],
16 | __package__: {
17 | "level": "DEBUG",
18 | "handlers": ["default"],
19 | },
20 | },
21 | },
22 | reload=True,
23 | factory=True,
24 | )
25 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python Debugger: FastAPI",
9 | "type": "debugpy",
10 | "request": "launch",
11 | "module": "uvicorn",
12 | "args": [
13 | "stac_auth_proxy.app:create_app",
14 | "--reload",
15 | "--factory"
16 | ],
17 | "jinja": true,
18 | "cwd": "${workspaceFolder}/src",
19 | "justMyCode": false
20 | }
21 | ]
22 | }
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = True
3 | source = src/stac_auth_proxy
4 | omit =
5 | */tests/*
6 | */test_*
7 | */__pycache__/*
8 | */venv/*
9 | */build/*
10 | */dist/*
11 | */htmlcov/*
12 | */lambda.py
13 | */__main__.py
14 |
15 | [report]
16 | exclude_lines =
17 | pragma: no cover
18 | def __repr__
19 | if self.debug:
20 | if settings.DEBUG
21 | raise AssertionError
22 | raise NotImplementedError
23 | if 0:
24 | if __name__ == .__main__.:
25 | class .*\bProtocol\):
26 | @(abc\.)?abstractmethod
27 | # Have to re-enable the standard pragma
28 | pragma: no cover
29 |
30 | [html]
31 | directory = htmlcov
32 |
33 | [xml]
34 | output = coverage.xml
35 |
--------------------------------------------------------------------------------
/.github/workflows/release-please.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - main
5 |
6 | permissions:
7 | contents: write
8 | pull-requests: write
9 |
10 | name: release-please
11 |
12 | jobs:
13 | release-please:
14 | runs-on: ubuntu-latest
15 | outputs:
16 | release_created: ${{ steps.release.outputs.release_created }}
17 | steps:
18 | - uses: actions/create-github-app-token@v2.0.6
19 | id: app-token
20 | with:
21 | app-id: ${{ secrets.DS_RELEASE_BOT_ID }}
22 | private-key: ${{ secrets.DS_RELEASE_BOT_PRIVATE_KEY }}
23 |
24 | - uses: googleapis/release-please-action@v4
25 | id: release
26 | with:
27 | release-type: python
28 | token: ${{ steps.app-token.outputs.token }}
29 |
--------------------------------------------------------------------------------
/examples/opa/README.md:
--------------------------------------------------------------------------------
1 | # Open Policy Agent (OPA) Integration
2 |
3 | This example demonstrates how to integrate with an Open Policy Agent (OPA) to authorize requests to a STAC API.
4 |
5 | ## Running the Example
6 |
7 | From the root directory, run:
8 |
9 | ```sh
10 | docker compose -f docker-compose.yaml -f examples/opa/docker-compose.yaml up
11 | ```
12 |
13 | ## Testing OPA
14 |
15 | ```sh
16 | ▶ curl -X POST "http://localhost:8181/v1/data/stac/cql2" \
17 | -H "Content-Type: application/json" \
18 | -d '{"input":{"payload": null}}'
19 | {"result":"private = true"}
20 | ```
21 |
22 | ```sh
23 | ▶ curl -X POST "http://localhost:8181/v1/data/stac/cql2" \
24 | -H "Content-Type: application/json" \
25 | -d '{"input":{"payload": {"sub": "user1"}}}'
26 | {"result":"1=1"}
27 | ```
28 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/filters/template.py:
--------------------------------------------------------------------------------
1 | """Generate CQL2 filter expressions via Jinja2 templating."""
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Any
5 |
6 | from jinja2 import BaseLoader, Environment
7 |
8 |
9 | @dataclass
10 | class Template:
11 | """Generate CQL2 filter expressions via Jinja2 templating."""
12 |
13 | template_str: str
14 | env: Environment = field(init=False)
15 |
16 | def __post_init__(self):
17 | """Initialize the Jinja2 environment."""
18 | self.env = Environment(loader=BaseLoader).from_string(self.template_str)
19 |
20 | async def __call__(self, context: dict[str, Any]) -> str:
21 | """Render a CQL2 filter expression with the request and auth token."""
22 | return self.env.render(**context).strip()
23 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/abravalheri/validate-pyproject
3 | rev: v0.24.1
4 | hooks:
5 | - id: validate-pyproject
6 |
7 | - repo: https://github.com/PyCQA/isort
8 | rev: 6.0.1
9 | hooks:
10 | - id: isort
11 | language_version: python
12 |
13 | - repo: https://github.com/charliermarsh/ruff-pre-commit
14 | rev: v0.12.11
15 | hooks:
16 | - id: ruff-check
17 | args: ["--fix"]
18 | - id: ruff-format
19 |
20 | - repo: https://github.com/pre-commit/mirrors-mypy
21 | rev: v1.17.1
22 | hooks:
23 | - id: mypy
24 | language_version: python
25 | exclude: tests/.*
26 | additional_dependencies:
27 | - types-simplejson
28 | - types-attrs
29 | - pydantic~=2.0
30 |
--------------------------------------------------------------------------------
/tests/test_configure_app.py:
--------------------------------------------------------------------------------
1 | """Tests for configuring an external FastAPI application."""
2 |
3 | from fastapi import FastAPI
4 | from fastapi.routing import APIRoute
5 |
6 | from stac_auth_proxy import Settings, configure_app
7 |
8 |
9 | def test_configure_app_excludes_proxy_route():
10 | """Ensure `configure_app` adds health route and omits proxy route."""
11 | app = FastAPI()
12 | settings = Settings(
13 | upstream_url="https://example.com",
14 | oidc_discovery_url="https://example.com/.well-known/openid-configuration",
15 | wait_for_upstream=False,
16 | check_conformance=False,
17 | default_public=True,
18 | )
19 |
20 | configure_app(app, settings)
21 |
22 | routes = [r.path for r in app.router.routes if isinstance(r, APIRoute)]
23 | assert settings.healthz_prefix in routes
24 | assert "/{path:path}" not in routes
25 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/handlers/healthz.py:
--------------------------------------------------------------------------------
1 | """Health check endpoints."""
2 |
3 | from dataclasses import dataclass, field
4 |
5 | from fastapi import APIRouter
6 | from httpx import AsyncClient
7 |
8 |
9 | @dataclass
10 | class HealthzHandler:
11 | """Handler for health check endpoints."""
12 |
13 | upstream_url: str
14 | router: APIRouter = field(init=False)
15 |
16 | def __post_init__(self):
17 | """Initialize the router."""
18 | self.router = APIRouter()
19 | self.router.add_api_route("", self.healthz, methods=["GET"])
20 | self.router.add_api_route("/upstream", self.healthz_upstream, methods=["GET"])
21 |
22 | async def healthz(self):
23 | """Return health of this API."""
24 | return {"status": "ok"}
25 |
26 | async def healthz_upstream(self):
27 | """Return health of upstream STAC API."""
28 | async with AsyncClient() as client:
29 | response = await client.get(self.upstream_url)
30 | response.raise_for_status()
31 | return {"status": "ok", "code": response.status_code}
32 |
--------------------------------------------------------------------------------
/.github/workflows/publish-helm.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Helm Chart
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - 'helm/**'
9 | - '.github/workflows/publish-helm.yaml'
10 | release:
11 | types: [created]
12 |
13 | jobs:
14 | publish-helm:
15 | runs-on: ubuntu-latest
16 | permissions:
17 | contents: read
18 | packages: write
19 |
20 | steps:
21 | - name: Checkout
22 | uses: actions/checkout@v4
23 | with:
24 | fetch-depth: 0
25 |
26 | - name: Install Helm
27 | uses: azure/setup-helm@v3
28 | with:
29 | version: v3.12.1
30 |
31 | - name: Login to GHCR
32 | uses: docker/login-action@v3
33 | with:
34 | registry: ghcr.io
35 | username: ${{ github.actor }}
36 | password: ${{ secrets.GITHUB_TOKEN }}
37 |
38 | - name: Package Helm Chart
39 | run: |
40 | helm package helm/
41 |
42 | - name: Push Helm Chart
43 | run: |
44 | helm push *.tgz oci://ghcr.io/${{ github.repository }}/charts
--------------------------------------------------------------------------------
/.github/workflows/docs.yaml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | test:
10 | name: Test
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout main
14 | uses: actions/checkout@v4
15 |
16 | - name: Set up Python 3.11
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: 3.11
20 |
21 | - uses: astral-sh/setup-uv@v4
22 | with:
23 | enable-cache: true
24 |
25 | - name: Test docs
26 | run: uv run --extra docs mkdocs build --strict
27 |
28 | deploy:
29 | name: Deploy
30 | runs-on: ubuntu-latest
31 | if: github.ref_name == 'main'
32 | steps:
33 | - name: Checkout main
34 | uses: actions/checkout@v4
35 |
36 | - name: Set up Python 3.11
37 | uses: actions/setup-python@v5
38 | with:
39 | python-version: 3.11
40 |
41 | - uses: astral-sh/setup-uv@v4
42 | with:
43 | enable-cache: true
44 |
45 | - name: Deploy docs
46 | run: uv run --extra docs mkdocs gh-deploy --force
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Development Seed
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/AddProcessTimeHeaderMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to add Server-Timing header with proxy processing time."""
2 |
3 | import time
4 |
5 | from fastapi import Request, Response
6 | from starlette.middleware.base import BaseHTTPMiddleware
7 |
8 | from stac_auth_proxy.utils.requests import build_server_timing_header
9 |
10 |
11 | class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware):
12 | """Middleware to add Server-Timing header with proxy processing time."""
13 |
14 | async def dispatch(self, request: Request, call_next) -> Response:
15 | """Add Server-Timing header with proxy processing time to the response."""
16 | start_time = time.perf_counter()
17 | response = await call_next(request)
18 | process_time = time.perf_counter() - start_time
19 |
20 | # Add Server-Timing header with proxy processing time
21 | response.headers["Server-Timing"] = build_server_timing_header(
22 | response.headers.get("Server-Timing"),
23 | name="proxy",
24 | dur=process_time,
25 | desc="Proxy processing time",
26 | )
27 |
28 | return response
29 |
--------------------------------------------------------------------------------
/examples/custom-integration/src/custom_integration.py:
--------------------------------------------------------------------------------
1 | """
2 | A custom integration example.
3 |
4 | In this example, we're intentionally using a functional pattern but you could also use a
5 | class like we do in the integrations found in stac_auth_proxy.filters.
6 | """
7 |
8 | from typing import Any
9 |
10 |
11 | def cql2_builder(admin_user: str):
12 | """CQL2 builder integration filter."""
13 | # NOTE: This is where you would set up things like connection pools.
14 | # NOTE: args/kwargs are passed in via environment variables.
15 |
16 | async def custom_integration_filter(ctx: dict[str, Any]) -> str:
17 | """
18 | Generate CQL2 expressions based on the request context.
19 |
20 | Returns a CQL2 expression, either as a string (cql2-text) or as a dict (cql2-json).
21 | """
22 | # NOTE: This is where you would perform a lookup from a database, API, etc.
23 | # NOTE: ctx is the request context, which includes the payload, headers, etc.
24 |
25 | if ctx["payload"] and ctx["payload"]["sub"] == admin_user:
26 | return "1=1"
27 | return "private = true"
28 |
29 | return custom_integration_filter
30 |
--------------------------------------------------------------------------------
/docs/architecture/filtering-data.md:
--------------------------------------------------------------------------------
1 | # Filtering Data
2 |
3 | > [!NOTE]
4 | >
5 | > For more information on using filters to solve authorization needs, more information can be found in the [user guide](../user-guide/record-level-auth.md).
6 |
7 | ## Example Request Flow for multi-record endpoints
8 |
9 | ```mermaid
10 | sequenceDiagram
11 | Client->>Proxy: GET /collections
12 | Note over Proxy: EnforceAuth checks credentials
13 | Note over Proxy: BuildCql2Filter creates filter
14 | Note over Proxy: ApplyCql2Filter applies filter to request
15 | Proxy->>STAC API: GET /collection?filter=(collection=landsat)
16 | STAC API->>Client: Response
17 | ```
18 |
19 | ## Example Request Flow for single-record endpoints
20 |
21 | The Filter Extension does not apply to fetching individual records. As such, we must validate the record _after_ it is returned from the upstream API but _before_ it is returned to the user:
22 |
23 | ```mermaid
24 | sequenceDiagram
25 | Client->>Proxy: GET /collections/abc123
26 | Note over Proxy: EnforceAuth checks credentials
27 | Note over Proxy: BuildCql2Filter creates filter
28 | Proxy->>STAC API: GET /collection/abc123
29 | Note over Proxy: ApplyCql2Filter validates the response
30 | STAC API->>Client: Response
31 | ```
32 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/__init__.py:
--------------------------------------------------------------------------------
1 | """Custom middleware."""
2 |
3 | from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
4 | from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
5 | from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware
6 | from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware
7 | from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware
8 | from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware
9 | from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware
10 | from .EnforceAuthMiddleware import EnforceAuthMiddleware
11 | from .ProcessLinksMiddleware import ProcessLinksMiddleware
12 | from .RemoveRootPathMiddleware import RemoveRootPathMiddleware
13 | from .UpdateOpenApiMiddleware import OpenApiMiddleware
14 |
15 | __all__ = [
16 | "AddProcessTimeHeaderMiddleware",
17 | "AuthenticationExtensionMiddleware",
18 | "Cql2ApplyFilterBodyMiddleware",
19 | "Cql2ApplyFilterQueryStringMiddleware",
20 | "Cql2BuildFilterMiddleware",
21 | "Cql2RewriteLinksFilterMiddleware",
22 | "Cql2ValidateResponseBodyMiddleware",
23 | "EnforceAuthMiddleware",
24 | "OpenApiMiddleware",
25 | "ProcessLinksMiddleware",
26 | "RemoveRootPathMiddleware",
27 | ]
28 |
--------------------------------------------------------------------------------
/helm/templates/ingress.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.ingress.enabled -}}
2 | {{- $fullName := include "stac-auth-proxy.fullname" . -}}
3 | {{- $svcPort := .Values.service.port -}}
4 | {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }}
5 | {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }}
6 | {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}}
7 | {{- end }}
8 | {{- end }}
9 | apiVersion: networking.k8s.io/v1
10 | kind: Ingress
11 | metadata:
12 | name: {{ $fullName }}
13 | labels:
14 | {{- include "stac-auth-proxy.labels" . | nindent 4 }}
15 | {{- with .Values.ingress.annotations }}
16 | annotations:
17 | {{- toYaml . | nindent 4 }}
18 | {{- end }}
19 | spec:
20 | ingressClassName: {{ .Values.ingress.className }}
21 | {{- if and .Values.ingress.tls.enabled .Values.ingress.host }}
22 | tls:
23 | - hosts:
24 | - {{ .Values.ingress.host }}
25 | secretName: {{ .Values.ingress.tls.secretName | default (printf "%s-tls" .Values.ingress.host) }}
26 | {{- end }}
27 | rules:
28 | {{- if .Values.ingress.host }}
29 | - host: {{ .Values.ingress.host }}
30 | http:
31 | paths:
32 | - path: /
33 | pathType: Prefix
34 | backend:
35 | service:
36 | name: {{ $fullName }}
37 | port:
38 | number: {{ $svcPort }}
39 | {{- end }}
40 | {{- end }}
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # https://github.com/astral-sh/uv-docker-example/blob/c16a61fb3e6ab568ac58d94b73a7d79594a5d570/Dockerfile
2 |
3 | # Build stage
4 | FROM python:3.13-slim AS builder
5 |
6 | WORKDIR /app
7 |
8 | ENV UV_COMPILE_BYTECODE=1
9 | ENV UV_LINK_MODE=copy
10 | ENV UV_PROJECT_ENVIRONMENT=/usr/local
11 |
12 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
13 |
14 | # Install the project's dependencies using the lockfile and settings
15 | RUN --mount=type=cache,target=/root/.cache/uv \
16 | --mount=type=bind,source=uv.lock,target=uv.lock \
17 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
18 | uv sync --frozen --no-install-project --no-dev
19 |
20 | # Then, add the rest of the project source code and install it
21 | # Installing separately from its dependencies allows optimal layer caching
22 | ADD . /app
23 | RUN --mount=type=cache,target=/root/.cache/uv \
24 | uv sync --frozen --no-dev
25 |
26 | # Runtime stage
27 | FROM python:3.13-slim
28 |
29 | WORKDIR /app
30 |
31 | # Copy installed packages from builder
32 | COPY --from=builder /usr/local/lib/python3.13/site-packages /usr/local/lib/python3.13/site-packages
33 | COPY --from=builder /usr/local/bin /usr/local/bin
34 |
35 | # Copy only the source code directory needed at runtime
36 | COPY --from=builder /app/src/stac_auth_proxy /app/src/stac_auth_proxy
37 |
38 |
39 | RUN useradd -m -u 1001 -s /bin/bash user && \
40 | chown -R user:user /app
41 |
42 | USER user
43 |
44 | ENV PYTHONPATH=/app/src
45 |
46 | CMD ["python", "-m", "stac_auth_proxy"]
47 |
--------------------------------------------------------------------------------
/tests/test_proxy.py:
--------------------------------------------------------------------------------
1 | """Test authentication cases for the proxy app."""
2 |
3 | from fastapi.testclient import TestClient
4 | from utils import AppFactory, get_upstream_request
5 |
6 | app_factory = AppFactory(
7 | oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration",
8 | default_public=True,
9 | public_endpoints={},
10 | private_endpoints={},
11 | )
12 |
13 |
14 | async def test_proxied_headers_no_encoding(source_api_server, mock_upstream):
15 | """Clients that don't accept encoding should not receive it."""
16 | test_app = app_factory(upstream_url=source_api_server)
17 |
18 | client = TestClient(test_app)
19 | req = client.build_request(method="GET", url="/", headers={})
20 | for h in req.headers:
21 | if h in ["accept-encoding"]:
22 | del req.headers[h]
23 | client.send(req)
24 |
25 | proxied_request = await get_upstream_request(mock_upstream)
26 | assert "accept-encoding" not in proxied_request.headers
27 |
28 |
29 | async def test_proxied_headers_with_encoding(source_api_server, mock_upstream):
30 | """Clients that do accept encoding should receive it."""
31 | test_app = app_factory(upstream_url=source_api_server)
32 |
33 | client = TestClient(test_app)
34 | req = client.build_request(
35 | method="GET", url="/", headers={"accept-encoding": "gzip"}
36 | )
37 | client.send(req)
38 |
39 | proxied_request = await get_upstream_request(mock_upstream)
40 | assert proxied_request.headers.get("accept-encoding") == "gzip"
41 |
--------------------------------------------------------------------------------
/.github/workflows/cicd.yaml:
--------------------------------------------------------------------------------
1 | name: CI/CD
2 |
3 | on:
4 | push:
5 | release:
6 |
7 | jobs:
8 | lint:
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: actions/setup-python@v3
14 | with:
15 | python-version: '3.13'
16 | - uses: astral-sh/setup-uv@v4
17 | with:
18 | enable-cache: true
19 | - uses: actions/cache@v4
20 | with:
21 | path: ~/.cache/pre-commit
22 | key: ${{ hashFiles('.pre-commit-config.yaml') }}
23 | - run: uv run pre-commit run --all-files
24 |
25 | test:
26 | runs-on: ubuntu-latest
27 |
28 | steps:
29 | - uses: actions/checkout@v4
30 | - uses: actions/setup-python@v3
31 | with:
32 | python-version: '3.13'
33 | - uses: astral-sh/setup-uv@v4
34 | with:
35 | enable-cache: true
36 | - name: Run tests with coverage
37 | run: |
38 | uv run pytest -n auto --cov=src/stac_auth_proxy --cov-report=xml --cov-report=html --cov-report=term-missing --cov-fail-under=85
39 | - name: Upload coverage reports to Codecov
40 | uses: codecov/codecov-action@v5
41 | with:
42 | files: ./coverage.xml
43 | flags: unittests
44 | fail_ci_if_error: false
45 | - name: Archive coverage reports
46 | uses: actions/upload-artifact@v4
47 | if: always()
48 | with:
49 | name: coverage-reports
50 | path: |
51 | htmlcov/
52 | coverage.xml
53 |
--------------------------------------------------------------------------------
/.github/workflows/publish-docker.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Docker image
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | env:
8 | REGISTRY: ghcr.io
9 | IMAGE_NAME: ${{ github.repository }}
10 |
11 | jobs:
12 | build-and-push:
13 | runs-on: ubuntu-latest
14 | permissions:
15 | contents: read
16 | packages: write
17 |
18 | steps:
19 | - name: Checkout repository
20 | uses: actions/checkout@v4
21 |
22 | - name: Set up Docker Buildx
23 | uses: docker/setup-buildx-action@v3
24 |
25 | - name: Log in to the Container registry
26 | uses: docker/login-action@v3
27 | with:
28 | registry: ${{ env.REGISTRY }}
29 | username: ${{ github.actor }}
30 | password: ${{ secrets.GITHUB_TOKEN }}
31 |
32 | - name: Extract metadata (tags, labels) for Docker
33 | id: meta
34 | uses: docker/metadata-action@v5
35 | with:
36 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
37 | tags: |
38 | type=raw,value=${{ github.event.release.tag_name }}
39 | type=raw,value=latest,enable=${{ github.event.release.prerelease == false }}
40 |
41 | - name: Build and push Docker image
42 | uses: docker/build-push-action@v5
43 | with:
44 | context: .
45 | push: true
46 | platforms: linux/amd64,linux/arm64
47 | tags: ${{ steps.meta.outputs.tags }}
48 | labels: ${{ steps.meta.outputs.labels }}
49 | cache-from: type=gha
50 | cache-to: type=gha,mode=max
51 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/filters/opa.py:
--------------------------------------------------------------------------------
1 | """Integration with Open Policy Agent (OPA) to generate CQL2 filters for requests to a STAC API."""
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Any
5 |
6 | import httpx
7 |
8 | from ..utils.cache import MemoryCache, get_value_by_path
9 |
10 |
11 | @dataclass
12 | class Opa:
13 | """Call Open Policy Agent (OPA) to generate CQL2 filters from request context."""
14 |
15 | host: str
16 | decision: str
17 |
18 | client: httpx.AsyncClient = field(init=False)
19 | cache: MemoryCache = field(init=False)
20 | cache_key: str = "req.headers.authorization"
21 | cache_ttl: float = 5.0
22 |
23 | def __post_init__(self):
24 | """Initialize the client."""
25 | self.client = httpx.AsyncClient(base_url=self.host)
26 | self.cache = MemoryCache(ttl=self.cache_ttl)
27 |
28 | async def __call__(self, context: dict[str, Any]) -> str:
29 | """Generate a CQL2 filter for the request."""
30 | token = get_value_by_path(context, self.cache_key)
31 | try:
32 | expr_str = self.cache[token]
33 | except KeyError:
34 | expr_str = await self._fetch(context)
35 | self.cache[token] = expr_str
36 | return expr_str
37 |
38 | async def _fetch(self, context: dict[str, Any]) -> str:
39 | """Fetch the CQL2 filter from OPA."""
40 | response = await self.client.post(
41 | f"/v1/data/{self.decision}",
42 | json={"input": context},
43 | )
44 | return response.raise_for_status().json()["result"]
45 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/handlers/swagger_ui.py:
--------------------------------------------------------------------------------
1 | """
2 | In order to allow customization fo the Swagger UI's OAuth2 configuration, we support
3 | overriding the default handler. This is useful for adding custom parameters such as
4 | `usePkceWithAuthorizationCodeGrant` or `clientId`.
5 |
6 | See:
7 | - https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
8 | """
9 |
10 | from dataclasses import dataclass, field
11 | from typing import Optional
12 |
13 | from fastapi.openapi.docs import get_swagger_ui_html
14 | from starlette.requests import Request
15 | from starlette.responses import HTMLResponse
16 |
17 |
18 | @dataclass
19 | class SwaggerUI:
20 | """Swagger UI handler."""
21 |
22 | openapi_url: str
23 | title: Optional[str] = "STAC API"
24 | init_oauth: dict = field(default_factory=dict)
25 | parameters: dict = field(default_factory=dict)
26 | oauth2_redirect_url: str = "/docs/oauth2-redirect"
27 |
28 | async def route(self, req: Request) -> HTMLResponse:
29 | """Route handler."""
30 | root_path = req.scope.get("root_path", "").rstrip("/")
31 | openapi_url = root_path + self.openapi_url
32 | oauth2_redirect_url = self.oauth2_redirect_url
33 | if oauth2_redirect_url:
34 | oauth2_redirect_url = root_path + oauth2_redirect_url
35 | return get_swagger_ui_html(
36 | openapi_url=openapi_url,
37 | title=f"{self.title} - Swagger UI",
38 | oauth2_redirect_url=oauth2_redirect_url,
39 | init_oauth=self.init_oauth,
40 | swagger_ui_parameters=self.parameters,
41 | )
42 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/utils/filters.py:
--------------------------------------------------------------------------------
1 | """Utility functions."""
2 |
3 | import json
4 | import logging
5 | from typing import Optional
6 | from urllib.parse import parse_qs
7 |
8 | from cql2 import Expr
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def append_qs_filter(qs: str, filter: Expr, filter_lang: Optional[str] = None) -> bytes:
14 | """Insert a filter expression into a query string. If a filter already exists, combine them."""
15 | qs_dict = {k: v[0] for k, v in parse_qs(qs).items()}
16 | new_qs_dict = append_body_filter(
17 | qs_dict, filter, filter_lang or qs_dict.get("filter-lang", "cql2-text")
18 | )
19 | return dict_to_query_string(new_qs_dict).encode("utf-8")
20 |
21 |
22 | def append_body_filter(
23 | body: dict, filter: Expr, filter_lang: Optional[str] = None
24 | ) -> dict:
25 | """Insert a filter expression into a request body. If a filter already exists, combine them."""
26 | cur_filter = body.get("filter")
27 | filter_lang = filter_lang or body.get("filter-lang", "cql2-json")
28 | if cur_filter:
29 | filter = filter + Expr(cur_filter)
30 | return {
31 | **body,
32 | "filter": filter.to_text() if filter_lang == "cql2-text" else filter.to_json(),
33 | "filter-lang": filter_lang,
34 | }
35 |
36 |
37 | def dict_to_query_string(params: dict) -> str:
38 | """
39 | Convert a dictionary to a query string. Dict values are converted to JSON strings,
40 | unlike the default behavior of urllib.parse.urlencode.
41 | """
42 | parts = []
43 | for key, val in params.items():
44 | if isinstance(val, (dict, list)):
45 | val = json.dumps(val, separators=(",", ":"))
46 | parts.append(f"{key}={val}")
47 | return "&".join(parts)
48 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/RemoveRootPathMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to remove ROOT_PATH from incoming requests and update links in responses."""
2 |
3 | import logging
4 | from dataclasses import dataclass
5 |
6 | from starlette.responses import Response
7 | from starlette.types import ASGIApp, Receive, Scope, Send
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | @dataclass
13 | class RemoveRootPathMiddleware:
14 | """
15 | Middleware to remove the root path of the request before it is sent to the upstream
16 | server.
17 |
18 | IMPORTANT: This middleware must be placed early in the middleware chain (ie late in
19 | the order of declaration) so that it trims the root_path from the request path before
20 | any middleware that may need to use the request path (e.g. EnforceAuthMiddleware).
21 | """
22 |
23 | app: ASGIApp
24 | root_path: str
25 |
26 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
27 | """Remove ROOT_PATH from the request path if it exists."""
28 | if scope["type"] != "http":
29 | return await self.app(scope, receive, send)
30 |
31 | # If root_path is set and path doesn't start with it, return 404
32 | if self.root_path and not scope["path"].startswith(self.root_path):
33 | response = Response("Not Found", status_code=404)
34 | logger.error(
35 | f"Root path {self.root_path!r} not found in path {scope['path']!r}"
36 | )
37 | await response(scope, receive, send)
38 | return
39 |
40 | # Remove root_path if it exists at the start of the path
41 | if scope["path"].startswith(self.root_path):
42 | scope["raw_path"] = scope["path"].encode()
43 | scope["path"] = scope["path"][len(self.root_path) :] or "/"
44 |
45 | return await self.app(scope, receive, send)
46 |
--------------------------------------------------------------------------------
/docs/overrides/partials/integrations/analytics/plausible.html:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
54 |
--------------------------------------------------------------------------------
/helm/templates/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "stac-auth-proxy.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | */}}
11 | {{- define "stac-auth-proxy.fullname" -}}
12 | {{- if .Values.fullnameOverride }}
13 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
14 | {{- else }}
15 | {{- $name := default .Chart.Name .Values.nameOverride }}
16 | {{- if contains $name .Release.Name }}
17 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
18 | {{- else }}
19 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
20 | {{- end }}
21 | {{- end }}
22 | {{- end }}
23 |
24 | {{/*
25 | Create chart name and version as used by the chart label.
26 | */}}
27 | {{- define "stac-auth-proxy.chart" -}}
28 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
29 | {{- end }}
30 |
31 | {{/*
32 | Common labels
33 | */}}
34 | {{- define "stac-auth-proxy.labels" -}}
35 | helm.sh/chart: {{ include "stac-auth-proxy.chart" . }}
36 | {{ include "stac-auth-proxy.selectorLabels" . }}
37 | {{- if .Chart.AppVersion }}
38 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
39 | {{- end }}
40 | app.kubernetes.io/managed-by: {{ .Release.Service }}
41 | {{- end }}
42 |
43 | {{/*
44 | Selector labels
45 | */}}
46 | {{- define "stac-auth-proxy.selectorLabels" -}}
47 | app.kubernetes.io/name: {{ include "stac-auth-proxy.name" . }}
48 | app.kubernetes.io/instance: {{ .Release.Name }}
49 | {{- end }}
50 |
51 | {{/*
52 | Create the name of the service account to use
53 | */}}
54 | {{- define "stac-auth-proxy.serviceAccountName" -}}
55 | {{- if .Values.serviceAccount.create }}
56 | {{- default (include "stac-auth-proxy.fullname" .) .Values.serviceAccount.name }}
57 | {{- else }}
58 | {{- default "default" .Values.serviceAccount.name }}
59 | {{- end }}
60 | {{- end }}
61 |
62 | {{/*
63 | Render env var value based on type
64 | */}}
65 | {{- define "stac-auth-proxy.envValue" -}}
66 | {{- if kindIs "string" . -}}
67 | {{- . | quote -}}
68 | {{- else -}}
69 | {{- . | toJson | quote -}}
70 | {{- end -}}
71 | {{- end -}}
72 |
--------------------------------------------------------------------------------
/helm/templates/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: {{ include "stac-auth-proxy.fullname" . }}
5 | labels:
6 | {{- include "stac-auth-proxy.labels" . | nindent 4 }}
7 | spec:
8 | replicas: {{ .Values.replicaCount }}
9 | selector:
10 | matchLabels:
11 | {{- include "stac-auth-proxy.selectorLabels" . | nindent 6 }}
12 | template:
13 | metadata:
14 | labels:
15 | {{- include "stac-auth-proxy.selectorLabels" . | nindent 8 }}
16 | spec:
17 | serviceAccountName: {{ include "stac-auth-proxy.serviceAccountName" . }}
18 | securityContext:
19 | {{- toYaml .Values.securityContext | nindent 8 }}
20 | {{- with .Values.initContainers }}
21 | initContainers:
22 | {{- toYaml . | nindent 8 }}
23 | {{- end }}
24 | containers:
25 | - name: {{ .Chart.Name }}
26 | securityContext:
27 | {{- toYaml .Values.containerSecurityContext | nindent 12 }}
28 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
29 | imagePullPolicy: {{ .Values.image.pullPolicy }}
30 | ports:
31 | - name: http
32 | containerPort: 8000
33 | protocol: TCP
34 | resources:
35 | {{- toYaml .Values.resources | nindent 12 }}
36 | env:
37 | {{- range $key, $value := .Values.env }}
38 | - name: {{ $key }}
39 | value: {{ include "stac-auth-proxy.envValue" $value }}
40 | {{- end }}
41 | {{- with .Values.extraVolumeMounts }}
42 | volumeMounts:
43 | {{- toYaml . | nindent 12 }}
44 | {{- end }}
45 |
46 | {{- with .Values.extraVolumes }}
47 | volumes:
48 | {{- toYaml . | nindent 8 }}
49 | {{- end }}
50 | {{- with .Values.nodeSelector }}
51 | nodeSelector:
52 | {{- toYaml . | nindent 8 }}
53 | {{- end }}
54 | {{- with .Values.affinity }}
55 | affinity:
56 | {{- toYaml . | nindent 8 }}
57 | {{- end }}
58 | {{- with .Values.tolerations }}
59 | tolerations:
60 | {{- toYaml . | nindent 8 }}
61 | {{- end }}
62 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to inject CQL2 filters into the query string for GET/list endpoints."""
2 |
3 | import re
4 | from dataclasses import dataclass
5 | from logging import getLogger
6 | from typing import Optional
7 |
8 | from cql2 import Expr
9 | from starlette.requests import Request
10 | from starlette.types import ASGIApp, Receive, Scope, Send
11 |
12 | from ..utils import filters
13 | from ..utils.middleware import required_conformance
14 |
15 | logger = getLogger(__name__)
16 |
17 |
18 | @required_conformance(
19 | r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
20 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
21 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
22 | )
23 | @dataclass(frozen=True)
24 | class Cql2ApplyFilterQueryStringMiddleware:
25 | """Middleware to inject CQL2 filters into the query string for GET/list endpoints."""
26 |
27 | app: ASGIApp
28 | state_key: str = "cql2_filter"
29 |
30 | single_record_endpoints = [
31 | r"^/collections/([^/]+)/items/([^/]+)$",
32 | r"^/collections/([^/]+)$",
33 | ]
34 |
35 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
36 | """Apply the CQL2 filter to the query string."""
37 | if scope["type"] != "http":
38 | return await self.app(scope, receive, send)
39 |
40 | request = Request(scope)
41 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
42 | if not cql2_filter:
43 | return await self.app(scope, receive, send)
44 |
45 | # Only handle GET requests that are not single-record endpoints
46 | if request.method != "GET":
47 | return await self.app(scope, receive, send)
48 | if any(
49 | re.match(expr, request.url.path) for expr in self.single_record_endpoints
50 | ):
51 | return await self.app(scope, receive, send)
52 |
53 | # Inject filter into query string
54 | scope = dict(scope)
55 | scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter)
56 | return await self.app(scope, receive, send)
57 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: help test test-coverage test-fast lint format clean install dev-install docs
2 |
3 | help: ## Show this help message
4 | @echo "Available commands:"
5 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
6 |
7 | install: ## Install the package
8 | uv sync
9 |
10 | dev-install: ## Install development dependencies
11 | uv sync --group dev
12 |
13 | test: ## Run tests
14 | uv run pytest
15 |
16 | test-fast: ## Run tests in parallel
17 | uv run pytest -n auto
18 |
19 | test-coverage: ## Run tests with coverage
20 | @echo "🧪 Running tests with coverage..."
21 | uv run pytest \
22 | --cov=src/stac_auth_proxy \
23 | --cov-report=term-missing \
24 | --cov-report=html \
25 | --cov-report=xml \
26 | --cov-fail-under=85 \
27 | -v
28 | @echo "✅ Coverage report generated!"
29 | @echo "📊 HTML report available at: htmlcov/index.html"
30 | @echo "📄 XML report available at: coverage.xml"
31 | @if [ "$(CI)" = "true" ]; then \
32 | echo "🚀 Running in CI environment"; \
33 | else \
34 | echo "💻 Running locally - opening HTML report..."; \
35 | if command -v open >/dev/null 2>&1; then \
36 | open htmlcov/index.html; \
37 | elif command -v xdg-open >/dev/null 2>&1; then \
38 | xdg-open htmlcov/index.html; \
39 | else \
40 | echo "Please open htmlcov/index.html in your browser to view the coverage report"; \
41 | fi; \
42 | fi
43 |
44 | lint: ## Run linting
45 | uv run pre-commit run ruff-check --all-files
46 | uv run pre-commit run mypy --all-files
47 |
48 | format: ## Format code
49 | uv run pre-commit run ruff-format --all-files
50 |
51 | clean: ## Clean up generated files
52 | rm -rf htmlcov/
53 | rm -rf .coverage
54 | rm -rf coverage.xml
55 | rm -rf .pytest_cache/
56 | rm -rf build/
57 | rm -rf dist/
58 | rm -rf *.egg-info/
59 | find . -type d -name __pycache__ -delete
60 | find . -type f -name "*.pyc" -delete
61 |
62 | ci: ## Run CI checks locally
63 | uv run pre-commit run --all-files
64 | @echo "🧪 Running tests with coverage..."
65 | uv run pytest \
66 | -n auto \
67 | --cov=src/stac_auth_proxy \
68 | --cov-report=term-missing \
69 | --cov-report=html \
70 | --cov-report=xml \
71 | --cov-fail-under=85 \
72 | -v
73 | @echo "✅ CI checks completed!"
74 |
75 | docs: ## Serve documentation locally
76 | uv sync --extra docs
77 | DYLD_FALLBACK_LIBRARY_PATH=/opt/homebrew/lib uv run mkdocs serve
--------------------------------------------------------------------------------
/tests/test_remove_root_path.py:
--------------------------------------------------------------------------------
1 | """Tests for RemoveRootPathMiddleware."""
2 |
3 | import pytest
4 | from fastapi import FastAPI
5 | from starlette.testclient import TestClient
6 | from starlette.types import Receive, Scope, Send
7 |
8 | from stac_auth_proxy.middleware.RemoveRootPathMiddleware import RemoveRootPathMiddleware
9 |
10 |
11 | class MockASGIApp:
12 | """Mock ASGI application for testing."""
13 |
14 | def __init__(self):
15 | """Initialize the mock app."""
16 | self.called = False
17 | self.scope = None
18 |
19 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
20 | """Mock ASGI call."""
21 | self.called = True
22 | self.scope = scope
23 |
24 |
25 | @pytest.mark.asyncio
26 | async def test_remove_root_path_middleware():
27 | """Test that root path is removed from request path."""
28 | mock_app = MockASGIApp()
29 | middleware = RemoveRootPathMiddleware(mock_app, root_path="/api")
30 |
31 | # Test with root path
32 | scope = {
33 | "type": "http",
34 | "path": "/api/test",
35 | "raw_path": b"/api/test",
36 | }
37 | await middleware(scope, None, None)
38 | assert mock_app.called
39 | assert mock_app.scope["path"] == "/test"
40 | assert mock_app.scope["raw_path"] == b"/api/test"
41 |
42 |
43 | @pytest.mark.asyncio
44 | async def test_remove_root_path_middleware_non_http():
45 | """Test that non-HTTP requests are passed through unchanged."""
46 | mock_app = MockASGIApp()
47 | middleware = RemoveRootPathMiddleware(mock_app, root_path="/api")
48 |
49 | scope = {
50 | "type": "websocket",
51 | "path": "/api/test",
52 | }
53 | await middleware(scope, None, None)
54 | assert mock_app.called
55 | assert mock_app.scope["path"] == "/api/test"
56 |
57 |
58 | @pytest.mark.asyncio
59 | async def test_remove_root_path_middleware_empty_path():
60 | """Test that empty path after root path removal is set to '/'."""
61 | mock_app = MockASGIApp()
62 | middleware = RemoveRootPathMiddleware(mock_app, root_path="/api")
63 |
64 | scope = {
65 | "type": "http",
66 | "path": "/api",
67 | "raw_path": b"/api",
68 | }
69 | await middleware(scope, None, None)
70 | assert mock_app.called
71 | assert mock_app.scope["path"] == "/"
72 | assert mock_app.scope["raw_path"] == b"/api"
73 |
74 |
75 | def test_remove_root_path_middleware_integration():
76 | """Test middleware integration with FastAPI."""
77 | app = FastAPI()
78 | app.add_middleware(RemoveRootPathMiddleware, root_path="/api")
79 |
80 | @app.get("/test")
81 | async def test_endpoint():
82 | return {"message": "test"}
83 |
84 | client = TestClient(app)
85 |
86 | # Test with root path
87 | response = client.get("/api/test")
88 | assert response.status_code == 200
89 | assert response.json() == {"message": "test"}
90 |
91 | # Test without root path
92 | response = client.get("/test")
93 | assert response.status_code == 404 # Should not find the endpoint
94 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities for testing."""
2 |
3 | import json
4 | from dataclasses import dataclass
5 | from typing import Callable, cast
6 | from unittest.mock import MagicMock
7 | from urllib.parse import parse_qs, unquote
8 |
9 | import httpx
10 | from httpx import Headers, Request
11 |
12 | from stac_auth_proxy import Settings, create_app
13 |
14 |
15 | class AppFactory:
16 | """Factory for creating test apps with default settings."""
17 |
18 | def __init__(self, **defaults):
19 | """Initialize the factory with default settings."""
20 | self.defaults = defaults
21 |
22 | def __call__(self, *, upstream_url, **overrides) -> Callable:
23 | """Create a new app with the given overrides."""
24 | return create_app(
25 | Settings.model_validate(
26 | {
27 | **self.defaults,
28 | **overrides,
29 | "upstream_url": upstream_url,
30 | },
31 | )
32 | )
33 |
34 |
35 | @dataclass
36 | class SingleChunkAsyncStream(httpx.AsyncByteStream):
37 | """Mock async stream that returns a single chunk of data."""
38 |
39 | body: bytes
40 |
41 | async def __aiter__(self):
42 | """Return a single chunk of data."""
43 | yield self.body
44 |
45 |
46 | def single_chunk_async_stream_response(
47 | body: bytes, status_code=200, headers={"content-type": "application/json"}
48 | ):
49 | """Create a response with a single chunk of data."""
50 | return httpx.Response(
51 | stream=SingleChunkAsyncStream(body),
52 | status_code=status_code,
53 | headers=headers,
54 | )
55 |
56 |
57 | def parse_query_string(qs: str) -> dict:
58 | """Parse a query string into a dictionary."""
59 | # Python's parse_qs will turn dicts into strings (e.g. parse_qs('foo={"x":"y"}') == {'foo': ['{"x":"y"}']})
60 | # so we need some special tooling to examine the query params and compare them to expected values
61 | parsed = parse_qs(qs)
62 |
63 | result = {}
64 | for key, value_list in parsed.items():
65 | value = value_list[0]
66 | if key == "filter" and parsed.get("filter-lang") == ["cql2-json"]:
67 | decoded_str = unquote(value)
68 | result[key] = json.loads(decoded_str)
69 | else:
70 | result[key] = unquote(value)
71 |
72 | return result
73 |
74 |
75 | async def get_upstream_request(mock_upstream: MagicMock) -> "UpstreamRequest":
76 | """Fetch the raw body and query params from the single upstream request."""
77 | assert mock_upstream.call_count == 1
78 | [request] = cast(list[Request], mock_upstream.call_args[0])
79 | req_body = request._streamed_body
80 | return UpstreamRequest(
81 | body=req_body.decode(),
82 | query_params=parse_query_string(request.url.query.decode("utf-8")),
83 | headers=request.headers,
84 | )
85 |
86 |
87 | @dataclass
88 | class UpstreamRequest:
89 | """The raw body and query params from the single upstream request."""
90 |
91 | body: str
92 | query_params: dict
93 | headers: Headers
94 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/utils/cache.py:
--------------------------------------------------------------------------------
1 | """Cache utilities."""
2 |
3 | from dataclasses import dataclass, field
4 | from time import time
5 | from typing import Any
6 |
7 | from stac_auth_proxy.utils.filters import logger
8 |
9 |
10 | @dataclass
11 | class MemoryCache:
12 | """Cache results of a method call for a given key."""
13 |
14 | ttl: float = 5.0
15 | cache: dict[tuple[Any], tuple[Any, float]] = field(default_factory=dict)
16 | _last_pruned: float = field(default_factory=time)
17 |
18 | def __getitem__(self, key: Any) -> Any:
19 | """Get a value from the cache if it is not expired."""
20 | if key not in self.cache:
21 | msg = f"{self._key_str(key)!r} not in cache."
22 | logger.debug(msg)
23 | raise KeyError(msg)
24 |
25 | result, timestamp = self.cache[key]
26 | if (time() - timestamp) > self.ttl:
27 | msg = f"{self._key_str(key)!r} in cache, but expired."
28 | del self.cache[key]
29 | logger.debug(msg)
30 | raise KeyError(f"{key} expired")
31 |
32 | logger.debug(f"{self._key_str(key)} in cache, returning cached result.")
33 | return result
34 |
35 | def __setitem__(self, key: Any, value: Any):
36 | """Set a value in the cache."""
37 | self.cache[key] = (value, time())
38 | self._prune()
39 |
40 | def __contains__(self, key: Any) -> bool:
41 | """Check if a key is in the cache and is not expired."""
42 | try:
43 | self[key]
44 | return True
45 | except KeyError:
46 | return False
47 |
48 | def get(self, key: Any) -> Any:
49 | """Get a value from the cache."""
50 | try:
51 | return self[key]
52 | except KeyError:
53 | return None
54 |
55 | def _prune(self):
56 | """Prune the cache of expired items."""
57 | if time() - self._last_pruned < self.ttl:
58 | return
59 | self.cache = {
60 | k: (v, time_entered)
61 | for k, (v, time_entered) in self.cache.items()
62 | if time_entered > (time() - self.ttl)
63 | }
64 | self._last_pruned = time()
65 |
66 | @staticmethod
67 | def _key_str(key: Any) -> str:
68 | """Get a string representation of a key."""
69 | return key if len(str(key)) < 10 else f"{str(key)[:9]}..."
70 |
71 |
72 | def get_value_by_path(obj: dict, path: str, default: Any = None) -> Any:
73 | """
74 | Get a value from a dictionary using dot notation.
75 |
76 | Args:
77 | obj: The dictionary to search in
78 | path: The dot notation path (e.g. "payload.sub")
79 | default: Default value to return if path doesn't exist
80 |
81 | Returns:
82 | -------
83 | The value at the specified path or default if path doesn't exist
84 |
85 | """
86 | try:
87 | for key in path.split("."):
88 | if obj is None:
89 | return default
90 | obj = obj.get(key, default)
91 | return obj
92 | except (AttributeError, KeyError, TypeError):
93 | return default
94 |
--------------------------------------------------------------------------------
/helm/values.yaml:
--------------------------------------------------------------------------------
1 | # Default values for stac-auth-proxy
2 |
3 | replicaCount: 1
4 |
5 | image:
6 | repository: ghcr.io/developmentseed/stac-auth-proxy
7 | pullPolicy: IfNotPresent
8 | tag: "latest"
9 |
10 | service:
11 | type: ClusterIP
12 | port: 8000
13 |
14 | ingress:
15 | enabled: true
16 | className: "nginx"
17 | annotations:
18 | cert-manager.io/cluster-issuer: "letsencrypt-prod"
19 | host: "stac-proxy.example.com" # This should be overridden in production
20 | tls:
21 | enabled: true
22 | secretName: "" # If empty, will be auto-generated as "{host}-tls"
23 |
24 | resources:
25 | limits:
26 | cpu: 500m
27 | memory: 512Mi
28 | requests:
29 | cpu: 200m
30 | memory: 256Mi
31 |
32 | # Pod-level security context
33 | securityContext:
34 | runAsNonRoot: true
35 | runAsUser: 1000
36 | runAsGroup: 1000
37 |
38 | # Container-level security context
39 | containerSecurityContext:
40 | allowPrivilegeEscalation: false
41 | capabilities:
42 | drop:
43 | - ALL
44 |
45 | nodeSelector: {}
46 | tolerations: []
47 | affinity: {}
48 |
49 | # Additional volumes to mount
50 | extraVolumes: []
51 | # Example:
52 | # extraVolumes:
53 | # - name: filters
54 | # configMap:
55 | # name: stac-auth-proxy-filters
56 |
57 | # Additional volume mounts for the container
58 | extraVolumeMounts: []
59 | # Example:
60 | # extraVolumeMounts:
61 | # - name: filters
62 | # mountPath: /app/src/stac_auth_proxy/custom_filters.py
63 | # subPath: custom_filters.py
64 | # readOnly: true
65 |
66 | # Init containers to run before the main container starts
67 | # initContainers: []
68 | # Example:
69 | # initContainers:
70 | # - name: wait-for-oidc
71 | # image: busybox:1.35
72 | # command: ['sh', '-c', 'until nc -z oidc-server 8080; do sleep 2; done']
73 |
74 | # Environment variables for the application
75 | env:
76 | # Required configuration
77 | UPSTREAM_URL: "" # STAC API URL
78 | OIDC_DISCOVERY_URL: "" # OpenID Connect discovery URL
79 |
80 | # Optional configuration
81 | WAIT_FOR_UPSTREAM: true
82 | HEALTHZ_PREFIX: "/healthz"
83 | OIDC_DISCOVERY_INTERNAL_URL: ""
84 | DEFAULT_PUBLIC: false
85 | PRIVATE_ENDPOINTS: |
86 | {
87 | "^/collections$": ["POST"],
88 | "^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
89 | "^/collections/([^/]+)/items$": ["POST"],
90 | "^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
91 | "^/collections/([^/]+)/bulk_items$": ["POST"]
92 | }
93 | PUBLIC_ENDPOINTS: |
94 | {
95 | "^/api.html$": ["GET"],
96 | "^/api$": ["GET"],
97 | "^/docs/oauth2-redirect": ["GET"],
98 | "^/healthz": ["GET"]
99 | }
100 |
101 |
102 |
103 | serviceAccount:
104 | # Specifies whether a service account should be created
105 | create: true
106 | # Annotations to add to the service account
107 | annotations: {}
108 | # The name of the service account to use.
109 | # If not set and create is true, a name is generated using the fullname template
110 | name: ""
111 | # Image pull secrets to add to the service account
112 | imagePullSecrets: []
113 | # - name: my-registry-secret
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
stac auth proxy
3 |
Reverse proxy to apply auth*n to your STAC API.
4 |
5 |
6 | ---
7 |
8 | [![PyPI - Version][pypi-version-badge]][pypi-link]
9 | [![GHCR - Version][ghcr-version-badge]][ghcr-link]
10 | [![GHCR - Size][ghcr-size-badge]][ghcr-link]
11 | [![codecov][codecov-badge]][codecov-link]
12 | [![Tests][tests-badge]][tests-link]
13 |
14 | STAC Auth Proxy is a proxy API that mediates between the client and your internally accessible STAC API to provide flexible authentication, authorization, and content-filtering mechanisms.
15 |
16 | > [!IMPORTANT]
17 | >
18 | > **We would :heart: to hear from you!**
19 | > Please [join the discussion](https://github.com/developmentseed/eoAPI/discussions/209) and let us know how you're using eoAPI! This helps us improve the project for you and others.
20 | > If you prefer to remain anonymous, you can email us at eoapi@developmentseed.org, and we'll be happy to post a summary on your behalf.
21 |
22 | ## ✨Features✨
23 |
24 | - **🔐 Authentication:** Apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) token validation and optional scope checks to specified endpoints and methods
25 | - **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on request context (e.g. user role)
26 | - **🤝 External Policy Integration:** Integrate with external systems (e.g. [Open Policy Agent (OPA)](https://www.openpolicyagent.org/)) to generate CQL2 filters dynamically from policy decisions
27 | - **🧩 Authentication Extension:** Add the [Authentication Extension](https://github.com/stac-extensions/authentication) to API responses to expose auth-related metadata
28 | - **📘 OpenAPI Augmentation:** Enhance the [OpenAPI spec](https://swagger.io/specification/) with security details to keep auto-generated docs and UIs (e.g., [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
29 | - **🗜️ Response Compression:** Optimize response sizes using [`starlette-cramjam`](https://github.com/developmentseed/starlette-cramjam/)
30 |
31 | ## Documentation
32 |
33 | [Full documentation is available on the website](https://developmentseed.org/stac-auth-proxy).
34 |
35 | Head to [Getting Started](https://developmentseed.org/stac-auth-proxy/user-guide/getting-started/) to dig in.
36 |
37 | [pypi-version-badge]: https://badge.fury.io/py/stac-auth-proxy.svg
38 | [pypi-link]: https://pypi.org/project/stac-auth-proxy/
39 | [ghcr-version-badge]: https://ghcr-badge.egpl.dev/developmentseed/stac-auth-proxy/latest_tag?color=%2344cc11&ignore=latest&label=image+version&trim=
40 | [ghcr-size-badge]: https://ghcr-badge.egpl.dev/developmentseed/stac-auth-proxy/size?color=%2344cc11&tag=latest&label=image+size&trim=
41 | [ghcr-link]: https://github.com/developmentseed/stac-auth-proxy/pkgs/container/stac-auth-proxy
42 | [codecov-badge]: https://codecov.io/gh/developmentseed/stac-auth-proxy/branch/main/graph/badge.svg
43 | [codecov-link]: https://codecov.io/gh/developmentseed/stac-auth-proxy
44 | [tests-badge]: https://github.com/developmentseed/stac-auth-proxy/actions/workflows/cicd.yaml/badge.svg
45 | [tests-link]: https://github.com/developmentseed/stac-auth-proxy/actions/workflows/cicd.yaml
46 |
--------------------------------------------------------------------------------
/docs/user-guide/tips.md:
--------------------------------------------------------------------------------
1 | # Tips
2 |
3 | ## CORS
4 |
5 | The STAC Auth Proxy does not modify the [CORS response headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS#the_http_response_headers) from the upstream STAC API. All CORS configuration must be handled by the upstream API.
6 |
7 | Because the STAC Auth Proxy introduces authentication, the upstream API’s CORS settings may need adjustment to support credentials. In most cases, this means:
8 |
9 | - [`Access-Control-Allow-Credentials`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Credentials) must be `true`
10 | - [`Access-Control-Allow-Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin) must _not_ be `*`[^CORSNotSupportingCredentials]
11 |
12 | [^CORSNotSupportingCredentials]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS/Errors/CORSNotSupportingCredentials
13 |
14 | ## Root Paths
15 |
16 | The proxy can be optionally served from a non-root path (e.g., `/api/v1`). Additionally, the proxy can optionally proxy requests to an upstream API served from a non-root path (e.g., `/stac`). To handle this, the proxy will:
17 |
18 | - Remove the `ROOT_PATH` from incoming requests before forwarding to the upstream API
19 | - Remove the proxy's prefix from all links in STAC API responses
20 | - Add the `ROOT_PATH` prefix to all links in STAC API responses
21 | - Update the OpenAPI specification to include the `ROOT_PATH` in the servers field
22 | - Handle requests that don't match the `ROOT_PATH` with a 404 response
23 |
24 | ## Non-OIDC Workaround
25 |
26 | If the upstream server utilizes RS256 JWTs but does not utilize a proper OIDC server, the proxy can be configured to work around this by setting the `OIDC_DISCOVERY_URL` to a statically-hosted OIDC discovery document that points to a valid JWKS endpoint.
27 |
28 | ## Swagger UI Direct JWT Input
29 |
30 | Rather than performing the login flow, the Swagger UI can be configured to accept direct JWT as input with the the following configuration:
31 |
32 | ```sh
33 | OPENAPI_AUTH_SCHEME_NAME=jwtAuth
34 | OPENAPI_AUTH_SCHEME_OVERRIDE='{
35 | "type": "http",
36 | "scheme": "bearer",
37 | "bearerFormat": "JWT",
38 | "description": "Paste your raw JWT here. This API uses Bearer token authorization."
39 | }'
40 | ```
41 |
42 | ## Non-proxy Configuration
43 |
44 | While the STAC Auth Proxy is designed to work out-of-the-box as an application, it might not address every projects needs. When the need for customization arises, the codebase can instead be treated as a library of components that can be used to augment a FastAPI server.
45 |
46 | This may look something like the following:
47 |
48 | ```py
49 | from fastapi import FastAPI
50 | from stac_fastapi.api.app import StacApi
51 | from stac_auth_proxy import configure_app, Settings as StacAuthSettings
52 |
53 | # Create Auth Settings
54 | auth_settings = StacAuthSettings(
55 | upstream_url='https://stac-server', # Dummy value, we don't make use of this value in non-proxy mode
56 | oidc_discovery_url='https://auth-server/.well-known/openid-configuration',
57 | )
58 |
59 | # Setup App
60 | app = FastAPI( ... )
61 |
62 | # Apply STAC Auth Proxy middleware
63 | configure_app(app, auth_settings)
64 |
65 | # Setup STAC API
66 | api = StacApi( app, ... )
67 | ```
68 |
69 | > [!IMPORTANT]
70 | > Avoid using `build_lifespan()` when operating in non-proxy mode, as we are unable to check for the non-existent upstream API.
71 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to add auth information to the OpenAPI spec served by upstream API."""
2 |
3 | import re
4 | from dataclasses import dataclass
5 | from typing import Any, Optional
6 |
7 | from starlette.datastructures import Headers
8 | from starlette.requests import Request
9 | from starlette.types import ASGIApp, Scope
10 |
11 | from ..config import EndpointMethods
12 | from ..utils.middleware import JsonResponseMiddleware
13 | from ..utils.requests import find_match
14 |
15 |
16 | @dataclass(frozen=True)
17 | class OpenApiMiddleware(JsonResponseMiddleware):
18 | """Middleware to add the OpenAPI spec to the response."""
19 |
20 | app: ASGIApp
21 | openapi_spec_path: str
22 | oidc_discovery_url: str
23 | private_endpoints: EndpointMethods
24 | public_endpoints: EndpointMethods
25 | default_public: bool
26 | root_path: str = ""
27 | auth_scheme_name: str = "oidcAuth"
28 | auth_scheme_override: Optional[dict] = None
29 |
30 | json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"
31 |
32 | def should_transform_response(self, request: Request, scope: Scope) -> bool:
33 | """Only transform responses for the OpenAPI spec path."""
34 | return (
35 | all(
36 | re.match(expr, val)
37 | for expr, val in [
38 | (self.openapi_spec_path, request.url.path),
39 | (
40 | self.json_content_type_expr,
41 | Headers(scope=scope).get("content-type", ""),
42 | ),
43 | ]
44 | )
45 | and 200 <= scope["status"] < 300
46 | )
47 |
48 | def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
49 | """Augment the OpenAPI spec with auth information."""
50 | # Remove any existing servers field from upstream API
51 | # This ensures we don't have conflicting server declarations
52 | if "servers" in data:
53 | del data["servers"]
54 |
55 | # Add servers field with root path if root_path is set
56 | if self.root_path:
57 | data["servers"] = [{"url": self.root_path}]
58 |
59 | # Add security scheme
60 | components = data.setdefault("components", {})
61 | securitySchemes = components.setdefault("securitySchemes", {})
62 | securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
63 | "type": "openIdConnect",
64 | "openIdConnectUrl": self.oidc_discovery_url,
65 | }
66 |
67 | # Add security to private endpoints
68 | for path, method_config in data["paths"].items():
69 | for method, config in method_config.items():
70 | if method == "options":
71 | # OPTIONS requests are not authenticated, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials
72 | continue
73 | match = find_match(
74 | path,
75 | method,
76 | self.private_endpoints,
77 | self.public_endpoints,
78 | self.default_public,
79 | )
80 | if match.is_private:
81 | config.setdefault("security", []).append(
82 | {self.auth_scheme_name: match.required_scopes}
83 | )
84 | return data
85 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | authors = [{name = "Anthony Lukach", email = "anthonylukach@gmail.com"}]
3 | classifiers = [
4 | "Programming Language :: Python :: 3",
5 | "Programming Language :: Python :: 3.8",
6 | "License :: OSI Approved :: MIT License",
7 | ]
8 | dependencies = [
9 | "boto3>=1.37.16",
10 | "brotli>=1.1.0",
11 | "cql2>=0.4.2",
12 | "cryptography>=44.0.1",
13 | "fastapi>=0.115.5",
14 | "httpx[http2]>=0.28.0",
15 | "jinja2>=3.1.4",
16 | "pydantic-settings>=2.6.1",
17 | "pyjwt>=2.10.1",
18 | "starlette-cramjam>=0.4.0",
19 | "uvicorn>=0.32.1",
20 | ]
21 | description = "STAC authentication proxy with FastAPI"
22 | keywords = ["STAC", "FastAPI", "Authentication", "Proxy"]
23 | license = {file = "LICENSE"}
24 | name = "stac-auth-proxy"
25 | readme = "README.md"
26 | requires-python = ">=3.10"
27 | version = "0.11.0"
28 |
29 | [project.optional-dependencies]
30 | docs = [
31 | "griffe-fieldz>=0.3.0",
32 | "griffe-inherited-docstrings>=1.1.1",
33 | "markdown-gfm-admonition>=0.1.1",
34 | "mkdocs>=1.6.1",
35 | "mkdocs-api-autonav>=0.3.0",
36 | "mkdocs-material[imaging]>=9.6.16",
37 | "mkdocstrings[python]>=0.30.0",
38 | ]
39 | lambda = [
40 | "mangum>=0.19.0",
41 | ]
42 |
43 | [tool.coverage.run]
44 | branch = true
45 | source = ["src/stac_auth_proxy"]
46 | omit = [
47 | "*/tests/*",
48 | "*/test_*",
49 | "*/__pycache__/*",
50 | "*/venv/*",
51 | "*/build/*",
52 | "*/dist/*",
53 | "*/htmlcov/*",
54 | "*/lambda.py", # Lambda entry point not tested in unit tests
55 | ]
56 |
57 | [tool.coverage.report]
58 | exclude_lines = [
59 | "pragma: no cover",
60 | "def __repr__",
61 | "if self.debug:",
62 | "if settings.DEBUG",
63 | "raise AssertionError",
64 | "raise NotImplementedError",
65 | "if 0:",
66 | "if __name__ == .__main__.:",
67 | "class .*\\bProtocol\\):",
68 | "@(abc\\.)?abstractmethod",
69 | ]
70 |
71 | [tool.coverage.html]
72 | directory = "htmlcov"
73 |
74 | [tool.coverage.xml]
75 | output = "coverage.xml"
76 |
77 | [tool.isort]
78 | known_first_party = ["stac_auth_proxy"]
79 | profile = "black"
80 |
81 | [tool.ruff.lint]
82 | ignore = ["E501", "D203", "D205", "D212"]
83 | select = ["D", "E", "F"]
84 |
85 | [build-system]
86 | build-backend = "hatchling.build"
87 | requires = ["hatchling>=1.12.0"]
88 |
89 | [dependency-groups]
90 | dev = [
91 | "jwcrypto>=1.5.6",
92 | "mypy>=1.3.0",
93 | "pre-commit>=3.5.0",
94 | "pytest-asyncio>=0.25.1",
95 | "pytest-cov>=5.0.0",
96 | "pytest-xdist>=3.6.1",
97 | "pytest>=8.3.3",
98 | "ruff>=0.0.238",
99 | "starlette-cramjam>=0.4.0",
100 | "types-simplejson",
101 | "types-attrs",
102 | ]
103 |
104 | [tool.pytest.ini_options]
105 | asyncio_default_fixture_loop_scope = "function"
106 | asyncio_mode = "auto"
107 | testpaths = ["tests"]
108 | python_files = ["test_*.py", "*_test.py"]
109 | python_classes = ["Test*"]
110 | python_functions = ["test_*"]
111 | addopts = [
112 | "--strict-markers",
113 | "--strict-config",
114 | "--verbose",
115 | "--tb=short",
116 | "--cov=src/stac_auth_proxy",
117 | "--cov-report=term-missing",
118 | "--cov-report=html",
119 | "--cov-report=xml",
120 | "--cov-fail-under=85",
121 | ]
122 | markers = [
123 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
124 | "integration: marks tests as integration tests",
125 | "unit: marks tests as unit tests",
126 | ]
127 |
--------------------------------------------------------------------------------
/tests/test_cache.py:
--------------------------------------------------------------------------------
1 | """Tests for cache utilities."""
2 |
3 | from unittest.mock import patch
4 |
5 | import pytest
6 |
7 | from stac_auth_proxy.utils.cache import MemoryCache, get_value_by_path
8 |
9 |
10 | def test_memory_cache_basic_operations():
11 | """Test basic cache operations."""
12 | cache = MemoryCache(ttl=5.0) # 5 second TTL
13 | key = "test_key"
14 | value = "test_value"
15 |
16 | # Test setting and getting a value
17 | cache[key] = value
18 | assert cache[key] == value
19 | assert key in cache
20 |
21 | # Test getting non-existent key
22 | with pytest.raises(KeyError):
23 | _ = cache["non_existent"]
24 |
25 | # Test get() method
26 | assert cache.get(key) == value
27 | assert cache.get("non_existent") is None
28 |
29 |
30 | def test_memory_cache_expiration():
31 | """Test cache expiration."""
32 | cache = MemoryCache(ttl=5.0) # 5 second TTL
33 | key = "test_key"
34 | value = "test_value"
35 |
36 | # Set initial time
37 | with patch("stac_auth_proxy.utils.cache.time") as mock_time:
38 | mock_time.return_value = 1000.0
39 | cache[key] = value
40 | assert cache[key] == value
41 |
42 | # Advance time past TTL
43 | mock_time.return_value = 1006.0 # 6 seconds later
44 |
45 | # Test expired key
46 | with pytest.raises(KeyError):
47 | cache[key]
48 |
49 | # Test contains after expiration
50 | assert key not in cache
51 |
52 |
53 | def test_memory_cache_pruning():
54 | """Test cache pruning."""
55 | cache = MemoryCache(ttl=5.0) # 5 second TTL
56 | key1 = "key1"
57 | key2 = "key2"
58 | value = "test_value"
59 |
60 | with patch("stac_auth_proxy.utils.cache.time") as mock_time:
61 | # Set initial time
62 | mock_time.return_value = 1000.0
63 | cache[key1] = value
64 | cache[key2] = value
65 |
66 | # Advance time past TTL
67 | mock_time.return_value = 1006.0 # 6 seconds later
68 |
69 | # Force pruning by adding a new item
70 | cache["key3"] = value
71 |
72 | # Check that expired items were pruned
73 | assert key1 not in cache
74 | assert key2 not in cache
75 | assert "key3" in cache
76 |
77 |
78 | def test_memory_cache_key_str():
79 | """Test key string representation."""
80 | cache = MemoryCache()
81 |
82 | # Test short key
83 | short_key = "123"
84 | assert cache._key_str(short_key) == short_key
85 |
86 | # Test long key
87 | long_key = "1234567890"
88 | assert cache._key_str(long_key) == "123456789..."
89 |
90 |
91 | @pytest.mark.parametrize(
92 | "obj, path, default, expected",
93 | [
94 | # Basic path
95 | ({"a": {"b": 1}}, "a.b", None, 1),
96 | # Nested path
97 | ({"a": {"b": {"c": 2}}}, "a.b.c", None, 2),
98 | # Non-existent path
99 | ({"a": {"b": 1}}, "a.c", None, None),
100 | # Default value
101 | ({"a": {"b": 1}}, "a.c", "default", "default"),
102 | # None in path
103 | ({"a": None}, "a.b", None, None),
104 | # Empty path
105 | ({"a": 1}, "", None, None),
106 | # Complex object
107 | ({"a": {"b": [1, 2, 3]}}, "a.b", None, [1, 2, 3]),
108 | ],
109 | )
110 | def test_get_value_by_path(obj, path, default, expected):
111 | """Test getting values by path."""
112 | assert get_value_by_path(obj, path, default) == expected
113 |
--------------------------------------------------------------------------------
/helm/templates/NOTES.txt:
--------------------------------------------------------------------------------
1 | Thank you for installing {{ .Chart.Name }}.
2 |
3 | Your STAC Auth Proxy has been deployed with the following configuration:
4 |
5 | 1. Application Access:
6 | {{- if .Values.ingress.enabled }}
7 | {{- if .Values.ingress.host }}
8 | Your proxy is available at:
9 | {{- if .Values.ingress.tls.enabled }}
10 | https://{{ .Values.ingress.host }}
11 | {{- else }}
12 | http://{{ .Values.ingress.host }}
13 | {{- end }}
14 | {{- end }}
15 | {{- else if contains "NodePort" .Values.service.type }}
16 | Get the application URL by running these commands:
17 | export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "stac-auth-proxy.fullname" . }})
18 | export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}")
19 | echo http://$NODE_IP:$NODE_PORT
20 | {{- else if contains "LoadBalancer" .Values.service.type }}
21 | Get the application URL by running these commands:
22 | NOTE: It may take a few minutes for the LoadBalancer IP to be available.
23 | You can watch the status by running:
24 | kubectl get svc --namespace {{ .Release.Namespace }} {{ include "stac-auth-proxy.fullname" . }} -w
25 |
26 | Once ready, get the external IP/hostname with:
27 | export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "stac-auth-proxy.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}")
28 | echo http://$SERVICE_IP:{{ .Values.service.port }}
29 | {{- else }}
30 | The service is accessible within the cluster at:
31 | {{ include "stac-auth-proxy.fullname" . }}.{{ .Release.Namespace }}.svc.cluster.local:{{ .Values.service.port }}
32 | {{- end }}
33 |
34 | 2. Configuration Details:
35 | - Upstream STAC API: {{ .Values.env.UPSTREAM_URL }}
36 | - OIDC Discovery URL: {{ .Values.env.OIDC_DISCOVERY_URL }}
37 | - Health Check Endpoint: {{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }}
38 | - Default Public Access: {{ .Values.env.DEFAULT_PUBLIC | default "false" }}
39 |
40 | 3. Verify the deployment:
41 | kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "stac-auth-proxy.name" . }},app.kubernetes.io/instance={{ .Release.Name }}"
42 |
43 | 4. View the logs:
44 | kubectl logs --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "stac-auth-proxy.name" . }},app.kubernetes.io/instance={{ .Release.Name }}"
45 |
46 | 5. Health check:
47 | {{- if .Values.ingress.enabled }}
48 | {{- if .Values.ingress.host }}
49 | {{- if .Values.ingress.tls.enabled }}
50 | curl https://{{ .Values.ingress.host }}{{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }}
51 | {{- else }}
52 | curl http://{{ .Values.ingress.host }}{{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }}
53 | {{- end }}
54 | {{- end }}
55 | {{- else }}
56 | kubectl port-forward --namespace {{ .Release.Namespace }} service/{{ include "stac-auth-proxy.fullname" . }} 8000:{{ .Values.service.port }}
57 | curl http://localhost:8000{{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }}
58 | {{- end }}
59 |
60 | For more information about STAC Auth Proxy, please visit:
61 | https://github.com/developmentseed/stac-auth-proxy
62 |
63 | {{- if or (not .Values.env.UPSTREAM_URL) (not .Values.env.OIDC_DISCOVERY_URL) }}
64 | WARNING: Some required configuration values are not set. Please ensure you have configured:
65 | {{- if not .Values.env.UPSTREAM_URL }}
66 | - env.UPSTREAM_URL
67 | {{- end }}
68 | {{- if not .Values.env.OIDC_DISCOVERY_URL }}
69 | - env.OIDC_DISCOVERY_URL
70 | {{- end }}
71 | {{- end }}
--------------------------------------------------------------------------------
/tests/test_filters_opa.py:
--------------------------------------------------------------------------------
1 | """Test OPA filter integration."""
2 |
3 | from unittest.mock import AsyncMock, MagicMock, patch
4 |
5 | import pytest
6 | from httpx import AsyncClient, Response
7 |
8 | from stac_auth_proxy.filters.opa import Opa
9 |
10 |
11 | @pytest.fixture
12 | def opa_filter_factory():
13 | """Create an OPA instance for testing."""
14 | return Opa(host="http://localhost:8181", decision="stac/filter")
15 |
16 |
17 | @pytest.fixture
18 | def mock_opa_response():
19 | """Create a mock httpx Response."""
20 | response = MagicMock(spec=Response)
21 | response.json.return_value = {"result": "collection = 'test'"}
22 | response.raise_for_status.return_value = response
23 | return response
24 |
25 |
26 | @pytest.mark.asyncio
27 | async def test_opa_initialization(opa_filter_factory):
28 | """Test OPA initialization."""
29 | assert opa_filter_factory.host == "http://localhost:8181"
30 | assert opa_filter_factory.decision == "stac/filter"
31 | assert opa_filter_factory.cache_key == "req.headers.authorization"
32 | assert opa_filter_factory.cache_ttl == 5.0
33 | assert isinstance(opa_filter_factory.client, AsyncClient)
34 | assert opa_filter_factory.cache is not None
35 |
36 |
37 | @pytest.mark.asyncio
38 | async def test_opa_cache_hit(opa_filter_factory, mock_opa_response):
39 | """Test OPA cache hit behavior."""
40 | context = {"req": {"headers": {"authorization": "test-token"}}}
41 |
42 | # Mock the OPA response
43 | with patch.object(
44 | opa_filter_factory.client, "post", new_callable=AsyncMock
45 | ) as mock_post:
46 | mock_post.return_value = mock_opa_response
47 |
48 | # First call should hit OPA
49 | result = await opa_filter_factory(context)
50 | assert result == "collection = 'test'"
51 | assert mock_post.call_count == 1
52 |
53 | # Second call should use cache
54 | result = await opa_filter_factory(context)
55 | assert result == "collection = 'test'"
56 | assert mock_post.call_count == 1 # Still 1, no new call made
57 |
58 |
59 | @pytest.mark.asyncio
60 | async def test_opa_cache_miss(opa_filter_factory, mock_opa_response):
61 | """Test OPA cache miss behavior."""
62 | context = {"req": {"headers": {"authorization": "test-token"}}}
63 |
64 | with patch.object(
65 | opa_filter_factory.client, "post", new_callable=AsyncMock
66 | ) as mock_post:
67 | mock_post.return_value = mock_opa_response
68 |
69 | # First call with token1
70 | result = await opa_filter_factory(context)
71 | assert result == "collection = 'test'"
72 | assert mock_post.call_count == 1
73 |
74 | # Call with different token should miss cache
75 | context["req"]["headers"]["authorization"] = "different-token"
76 | result = await opa_filter_factory(context)
77 | assert result == "collection = 'test'"
78 | assert mock_post.call_count == 2 # New call made
79 |
80 |
81 | @pytest.mark.asyncio
82 | async def test_opa_error_handling(opa_filter_factory):
83 | """Test OPA error handling."""
84 | context = {"req": {"headers": {"authorization": "test-token"}}}
85 |
86 | with patch.object(
87 | opa_filter_factory.client, "post", new_callable=AsyncMock
88 | ) as mock_post:
89 | # Create a mock response that raises an exception on raise_for_status
90 | error_response = MagicMock(spec=Response)
91 | error_response.raise_for_status.side_effect = Exception("Internal server error")
92 | mock_post.return_value = error_response
93 |
94 | with pytest.raises(Exception):
95 | await opa_filter_factory(context)
96 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests."""
2 |
3 | import json
4 | from dataclasses import dataclass
5 | from logging import getLogger
6 | from typing import Optional
7 |
8 | from cql2 import Expr
9 | from starlette.requests import Request
10 | from starlette.types import ASGIApp, Receive, Scope, Send
11 |
12 | from ..utils import filters
13 | from ..utils.middleware import required_conformance
14 |
15 | logger = getLogger(__name__)
16 |
17 |
18 | @required_conformance(
19 | r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
20 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
21 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
22 | )
23 | @dataclass(frozen=True)
24 | class Cql2ApplyFilterBodyMiddleware:
25 | """Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests."""
26 |
27 | app: ASGIApp
28 | state_key: str = "cql2_filter"
29 |
30 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
31 | """Apply the CQL2 filter to the request body."""
32 | if scope["type"] != "http":
33 | return await self.app(scope, receive, send)
34 |
35 | request = Request(scope)
36 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
37 | if not cql2_filter:
38 | return await self.app(scope, receive, send)
39 |
40 | if request.method not in ["POST", "PUT", "PATCH"]:
41 | return await self.app(scope, receive, send)
42 |
43 | body = b""
44 | more_body = True
45 | while more_body:
46 | message = await receive()
47 | if message["type"] == "http.request":
48 | body += message.get("body", b"")
49 | more_body = message.get("more_body", False)
50 |
51 | try:
52 | body_json = json.loads(body) if body else {}
53 | except json.JSONDecodeError:
54 | logger.warning("Failed to parse request body as JSON")
55 | from starlette.responses import JSONResponse
56 |
57 | response = JSONResponse(
58 | {
59 | "code": "ParseError",
60 | "description": "Request body must be valid JSON.",
61 | },
62 | status_code=400,
63 | )
64 | await response(scope, receive, send)
65 | return
66 |
67 | if not isinstance(body_json, dict):
68 | logger.warning("Request body must be a JSON object")
69 | from starlette.responses import JSONResponse
70 |
71 | response = JSONResponse(
72 | {
73 | "code": "TypeError",
74 | "description": "Request body must be a JSON object.",
75 | },
76 | status_code=400,
77 | )
78 | await response(scope, receive, send)
79 | return
80 |
81 | new_body = json.dumps(
82 | filters.append_body_filter(body_json, cql2_filter)
83 | ).encode("utf-8")
84 |
85 | # Patch content-length in the headers
86 | headers = dict(scope["headers"])
87 | headers[b"content-length"] = str(len(new_body)).encode("latin1")
88 | scope = dict(scope)
89 | scope["headers"] = list(headers.items())
90 |
91 | async def new_receive():
92 | return {
93 | "type": "http.request",
94 | "body": new_body,
95 | "more_body": False,
96 | }
97 |
98 | await self.app(scope, new_receive, send)
99 |
--------------------------------------------------------------------------------
/tests/test_defaults.py:
--------------------------------------------------------------------------------
1 | """Basic test cases for the proxy app."""
2 |
3 | import pytest
4 | from fastapi.testclient import TestClient
5 | from utils import AppFactory
6 |
7 | app_factory = AppFactory(
8 | oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration"
9 | )
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "path,method,expected_status",
14 | [
15 | ("/", "GET", 200),
16 | ("/conformance", "GET", 200),
17 | ("/queryables", "GET", 200),
18 | ("/search", "GET", 200),
19 | ("/search", "POST", 200),
20 | ("/collections", "GET", 200),
21 | ("/collections", "POST", 401),
22 | ("/collections/example-collection", "GET", 200),
23 | ("/collections/example-collection", "PUT", 401),
24 | ("/collections/example-collection", "DELETE", 401),
25 | ("/collections/example-collection/items", "GET", 200),
26 | ("/collections/example-collection/items", "POST", 401),
27 | ("/collections/example-collection/items/example-item", "GET", 200),
28 | ("/collections/example-collection/items/example-item", "PUT", 401),
29 | ("/collections/example-collection/items/example-item", "DELETE", 401),
30 | ("/collections/example-collection/bulk_items", "POST", 401),
31 | ("/api.html", "GET", 200),
32 | ("/api", "GET", 200),
33 | ],
34 | )
35 | def test_default_public_true(source_api_server, path, method, expected_status):
36 | """
37 | When default_public=true and private_endpoints are set, all endpoints should be
38 | public except for transaction endpoints.
39 | """
40 | test_app = app_factory(
41 | upstream_url=source_api_server,
42 | public_endpoints={},
43 | private_endpoints={
44 | r"^/collections$": ["POST"],
45 | r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
46 | r"^/collections/([^/]+)/items$": ["POST"],
47 | r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
48 | r"^/collections/([^/]+)/bulk_items$": ["POST"],
49 | },
50 | default_public=True,
51 | )
52 | client = TestClient(test_app)
53 | response = client.request(method=method, url=path)
54 | assert response.status_code == expected_status
55 |
56 |
57 | @pytest.mark.parametrize(
58 | "path,method,expected_status",
59 | [
60 | ("/", "GET", 401),
61 | ("/conformance", "GET", 401),
62 | ("/queryables", "GET", 401),
63 | ("/search", "GET", 401),
64 | ("/search", "POST", 401),
65 | ("/collections", "GET", 401),
66 | ("/collections", "POST", 401),
67 | ("/collections/example-collection", "GET", 401),
68 | ("/collections/example-collection", "PUT", 401),
69 | ("/collections/example-collection", "DELETE", 401),
70 | ("/collections/example-collection/items", "GET", 401),
71 | ("/collections/example-collection/items", "POST", 401),
72 | ("/collections/example-collection/items/example-item", "GET", 401),
73 | ("/collections/example-collection/items/example-item", "PUT", 401),
74 | ("/collections/example-collection/items/example-item", "DELETE", 401),
75 | ("/collections/example-collection/bulk_items", "POST", 401),
76 | ("/api.html", "GET", 200),
77 | ("/api", "GET", 200),
78 | ],
79 | )
80 | def test_default_public_false(source_api_server, path, method, expected_status):
81 | """
82 | When default_public=false and private_endpoints aren't set, all endpoints should be
83 | public except for transaction endpoints.
84 | """
85 | test_app = app_factory(
86 | upstream_url=source_api_server,
87 | public_endpoints={"/api.html": ["GET"], "/api": ["GET"]},
88 | private_endpoints={},
89 | default_public=False,
90 | )
91 | client = TestClient(test_app)
92 | response = client.request(method=method, url=path)
93 | assert response.status_code == expected_status
94 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | .pgdata
164 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to add auth information to item response served by upstream API."""
2 |
3 | import logging
4 | import re
5 | from dataclasses import dataclass, field
6 | from typing import Any
7 | from urllib.parse import urlparse
8 |
9 | from starlette.datastructures import Headers
10 | from starlette.requests import Request
11 | from starlette.types import ASGIApp, Scope
12 |
13 | from ..config import EndpointMethods
14 | from ..utils.middleware import JsonResponseMiddleware
15 | from ..utils.requests import find_match
16 | from ..utils.stac import get_links
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | @dataclass
22 | class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
23 | """Middleware to add the authentication extension to the response."""
24 |
25 | app: ASGIApp
26 |
27 | default_public: bool
28 | private_endpoints: EndpointMethods
29 | public_endpoints: EndpointMethods
30 |
31 | oidc_discovery_url: str
32 | auth_scheme_name: str = "oidc"
33 | auth_scheme: dict[str, Any] = field(default_factory=dict)
34 | extension_url: str = (
35 | "https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
36 | )
37 |
38 | json_content_type_expr: str = r"application/(geo\+)?json"
39 |
40 | def should_transform_response(self, request: Request, scope: Scope) -> bool:
41 | """Determine if the response should be transformed."""
42 | # Match STAC catalog, collection, or item URLs with a single regex
43 | return (
44 | all(
45 | (
46 | re.match(expr, val)
47 | for expr, val in [
48 | (
49 | # catalog, collections, collection, items, item, search
50 | r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
51 | request.url.path,
52 | ),
53 | (
54 | self.json_content_type_expr,
55 | Headers(scope=scope).get("content-type", ""),
56 | ),
57 | ]
58 | ),
59 | )
60 | and 200 <= scope["status"] < 300
61 | )
62 |
63 | def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
64 | """Augment the STAC Item with auth information."""
65 | extensions = data.setdefault("stac_extensions", [])
66 | if self.extension_url not in extensions:
67 | extensions.append(self.extension_url)
68 |
69 | # auth:schemes
70 | # ---
71 | # A property that contains all of the scheme definitions used by Assets and
72 | # Links in the STAC Item or Collection.
73 | # - Catalogs
74 | # - Collections
75 | # - Item Properties
76 |
77 | scheme_loc = data["properties"] if "properties" in data else data
78 | schemes = scheme_loc.setdefault("auth:schemes", {})
79 | schemes[self.auth_scheme_name] = {
80 | "type": "openIdConnect",
81 | "openIdConnectUrl": self.oidc_discovery_url,
82 | }
83 |
84 | # auth:refs
85 | # ---
86 | # Annotate links with "auth:refs": [auth_scheme]
87 | for link in get_links(data):
88 | if "href" not in link:
89 | logger.warning("Link %s has no href", link)
90 | continue
91 | match = find_match(
92 | path=urlparse(link["href"]).path,
93 | method="GET",
94 | private_endpoints=self.private_endpoints,
95 | public_endpoints=self.public_endpoints,
96 | default_public=self.default_public,
97 | )
98 | if match.is_private:
99 | link.setdefault("auth:refs", []).append(self.auth_scheme_name)
100 |
101 | return data
102 |
--------------------------------------------------------------------------------
/docs/user-guide/getting-started.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | STAC Auth Proxy is a reverse proxy that adds authentication and authorization to your STAC API. It sits between clients and your STAC API, validating tokens to authenticate request and applying custom authorization rules.
4 |
5 | ## Core Requirements
6 |
7 | To get started with STAC Auth Proxy, you need to provide two essential pieces of information:
8 |
9 | ### 1. OIDC Discovery URL
10 |
11 | You need a valid OpenID Connect (OIDC) discovery URL that points to your identity provider's configuration. This URL typically follows the pattern:
12 |
13 | ```
14 | https://your-auth-provider.com/.well-known/openid-configuration
15 | ```
16 |
17 | > [!TIP]
18 | >
19 | > Common OIDC providers include:
20 | >
21 | > - **Auth0**: `https://{tenant-id}.auth0.com/.well-known/openid-configuration`
22 | > - **AWS Cognito**: `https://cognito-idp.{region}.amazonaws.com/{user-pool-id}/.well-known/openid-configuration`
23 | > - **Azure AD**: `https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration`
24 | > - **Google**: `https://accounts.google.com/.well-known/openid-configuration`
25 | > - **Keycloak**: `https://{keycloak-server}/auth/realms/{realm-id}/.well-known/openid-configuration`
26 |
27 | ### 2. Upstream STAC API URL
28 |
29 | You need the URL to your upstream STAC API that the proxy will protect:
30 |
31 | ```
32 | https://your-stac-api.com/stac
33 | ```
34 |
35 | This should be a valid STAC API that conforms to the STAC specification.
36 |
37 | ## Quick Start
38 |
39 | Here's a minimal example to get you started:
40 |
41 | ### Using Docker
42 |
43 | ```bash
44 | docker run -p 8000:8000 \
45 | -e UPSTREAM_URL=https://your-stac-api.com/stac \
46 | -e OIDC_DISCOVERY_URL=https://your-auth-provider.com/.well-known/openid-configuration \
47 | ghcr.io/developmentseed/stac-auth-proxy:latest
48 | ```
49 |
50 | ### Using Python
51 |
52 | 1. Install the package:
53 | ```bash
54 | pip install stac-auth-proxy
55 | ```
56 | 2. Set environment variables:
57 | ```bash
58 | export UPSTREAM_URL=https://your-stac-api.com/stac
59 | export OIDC_DISCOVERY_URL=https://your-auth-provider.com/.well-known/openid-configuration
60 | ```
61 | 3. Run the proxy:
62 | ```bash
63 | python -m stac_auth_proxy
64 | ```
65 |
66 | ### Using Docker Compose
67 |
68 | For development and experimentation, the codebase (ie within the repository, not within the Docker or Python distributions) ships with a `docker-compose.yaml` file, allowing the proxy to be run locally alongside various supporting services: the database, the STAC API, and a Mock OIDC provider.
69 |
70 | #### pgSTAC Backend
71 |
72 | Run the application stack with a pgSTAC backend using [stac-fastapi-pgstac](https://github.com/stac-utils/stac-fastapi-pgstac):
73 |
74 | ```sh
75 | docker compose up
76 | ```
77 |
78 | #### OpenSearch Backend
79 |
80 | Run the application stack with an OpenSearch backend using [stac-fastapi-elasticsearch-opensearch](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch):
81 |
82 | ```sh
83 | docker compose --profile os up
84 | ```
85 |
86 | The proxy will start on `http://localhost:8000` by default.
87 |
88 | ## What Happens Next?
89 |
90 | Once the proxy starts successfully:
91 |
92 | 1. **Health Check**: The proxy verifies your upstream STAC API is accessible
93 | 2. **Conformance Check**: It ensures your STAC API conforms to required specifications
94 | 3. **OIDC Discovery**: It fetches and validates your OIDC provider configuration
95 | 4. **Ready**: The proxy is now ready to handle requests
96 |
97 | ## Testing Your Setup
98 |
99 | You can test that your proxy is working by accessing the health endpoint:
100 |
101 | ```bash
102 | curl http://localhost:8000/healthz
103 | ```
104 |
105 | ## Next Steps
106 |
107 | - [Configuration Guide](configuration.md) - Learn about all available configuration options
108 | - [Route-Level Authentication](route-level-auth.md) - Configure which endpoints require authentication
109 | - [Record-Level Authentication](record-level-auth.md) - Set up content filtering based on user permissions
110 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | """Tests for OpenAPI spec handling."""
2 |
3 | import pytest
4 | from utils import parse_query_string
5 |
6 | from stac_auth_proxy.utils.requests import (
7 | extract_variables,
8 | get_base_url,
9 | parse_forwarded_header,
10 | )
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "url, expected",
15 | (
16 | ("/collections/123", {"collection_id": "123"}),
17 | ("/collections/123/items", {"collection_id": "123"}),
18 | ("/collections/123/bulk_items", {"collection_id": "123"}),
19 | ("/collections/123/items/456", {"collection_id": "123", "item_id": "456"}),
20 | ("/collections/123/bulk_items/456", {"collection_id": "123", "item_id": "456"}),
21 | ("/other/123", {}),
22 | ),
23 | )
24 | def test_extract_variables(url, expected):
25 | """Test extracting variables from a URL path."""
26 | assert extract_variables(url) == expected
27 |
28 |
29 | @pytest.mark.parametrize(
30 | "query, expected",
31 | (
32 | ("foo=bar", {"foo": "bar"}),
33 | (
34 | 'filter={"xyz":"abc"}&filter-lang=cql2-json',
35 | {"filter": {"xyz": "abc"}, "filter-lang": "cql2-json"},
36 | ),
37 | ),
38 | )
39 | def test_parse_query_string(query, expected):
40 | """Validate test helper for parsing query strings."""
41 | assert parse_query_string(query) == expected
42 |
43 |
44 | @pytest.mark.parametrize(
45 | "header, expected",
46 | (
47 | # Basic Forwarded header parsing
48 | (
49 | "for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com",
50 | {
51 | "for": "192.0.2.43",
52 | "by": "203.0.113.60",
53 | "proto": "https",
54 | "host": "api.example.com",
55 | },
56 | ),
57 | # Multiple for values - should only take the first
58 | (
59 | "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=api.example.com",
60 | {
61 | "for": "192.0.2.43",
62 | "by": "203.0.113.60",
63 | "proto": "https",
64 | "host": "api.example.com",
65 | },
66 | ),
67 | # Quoted values
68 | (
69 | 'for="192.0.2.43"; by="203.0.113.60"; proto="https"; host="api.example.com"',
70 | {
71 | "for": "192.0.2.43",
72 | "by": "203.0.113.60",
73 | "proto": "https",
74 | "host": "api.example.com",
75 | },
76 | ),
77 | # Malformed content
78 | ("malformed header content", {}),
79 | # Empty content
80 | ("", {}),
81 | ),
82 | )
83 | def test_parse_forwarded_header(header, expected):
84 | """Test Forwarded header parsing with various scenarios."""
85 | result = parse_forwarded_header(header)
86 | assert result == expected
87 |
88 |
89 | @pytest.mark.parametrize(
90 | "headers, expected_url",
91 | (
92 | # Forwarded header
93 | (
94 | [
95 | (b"host", b"internal-proxy:8080"),
96 | (b"forwarded", b"for=192.0.2.43; proto=https; host=api.example.com"),
97 | ],
98 | "https://api.example.com/",
99 | ),
100 | # X-Forwarded-* headers
101 | (
102 | [
103 | (b"host", b"internal-proxy:8080"),
104 | (b"x-forwarded-host", b"api.example.com"),
105 | (b"x-forwarded-proto", b"https"),
106 | ],
107 | "https://api.example.com/",
108 | ),
109 | # No forwarded headers
110 | (
111 | [
112 | (b"host", b"proxy.example.com"),
113 | ],
114 | "http://proxy.example.com/",
115 | ),
116 | ),
117 | )
118 | def test_get_base_url(headers, expected_url):
119 | """Test get_base_url with various header configurations."""
120 | from starlette.requests import Request
121 |
122 | scope = {
123 | "type": "http",
124 | "method": "GET",
125 | "path": "/test",
126 | "headers": headers,
127 | }
128 | request = Request(scope)
129 |
130 | result = get_base_url(request)
131 | assert result == expected_url
132 |
--------------------------------------------------------------------------------
/tests/test_lifespan.py:
--------------------------------------------------------------------------------
1 | """Tests for lifespan module."""
2 |
3 | from dataclasses import dataclass
4 | from unittest.mock import AsyncMock, patch
5 |
6 | import pytest
7 | from fastapi import FastAPI
8 | from fastapi.testclient import TestClient
9 | from starlette.middleware import Middleware
10 | from starlette.types import ASGIApp
11 |
12 | from stac_auth_proxy import build_lifespan
13 | from stac_auth_proxy.lifespan import check_conformance, check_server_health
14 | from stac_auth_proxy.utils.middleware import required_conformance
15 |
16 |
17 | @required_conformance("http://example.com/conformance")
18 | @dataclass
19 | class ExampleMiddleware:
20 | """Test middleware with required conformance."""
21 |
22 | app: ASGIApp
23 |
24 |
25 | async def test_check_server_health_success(source_api_server):
26 | """Test successful health check."""
27 | await check_server_health(source_api_server)
28 |
29 |
30 | async def test_check_server_health_failure():
31 | """Test health check failure."""
32 | with pytest.raises(RuntimeError) as exc_info:
33 | with patch("asyncio.sleep") as mock_sleep:
34 | await check_server_health("http://localhost:9999")
35 | assert "failed to respond after" in str(exc_info.value)
36 | # Verify sleep was called with exponential backoff
37 | assert mock_sleep.call_count > 0
38 | # First call should be with base delay
39 | # NOTE: Concurrency issues makes this test flaky
40 | # assert mock_sleep.call_args_list[0][0][0] == 1.0
41 | # Last call should be with max delay
42 | assert mock_sleep.call_args_list[-1][0][0] == 5.0
43 |
44 |
45 | async def test_check_conformance_success(source_api_server, source_api_responses):
46 | """Test successful conformance check."""
47 | middleware = [Middleware(ExampleMiddleware)]
48 | await check_conformance(middleware, source_api_server)
49 |
50 |
51 | async def test_check_conformance_failure(source_api_server, source_api_responses):
52 | """Test conformance check failure."""
53 | # Override the conformance response to not include required conformance
54 | source_api_responses["/conformance"]["GET"] = {"conformsTo": []}
55 |
56 | middleware = [Middleware(ExampleMiddleware)]
57 | with pytest.raises(RuntimeError) as exc_info:
58 | await check_conformance(middleware, source_api_server)
59 | assert "missing the following conformance classes" in str(exc_info.value)
60 |
61 |
62 | async def test_check_conformance_multiple_middleware(source_api_server):
63 | """Test conformance check with multiple middleware."""
64 |
65 | @required_conformance("http://example.com/conformance")
66 | class TestMiddleware2:
67 | def __init__(self, app):
68 | self.app = app
69 |
70 | middleware = [
71 | Middleware(ExampleMiddleware),
72 | Middleware(TestMiddleware2),
73 | ]
74 | await check_conformance(middleware, source_api_server)
75 |
76 |
77 | async def test_check_conformance_no_required(source_api_server):
78 | """Test conformance check with middleware that has no required conformances."""
79 |
80 | class NoConformanceMiddleware:
81 | def __init__(self, app):
82 | self.app = app
83 |
84 | middleware = [Middleware(NoConformanceMiddleware)]
85 | await check_conformance(middleware, source_api_server)
86 |
87 |
88 | def test_lifespan_reusable():
89 | """Ensure the public lifespan handler runs health and conformance checks."""
90 | upstream_url = "https://example.com"
91 | oidc_discovery_url = "https://example.com/.well-known/openid-configuration"
92 | with (
93 | patch(
94 | "stac_auth_proxy.lifespan.check_server_health",
95 | new=AsyncMock(),
96 | ) as mock_health,
97 | patch(
98 | "stac_auth_proxy.lifespan.check_conformance",
99 | new=AsyncMock(),
100 | ) as mock_conf,
101 | ):
102 | app = FastAPI(
103 | lifespan=build_lifespan(
104 | upstream_url=upstream_url,
105 | oidc_discovery_url=oidc_discovery_url,
106 | )
107 | )
108 | with TestClient(app):
109 | pass
110 | assert mock_health.await_count == 2
111 | expected_upstream = upstream_url.rstrip("/") + "/"
112 | mock_conf.assert_awaited_once_with(app.user_middleware, expected_upstream)
113 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: STAC Auth Proxy
2 | site_description: A reverse proxy to mediate communication between a client and an internally accessible STAC API in order to provide a flexible authentication mechanism.
3 | site_dir: build
4 |
5 | repo_name: developmentseed/stac-auth-proxy
6 | repo_url: https://github.com/developmentseed/stac-auth-proxy
7 | edit_uri: blob/main/docs/
8 | site_url: https://developmentseed.org/stac-auth-proxy
9 |
10 | extra:
11 | analytics:
12 | provider: plausible
13 | domain: developmentseed.org/stac-auth-proxy
14 |
15 | feedback:
16 | title: Was this page helpful?
17 | ratings:
18 | - icon: material/emoticon-happy-outline
19 | name: This page was helpful
20 | data: good
21 | note: Thanks for your feedback!
22 |
23 | - icon: material/emoticon-sad-outline
24 | name: This page could be improved
25 | data: bad
26 | note: Thanks for your feedback!
27 | social:
28 | - icon: fontawesome/brands/github
29 | link: https://github.com/developmentseed
30 | nav:
31 | - Overview: index.md
32 | - User Guide:
33 | - Getting Started: user-guide/getting-started.md
34 | - Configuration: user-guide/configuration.md
35 | - Route-Level Auth: user-guide/route-level-auth.md
36 | - Record-Level Auth: user-guide/record-level-auth.md
37 | - Deployment: user-guide/deployment.md
38 | - Tips: user-guide/tips.md
39 | - Architecture:
40 | - Middleware Stack: architecture/middleware-stack.md
41 | - Filtering Data: architecture/filtering-data.md
42 | - Changelog: changelog.md
43 |
44 | plugins:
45 | - search
46 | - social
47 | - api-autonav:
48 | modules:
49 | - src/stac_auth_proxy
50 | - mkdocstrings:
51 | enable_inventory: true
52 | handlers:
53 | python:
54 | paths:
55 | - src
56 | options:
57 | extensions:
58 | - griffe_fieldz
59 | show_signature_annotations: true
60 | inventories:
61 | - https://docs.python.org/3/objects.inv
62 | - https://docs.pydantic.dev/latest/objects.inv
63 | - https://fastapi.tiangolo.com/objects.inv
64 | - https://www.starlette.io/objects.inv
65 |
66 | theme:
67 | name: material
68 | palette:
69 | # Palette toggle for automatic mode
70 | - media: (prefers-color-scheme)
71 | toggle:
72 | icon: material/brightness-auto
73 | name: Switch to light mode
74 |
75 | # Palette toggle for light mode
76 | - media: "(prefers-color-scheme: light)"
77 | scheme: default
78 | primary: indigo
79 | accent: indigo
80 | toggle:
81 | icon: material/brightness-7
82 | name: Switch to dark mode
83 |
84 | # Palette toggle for dark mode
85 | - media: "(prefers-color-scheme: dark)"
86 | scheme: slate
87 | primary: indigo
88 | accent: indigo
89 | toggle:
90 | icon: material/brightness-4
91 | name: Switch to light mode
92 |
93 | custom_dir: docs/overrides
94 | favicon: assets/ds-symbol-positive-mono.png
95 | logo: assets/ds-symbol-negative-mono.png
96 |
97 | features:
98 | - content.code.annotate
99 | - content.code.copy
100 | - navigation.indexes
101 | - navigation.instant
102 | - navigation.tracking
103 | - search.suggest
104 | - search.share
105 |
106 | # https://github.com/kylebarron/cogeo-mosaic/blob/mkdocs/mkdocs.yml#L50-L75
107 | markdown_extensions:
108 | - attr_list
109 | - codehilite:
110 | guess_lang: false
111 | - def_list
112 | - footnotes
113 | - markdown_gfm_admonition # support github-flavored admonitions
114 | - pymdownx.emoji: # Convert emoji shortcodes to images
115 | emoji_index: !!python/name:material.extensions.emoji.twemoji
116 | emoji_generator: !!python/name:material.extensions.emoji.to_svg
117 | - pymdownx.magiclink: # render hrefs as links
118 | hide_protocol: true
119 | repo_url_shortener: true # handle github links
120 | - pymdownx.superfences:
121 | custom_fences:
122 | - name: mermaid # support mermaid diagrams
123 | class: mermaid
124 | format: !!python/name:pymdownx.superfences.fence_code_format
125 | - toc:
126 | permalink: true # Add permalink hover link to each heading
127 | - pymdownx.tasklist:
128 | custom_checkbox: true
129 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | stac-pg:
3 | profiles: [""] # default profile
4 | image: ghcr.io/stac-utils/stac-fastapi-pgstac:6.1.1
5 | environment:
6 | STAC_FASTAPI_TITLE: stac-fastapi-pgstac + stac-auth-proxy
7 | STAC_FASTAPI_DESCRIPTION: A STAC FastAPI implemented with pgSTAC, protected with STAC Auth Proxy
8 | APP_HOST: 0.0.0.0
9 | APP_PORT: 8001
10 | CORS_HEADERS: "*"
11 | RELOAD: true
12 | POSTGRES_USER: username
13 | POSTGRES_PASS: password
14 | POSTGRES_DBNAME: postgis
15 | POSTGRES_HOST_READER: database-pg
16 | POSTGRES_HOST_WRITER: database-pg
17 | POSTGRES_PORT: 5432
18 | DB_MIN_CONN_SIZE: 1
19 | DB_MAX_CONN_SIZE: 1
20 | USE_API_HYDRATE: ${USE_API_HYDRATE:-false}
21 | ENABLE_TRANSACTIONS_EXTENSIONS: 1
22 | hostname: stac
23 | ports:
24 | - "8001:8001"
25 | depends_on:
26 | - database-pg
27 | command: python -m stac_fastapi.pgstac.app
28 | # command: bash -c "./scripts/wait-for-it.sh database-pg:5432 && python -m stac_fastapi.pgstac.app"
29 |
30 | stac-os:
31 | profiles: ["os"]
32 | container_name: stac-fastapi-os
33 | image: ghcr.io/stac-utils/stac-fastapi-os:v6.1.0
34 | environment:
35 | STAC_FASTAPI_TITLE: stac-fastapi-opensearch + stac-auth-proxy
36 | STAC_FASTAPI_DESCRIPTION: A STAC FastAPI with an Opensearch backend, protected with STAC Auth Proxy
37 | STAC_FASTAPI_LANDING_PAGE_ID: stac-fastapi-opensearch
38 | APP_HOST: 0.0.0.0
39 | APP_PORT: 8001
40 | RELOAD: true
41 | ENVIRONMENT: local
42 | ES_HOST: database-os
43 | ES_PORT: 9200
44 | ES_USE_SSL: false
45 | ES_VERIFY_CERTS: false
46 | BACKEND: opensearch
47 | STAC_FASTAPI_RATE_LIMIT: 200/minute
48 | hostname: stac
49 | ports:
50 | - "8001:8001"
51 | depends_on:
52 | - database-os
53 | command: |
54 | bash -c "./scripts/wait-for-it-es.sh database-os:9200 && python -m stac_fastapi.opensearch.app"
55 |
56 | database-pg:
57 | profiles: [""] # default profile
58 | container_name: database-pg
59 | image: ghcr.io/stac-utils/pgstac:v0.9.5
60 | environment:
61 | POSTGRES_USER: username
62 | POSTGRES_PASSWORD: password
63 | POSTGRES_DB: postgis
64 | PGUSER: username
65 | PGPASSWORD: password
66 | PGDATABASE: postgis
67 | ports:
68 | - "${MY_DOCKER_IP:-127.0.0.1}:5439:5432"
69 | command: postgres -N 500
70 | volumes:
71 | - ./.pgdata:/var/lib/postgresql/data
72 |
73 | database-os:
74 | profiles: ["os"]
75 | container_name: database-os
76 | image: opensearchproject/opensearch:${OPENSEARCH_VERSION:-2.11.1}
77 | hostname: database-os
78 | environment:
79 | cluster.name: stac-cluster
80 | node.name: os01
81 | http.port: 9200
82 | http.cors.allow-headers: X-Requested-With,Content-Type,Content-Length,Accept,Authorization
83 | discovery.type: single-node
84 | plugins.security.disabled: true
85 | OPENSEARCH_JAVA_OPTS: -Xms512m -Xmx512m
86 | ports:
87 | - "9200:9200"
88 |
89 | proxy:
90 | depends_on:
91 | - oidc
92 | build:
93 | context: .
94 | environment:
95 | UPSTREAM_URL: ${UPSTREAM_URL:-http://stac:8001}
96 | OIDC_DISCOVERY_URL: ${OIDC_DISCOVERY_URL:-http://localhost:8888/.well-known/openid-configuration}
97 | OIDC_DISCOVERY_INTERNAL_URL: ${OIDC_DISCOVERY_INTERNAL_URL:-http://oidc:8888/.well-known/openid-configuration}
98 | env_file:
99 | - path: .env
100 | required: false
101 | ports:
102 | - "8000:8000"
103 | volumes:
104 | - ./src:/app/src
105 |
106 | oidc:
107 | image: ghcr.io/alukach/mock-oidc-server:latest
108 | environment:
109 | ISSUER: http://localhost:8888
110 | SCOPES: item:create,item:update,item:delete,collection:create,collection:update,collection:delete
111 | PORT: 8888
112 | ports:
113 | - "8888:8888"
114 |
115 | stac-browser:
116 | image: ghcr.io/radiantearth/stac-browser:latest
117 | ports:
118 | - 8080:8080
119 | environment:
120 | SB_catalogUrl: "http://localhost:8000"
121 | SB_authConfig: |
122 | {
123 | "type": "openIdConnect",
124 | "openIdConnectUrl": "http://localhost:8888/.well-known/openid-configuration",
125 | "oidcOptions": {
126 | "client_id": "stac-browser"
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/config.py:
--------------------------------------------------------------------------------
1 | """Configuration for the STAC Auth Proxy."""
2 |
3 | import importlib
4 | from typing import Any, Literal, Optional, Sequence, TypeAlias, Union
5 |
6 | from pydantic import BaseModel, Field, field_validator, model_validator
7 | from pydantic.networks import HttpUrl
8 | from pydantic_settings import BaseSettings, SettingsConfigDict
9 |
10 | METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
11 | EndpointMethods: TypeAlias = dict[str, Sequence[METHODS]]
12 | EndpointMethodsWithScope: TypeAlias = dict[
13 | str, Sequence[Union[METHODS, tuple[METHODS, str]]]
14 | ]
15 |
16 | _PREFIX_PATTERN = r"^/.*$"
17 |
18 |
19 | def str2list(x: Optional[str] = None) -> Optional[Sequence[str]]:
20 | """Convert string to list based on , delimiter."""
21 | if x:
22 | return x.replace(" ", "").split(",")
23 |
24 | return None
25 |
26 |
27 | class _ClassInput(BaseModel):
28 | """Input model for dynamically loading a class or function."""
29 |
30 | cls: str
31 | args: Sequence[str] = Field(default_factory=list)
32 | kwargs: dict[str, str] = Field(default_factory=dict)
33 |
34 | def __call__(self):
35 | """Dynamically load a class and instantiate it with args & kwargs."""
36 | assert self.cls.count(":")
37 | module_path, class_name = self.cls.rsplit(":", 1)
38 | module = importlib.import_module(module_path)
39 | cls = getattr(module, class_name)
40 | return cls(*self.args, **self.kwargs)
41 |
42 |
43 | class Settings(BaseSettings):
44 | """Configuration settings for the STAC Auth Proxy."""
45 |
46 | # External URLs
47 | upstream_url: HttpUrl
48 | oidc_discovery_url: HttpUrl
49 | oidc_discovery_internal_url: HttpUrl
50 | allowed_jwt_audiences: Optional[Sequence[str]] = None
51 |
52 | root_path: str = ""
53 | override_host: bool = True
54 | healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
55 | wait_for_upstream: bool = True
56 | check_conformance: bool = True
57 | enable_compression: bool = True
58 |
59 | # OpenAPI / Swagger UI
60 | openapi_spec_endpoint: Optional[str] = Field(
61 | pattern=_PREFIX_PATTERN, default="/api"
62 | )
63 | openapi_auth_scheme_name: str = "oidcAuth"
64 | openapi_auth_scheme_override: Optional[dict] = None
65 | swagger_ui_endpoint: Optional[str] = Field(
66 | pattern=_PREFIX_PATTERN, default="/api.html"
67 | )
68 | swagger_ui_init_oauth: dict = Field(default_factory=dict)
69 |
70 | # Auth
71 | enable_authentication_extension: bool = True
72 | default_public: bool = False
73 | public_endpoints: EndpointMethods = {
74 | r"^/$": ["GET"],
75 | r"^/api.html$": ["GET"],
76 | r"^/api$": ["GET"],
77 | r"^/conformance$": ["GET"],
78 | r"^/docs/oauth2-redirect": ["GET"],
79 | r"^/healthz": ["GET"],
80 | }
81 | private_endpoints: EndpointMethodsWithScope = {
82 | # https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
83 | r"^/collections$": ["POST"],
84 | r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
85 | # https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
86 | r"^/collections/([^/]+)/items$": ["POST"],
87 | r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
88 | # https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
89 | r"^/collections/([^/]+)/bulk_items$": ["POST"],
90 | }
91 |
92 | # Filters
93 | items_filter: Optional[_ClassInput] = None
94 | items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
95 | collections_filter: Optional[_ClassInput] = None
96 | collections_filter_path: str = r"^/collections(/[^/]+)?$"
97 |
98 | model_config = SettingsConfigDict(
99 | env_nested_delimiter="_",
100 | )
101 |
102 | @model_validator(mode="before")
103 | @classmethod
104 | def _default_oidc_discovery_internal_url(cls, data: Any) -> Any:
105 | """Set the internal OIDC discovery URL to the public URL if not set."""
106 | if not data.get("oidc_discovery_internal_url"):
107 | data["oidc_discovery_internal_url"] = data.get("oidc_discovery_url")
108 | return data
109 |
110 | @field_validator("allowed_jwt_audiences", mode="before")
111 | @classmethod
112 | def parse_audience(cls, v) -> Optional[Sequence[str]]:
113 | """Parse a comma separated string list of audiences into a list."""
114 | return str2list(v)
115 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""
2 |
3 | import json
4 | from dataclasses import dataclass
5 | from logging import getLogger
6 | from typing import Optional
7 | from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
8 |
9 | from cql2 import Expr
10 | from starlette.requests import Request
11 | from starlette.types import ASGIApp, Message, Receive, Scope, Send
12 |
13 | logger = getLogger(__name__)
14 |
15 |
16 | @dataclass(frozen=True)
17 | class Cql2RewriteLinksFilterMiddleware:
18 | """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""
19 |
20 | app: ASGIApp
21 | state_key: str = "cql2_filter"
22 |
23 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
24 | """Replace 'filter' in .links of the JSON response to state before we had applied the filter."""
25 | if scope["type"] != "http":
26 | return await self.app(scope, receive, send)
27 |
28 | request = Request(scope)
29 | original_filter = request.query_params.get("filter")
30 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
31 | if cql2_filter is None:
32 | # No filter set, just pass through
33 | return await self.app(scope, receive, send)
34 |
35 | # Intercept the response
36 | response_start = None
37 | body_chunks = []
38 | more_body = True
39 |
40 | async def send_wrapper(message: Message):
41 | nonlocal response_start, body_chunks, more_body
42 | if message["type"] == "http.response.start":
43 | response_start = message
44 | elif message["type"] == "http.response.body":
45 | body_chunks.append(message.get("body", b""))
46 | more_body = message.get("more_body", False)
47 | if not more_body:
48 | await self._process_and_send_response(
49 | response_start, body_chunks, send, original_filter
50 | )
51 | else:
52 | await send(message)
53 |
54 | await self.app(scope, receive, send_wrapper)
55 |
56 | async def _process_and_send_response(
57 | self,
58 | response_start: Message,
59 | body_chunks: list[bytes],
60 | send: Send,
61 | original_filter: Optional[str],
62 | ):
63 | body = b"".join(body_chunks)
64 | try:
65 | data = json.loads(body)
66 | except Exception:
67 | await send(response_start)
68 | await send({"type": "http.response.body", "body": body, "more_body": False})
69 | return
70 |
71 | cql2_filter = Expr(original_filter) if original_filter else None
72 | links = data.get("links")
73 | if isinstance(links, list):
74 | for link in links:
75 | # Handle filter in query string
76 | if "href" in link:
77 | url = urlparse(link["href"])
78 | qs = parse_qs(url.query)
79 | if "filter" in qs:
80 | if cql2_filter:
81 | qs["filter"] = [cql2_filter.to_text()]
82 | else:
83 | qs.pop("filter", None)
84 | qs.pop("filter-lang", None)
85 | new_query = urlencode(qs, doseq=True)
86 | link["href"] = urlunparse(url._replace(query=new_query))
87 |
88 | # Handle filter in body (for POST links)
89 | if "body" in link and isinstance(link["body"], dict):
90 | if "filter" in link["body"]:
91 | if cql2_filter:
92 | link["body"]["filter"] = cql2_filter.to_json()
93 | else:
94 | link["body"].pop("filter", None)
95 | link["body"].pop("filter-lang", None)
96 |
97 | # Send the modified response
98 | new_body = json.dumps(data).encode("utf-8")
99 |
100 | # Patch content-length
101 | headers = [
102 | (k, v) for k, v in response_start["headers"] if k != b"content-length"
103 | ]
104 | headers.append((b"content-length", str(len(new_body)).encode("latin1")))
105 | response_start = dict(response_start)
106 | response_start["headers"] = headers
107 | await send(response_start)
108 | await send({"type": "http.response.body", "body": new_body, "more_body": False})
109 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to remove the application root path from incoming requests and update links in responses."""
2 |
3 | import logging
4 | import re
5 | from dataclasses import dataclass
6 | from typing import Any, Optional
7 | from urllib.parse import ParseResult, urlparse, urlunparse
8 |
9 | from starlette.datastructures import Headers
10 | from starlette.requests import Request
11 | from starlette.types import ASGIApp, Scope
12 |
13 | from ..utils.middleware import JsonResponseMiddleware
14 | from ..utils.requests import get_base_url
15 | from ..utils.stac import get_links
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | @dataclass
21 | class ProcessLinksMiddleware(JsonResponseMiddleware):
22 | """
23 | Middleware to update links in responses, removing the upstream_url path and adding
24 | the root_path if it exists.
25 | """
26 |
27 | app: ASGIApp
28 | upstream_url: str
29 | root_path: Optional[str] = None
30 |
31 | json_content_type_expr: str = r"application/(geo\+)?json"
32 |
33 | def should_transform_response(self, request: Request, scope: Scope) -> bool:
34 | """Only transform responses with JSON content type."""
35 | return bool(
36 | re.match(
37 | self.json_content_type_expr,
38 | Headers(scope=scope).get("content-type", ""),
39 | )
40 | )
41 |
42 | def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
43 | """Update links in the response to include root_path."""
44 | # Get the client's actual base URL (accounting for load balancers/proxies)
45 | req_base_url = get_base_url(request)
46 | parsed_req_url = urlparse(req_base_url)
47 | parsed_upstream_url = urlparse(self.upstream_url)
48 |
49 | for link in get_links(data):
50 | try:
51 | self._update_link(link, parsed_req_url, parsed_upstream_url)
52 | except Exception as e:
53 | logger.error(
54 | "Failed to parse link href %r, (ignoring): %s",
55 | link.get("href"),
56 | str(e),
57 | )
58 | return data
59 |
60 | def _update_link(
61 | self, link: dict[str, Any], request_url: ParseResult, upstream_url: ParseResult
62 | ) -> None:
63 | """
64 | Ensure that link hrefs that are local to upstream url are rewritten as local to
65 | the proxy.
66 | """
67 | if "href" not in link:
68 | logger.warning("Link %r has no href", link)
69 | return
70 |
71 | parsed_link = urlparse(link["href"])
72 |
73 | if parsed_link.netloc not in [
74 | request_url.netloc,
75 | upstream_url.netloc,
76 | ]:
77 | logger.debug(
78 | "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
79 | link["href"],
80 | request_url.netloc,
81 | upstream_url.netloc,
82 | )
83 | return
84 |
85 | # If the link path is not a descendant of the upstream path, don't transform it
86 | if upstream_url.path != "/" and not parsed_link.path.startswith(
87 | upstream_url.path
88 | ):
89 | logger.debug(
90 | "Ignoring link %s because it is not descendant of upstream path (%s)",
91 | link["href"],
92 | upstream_url.path,
93 | )
94 | return
95 |
96 | # Replace the upstream host with the client's host
97 | if parsed_link.netloc == upstream_url.netloc:
98 | parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace(
99 | scheme=request_url.scheme
100 | )
101 |
102 | # Remove the upstream prefix from the link path
103 | if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path):
104 | parsed_link = parsed_link._replace(
105 | path=parsed_link.path[len(upstream_url.path) :]
106 | )
107 |
108 | # Add the root_path to the link if it exists
109 | if self.root_path:
110 | parsed_link = parsed_link._replace(
111 | path=f"{self.root_path}{parsed_link.path}"
112 | )
113 |
114 | updated_href = urlunparse(parsed_link)
115 | if updated_href == link["href"]:
116 | return
117 |
118 | logger.debug(
119 | "Rewriting %r link %r to %r",
120 | link.get("rel"),
121 | link["href"],
122 | updated_href,
123 | )
124 |
125 | link["href"] = updated_href
126 |
--------------------------------------------------------------------------------
/docs/user-guide/deployment.md:
--------------------------------------------------------------------------------
1 | # Deployment
2 |
3 | ## General
4 |
5 | Deploying the STAC Auth Proxy is similar to deploying any other service. In general, we recommend you mirror the architecture of your other systems.
6 |
7 | The core principles of deploying the STAC Auth Proxy are:
8 |
9 | 1. The STAC API should not be available on the public internet
10 | 2. The STAC Auth Proxy should be able to communicate with both the STAC API and the OIDC Server (namely, the discovery endpoint and JWKS endpoint)
11 |
12 | ### Networking Considerations
13 |
14 | #### Hiding the STAC API
15 |
16 | The STAC API should not be directly accessible from the public internet. The STAC Auth Proxy acts as the public-facing endpoint.
17 |
18 | ##### AWS Strategy
19 |
20 | - Place the STAC API in a private subnet
21 | - Place the STAC Auth Proxy in a public subnet with internet access
22 | - Use security groups to restrict access between components
23 |
24 | ##### Kubernetes Strategy
25 |
26 | - Deploy the STAC API as an internal service (ClusterIP)
27 | - Deploy the STAC Auth Proxy with an Ingress for external access
28 | - Use network policies to control traffic flow
29 |
30 | #### Communicating with the OIDC Server
31 |
32 | The STAC Auth Proxy needs to communicate with your OIDC provider for authentication. If your OIDC server is not directly available to the STAC Auth Proxy, use [`OIDC_DISCOVERY_INTERNAL_URL`](configuration.md#oidc_discovery_internal_url) (the [`OIDC_DISCOVERY_URL`](configuration.md#oidc_discovery_url) will still be used for auth in the browser, such as the Swagger UI page).
33 |
34 | ## AWS Lambda
35 |
36 | For AWS Lambda deployments, we recommend using the [Mangum](https://pypi.org/project/mangum/) handler with disabled lifespan events. Such a handler is available at `stac_auth_proxy.lambda:handler`.
37 |
38 | > [!TIP]
39 | >
40 | > If using `stac_auth_proxy.lambda:handler`, be sure to install the `lambda` optional dependencies:
41 | >
42 | > ```bash
43 | > pip install stac_auth_proxy[lambda]
44 | > ```
45 |
46 | ### CDK
47 |
48 | If using [AWS CDK](https://docs.aws.amazon.com/cdk/), a [`StacAuthProxy` Construct](https://developmentseed.org/eoapi-cdk/#stacauthproxylambda-) is made available within the [`eoapi-cdk`](https://github.com/developmentseed/eoapi-cdk) project.
49 |
50 | ## Docker
51 |
52 | The STAC Auth Proxy is available as a [Docker image from the GitHub Container Registry (GHCR)](https://github.com/developmentseed/stac-auth-proxy/pkgs/container/stac-auth-proxy).
53 |
54 | ```bash
55 | # Latest version
56 | docker pull ghcr.io/developmentseed/stac-auth-proxy:latest
57 |
58 | # Specific version
59 | docker pull ghcr.io/developmentseed/stac-auth-proxy:v0.7.1
60 | ```
61 |
62 | ## Kubernetes
63 |
64 | The STAC Auth Proxy can be deployed to Kubernetes via the [Helm Chart available on the GitHub Container Registry (GHCR)](https://github.com/developmentseed/stac-auth-proxy/pkgs/container/stac-auth-proxy%2Fcharts%2Fstac-auth-proxy).
65 |
66 | ### Prerequisites
67 |
68 | - Kubernetes 1.19+
69 | - Helm 3.2.0+
70 |
71 | ### Installation
72 |
73 | ```bash
74 | # Add the Helm repository
75 | helm registry login ghcr.io
76 |
77 | # Install with minimal configuration
78 | helm install stac-auth-proxy oci://ghcr.io/developmentseed/stac-auth-proxy/charts/stac-auth-proxy \
79 | --set env.UPSTREAM_URL=https://your-stac-api.com/stac \
80 | --set env.OIDC_DISCOVERY_URL=https://your-auth-server/.well-known/openid-configuration \
81 | --set ingress.host=stac-proxy.your-domain.com
82 | ```
83 |
84 | ### Configuration
85 |
86 | | Parameter | Description | Required | Default |
87 | | ------------------------ | --------------------------------------------- | -------- | ------- |
88 | | `env.UPSTREAM_URL` | URL of the STAC API to proxy | Yes | - |
89 | | `env.OIDC_DISCOVERY_URL` | OpenID Connect discovery document URL | Yes | - |
90 | | `env` | Environment variables passed to the container | No | `{}` |
91 | | `ingress.enabled` | Enable ingress | No | `true` |
92 | | `ingress.className` | Ingress class name | No | `nginx` |
93 | | `ingress.host` | Hostname for the ingress | No | `""` |
94 | | `ingress.tls.enabled` | Enable TLS for ingress | No | `true` |
95 | | `replicaCount` | Number of replicas | No | `1` |
96 |
97 | For a complete list of values, see the [values.yaml](https://github.com/developmentseed/stac-auth-proxy/blob/main/helm/values.yaml) file.
98 |
99 | ### Management
100 |
101 | ```bash
102 | # Upgrade
103 | helm upgrade stac-auth-proxy oci://ghcr.io/developmentseed/stac-auth-proxy/charts/stac-auth-proxy
104 |
105 | # Uninstall
106 | helm uninstall stac-auth-proxy
107 | ```
108 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to build the Cql2Filter."""
2 |
3 | import logging
4 | import re
5 | from dataclasses import dataclass
6 | from typing import Any, Awaitable, Callable, Optional
7 |
8 | from cql2 import Expr, ValidationError
9 | from starlette.requests import Request
10 | from starlette.responses import Response
11 | from starlette.types import ASGIApp, Receive, Scope, Send
12 |
13 | from ..utils import requests
14 | from ..utils.middleware import required_conformance
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | @required_conformance(
20 | "http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
21 | "http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
22 | "http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
23 | )
24 | @dataclass(frozen=True)
25 | class Cql2BuildFilterMiddleware:
26 | """Middleware to build the Cql2Filter."""
27 |
28 | app: ASGIApp
29 |
30 | state_key: str = "cql2_filter"
31 |
32 | # Filters
33 | collections_filter: Optional[Callable] = None
34 | collections_filter_path: str = r"^/collections(/[^/]+)?$"
35 | items_filter: Optional[Callable] = None
36 | items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
37 |
38 | def __post_init__(self):
39 | """Set required conformances based on the filter functions."""
40 | required_conformances = set()
41 | if self.collections_filter:
42 | logger.debug("Appending required conformance for collections filter")
43 | # https://github.com/stac-api-extensions/collection-search/blob/4825b4b1cee96bdc0cbfbb342d5060d0031976f0/README.md#L5
44 | required_conformances.update(
45 | [
46 | "https://api.stacspec.org/v1.0.0/core",
47 | r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search",
48 | r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search#filter",
49 | "http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query",
50 | ]
51 | )
52 | if self.items_filter:
53 | logger.debug("Appending required conformance for items filter")
54 | # https://github.com/stac-api-extensions/filter/blob/c763dbbf0a52210ab8d9866ff048da448d270f93/README.md#conformance-classes
55 | required_conformances.update(
56 | [
57 | "http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter",
58 | "http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
59 | r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/item-search#filter",
60 | ]
61 | )
62 |
63 | # Must set required conformances on class
64 | self.__class__.__required_conformances__ = required_conformances.union(
65 | getattr(self.__class__, "__required_conformances__", [])
66 | )
67 |
68 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
69 | """Build the CQL2 filter, place on the request state."""
70 | if scope["type"] != "http":
71 | return await self.app(scope, receive, send)
72 |
73 | request = Request(scope)
74 |
75 | filter_builder = self._get_filter(request.url.path)
76 | if not filter_builder:
77 | return await self.app(scope, receive, send)
78 |
79 | filter_expr = await filter_builder(
80 | {
81 | "req": {
82 | "path": request.url.path,
83 | "method": request.method,
84 | "query_params": dict(request.query_params),
85 | "path_params": requests.extract_variables(request.url.path),
86 | "headers": dict(request.headers),
87 | },
88 | **scope["state"],
89 | }
90 | )
91 | cql2_filter = Expr(filter_expr)
92 | try:
93 | cql2_filter.validate()
94 | except ValidationError:
95 | logger.error("Invalid CQL2 filter: %s", filter_expr)
96 | return Response(status_code=502, content="Invalid CQL2 filter")
97 | setattr(request.state, self.state_key, cql2_filter)
98 |
99 | return await self.app(scope, receive, send)
100 |
101 | def _get_filter(
102 | self, path: str
103 | ) -> Optional[Callable[..., Awaitable[str | dict[str, Any]]]]:
104 | """Get the CQL2 filter builder for the given path."""
105 | endpoint_filters = [
106 | (self.collections_filter_path, self.collections_filter),
107 | (self.items_filter_path, self.items_filter),
108 | ]
109 | for expr, builder in endpoint_filters:
110 | if re.match(expr, path):
111 | return builder
112 | return None
113 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to validate the response body with a CQL2 filter for single-record endpoints."""
2 |
3 | import json
4 | import re
5 | from dataclasses import dataclass
6 | from logging import getLogger
7 | from typing import Optional
8 |
9 | from cql2 import Expr
10 | from starlette.requests import Request
11 | from starlette.types import ASGIApp, Message, Receive, Scope, Send
12 |
13 | from ..utils.middleware import required_conformance
14 |
15 | logger = getLogger(__name__)
16 |
17 |
18 | @required_conformance(
19 | r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
20 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
21 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
22 | )
23 | @dataclass
24 | class Cql2ValidateResponseBodyMiddleware:
25 | """ASGI middleware to validate the response body with a CQL2 filter for single-record endpoints."""
26 |
27 | app: ASGIApp
28 | state_key: str = "cql2_filter"
29 |
30 | single_record_endpoints = [
31 | r"^/collections/([^/]+)/items/([^/]+)$",
32 | r"^/collections/([^/]+)$",
33 | ]
34 |
35 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
36 | """Validate the response body with a CQL2 filter for single-record endpoints."""
37 | if scope["type"] != "http":
38 | return await self.app(scope, receive, send)
39 |
40 | request = Request(scope)
41 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
42 | if not cql2_filter:
43 | return await self.app(scope, receive, send)
44 |
45 | if not any(
46 | re.match(expr, request.url.path) for expr in self.single_record_endpoints
47 | ):
48 | return await self.app(scope, receive, send)
49 |
50 | # Intercept the response
51 | response_start = None
52 | body_chunks = []
53 | more_body = True
54 |
55 | async def send_wrapper(message: Message):
56 | nonlocal response_start, body_chunks, more_body
57 | if message["type"] == "http.response.start":
58 | response_start = message
59 | elif message["type"] == "http.response.body":
60 | body_chunks.append(message.get("body", b""))
61 | more_body = message.get("more_body", False)
62 | if not more_body:
63 | await self._process_and_send_response(
64 | response_start, body_chunks, send, cql2_filter
65 | )
66 | else:
67 | await send(message)
68 |
69 | await self.app(scope, receive, send_wrapper)
70 |
71 | async def _process_and_send_response(
72 | self, response_start, body_chunks, send, cql2_filter
73 | ):
74 | body = b"".join(body_chunks)
75 | try:
76 | body_json = json.loads(body)
77 | except json.JSONDecodeError:
78 | logger.warning("Failed to parse response body as JSON")
79 | await self._send_json_response(
80 | send,
81 | status=502,
82 | content={
83 | "code": "ParseError",
84 | "description": "Failed to parse response body as JSON",
85 | },
86 | )
87 | return
88 |
89 | try:
90 | cql2_matches = cql2_filter.matches(body_json)
91 | except Exception as e:
92 | cql2_matches = False
93 | logger.warning("Failed to apply filter: %s", e)
94 |
95 | if cql2_matches:
96 | logger.debug("Response matches filter, returning record")
97 | # Send the original response start
98 | await send(response_start)
99 | # Send the filtered body
100 | await send(
101 | {
102 | "type": "http.response.body",
103 | "body": json.dumps(body_json).encode("utf-8"),
104 | "more_body": False,
105 | }
106 | )
107 | else:
108 | logger.debug("Response did not match filter, returning 404")
109 | await self._send_json_response(
110 | send,
111 | status=404,
112 | content={"code": "NotFoundError", "description": "Record not found."},
113 | )
114 |
115 | async def _send_json_response(self, send, status, content):
116 | response_bytes = json.dumps(content).encode("utf-8")
117 | await send(
118 | {
119 | "type": "http.response.start",
120 | "status": status,
121 | "headers": [
122 | (b"content-type", b"application/json"),
123 | (b"content-length", str(len(response_bytes)).encode("latin1")),
124 | ],
125 | }
126 | )
127 | await send(
128 | {
129 | "type": "http.response.body",
130 | "body": response_bytes,
131 | "more_body": False,
132 | }
133 | )
134 |
--------------------------------------------------------------------------------
/docs/user-guide/route-level-auth.md:
--------------------------------------------------------------------------------
1 | # Route-Level Authorization
2 |
3 | Route-level authorization can provide a base layer of security for the simplest use cases. This typically looks like:
4 |
5 | - the entire catalog being private, available only to authenticated users
6 | - most of the catalog being public, available to anonymous or authenticated users. However, a subset of endpoints (typically the [transactions extension](https://github.com/stac-api-extensions/transaction) endpoints) are only available to all or a subset of authenticated users
7 |
8 | ## Configuration Variables
9 |
10 | Route-level authorization is controlled by three key environment variables:
11 |
12 | - **[`DEFAULT_PUBLIC`](../../configuration/#default_public)**: Sets the default access policy for all endpoints
13 | - **[`PUBLIC_ENDPOINTS`](../../configuration/#public_endpoints)**: Marks endpoints as not requiring authentication (used only when `DEFAULT_PUBLIC=false`). By default, we keep the catalog root, OpenAPI spec, Swagger UI, Swagger UI auth redirect, and the proxy health endpoint as public. Note that these are all endpoints that don't serve actual STAC data; they only acknowledge the presence of a STAC catalog. This is defined by a mapping of regex path expressions to arrays of HTTP methods.
14 | - **[`PRIVATE_ENDPOINTS`](../../configuration/#private_endpoints)**: Marks endpoints as requiring authentication. By default, the transactions endpoints are all marked as private. This is defined by a mapping of regex path expressions to arrays of either HTTP methods or tuples of HTTP methods and space-separated required scopes.
15 |
16 | > [!TIP]
17 | >
18 | > Users typically don't need to specify both `PRIVATE_ENDPOINTS` and `PUBLIC_ENDPOINTS`.
19 |
20 | ## Strategies
21 |
22 | ### Private by Default
23 |
24 | Make the entire STAC API private, requiring authentication for all endpoints.
25 |
26 | > [!NOTE]
27 | >
28 | > This is the out-of-the-box configuration of the STAC Auth Proxy.
29 |
30 | **Configuration**
31 |
32 | ```bash
33 | # Set default policy to private
34 | DEFAULT_PUBLIC=false
35 |
36 | # The default public endpoints are typically sufficient. Otherwise, they can be specified.
37 | # PUBLIC_ENDPOINTS='{ ... }'
38 | ```
39 |
40 | **Behavior**
41 |
42 | - All endpoints require authentication by default
43 | - Only explicitly listed endpoints in `PUBLIC_ENDPOINTS` are accessible without authentication. By default, these are endpoints that don't reveal STAC data
44 | - Useful for internal or enterprise STAC APIs where all data should be protected
45 |
46 | ### Public by Default with Protected Write Operations
47 |
48 | Make the STAC API mostly public for read operations, but require authentication for write operations (transactions extension).
49 |
50 | **Configuration**
51 |
52 | ```bash
53 | # Set default policy to public
54 | DEFAULT_PUBLIC=true
55 |
56 | # The default private endpoints are typically sufficient. Otherwise, they can be specified.
57 | # PRIVATE_ENDPOINTS='{ ... }'
58 | ```
59 |
60 | **Behavior**
61 |
62 | - Read operations (GET requests) are accessible to everyone
63 | - Write operations require authentication
64 | - Default configuration matches this pattern
65 | - Ideal for public STAC catalogs where data discovery is open but modifications are restricted
66 |
67 | ### Authenticated Access with Scope-based Authorization
68 |
69 | For a level of control beyond simple anonymous vs. authenticated status, the proxy can be configured so that path/method access requires JWTs containing particular permissions in the form of the [scopes claim](https://datatracker.ietf.org/doc/html/rfc8693#name-scope-scopes-claim).
70 |
71 | **Configuration**
72 |
73 | For granular permissions on a public API:
74 |
75 | ```bash
76 | # Set default policy to public
77 | DEFAULT_PUBLIC=true
78 |
79 | # Require specific scopes for write operations
80 | PRIVATE_ENDPOINTS='{
81 | "^/collections$": [["POST", "collection:create"]],
82 | "^/collections/([^/]+)$": [["PUT", "collection:update"], ["PATCH", "collection:update"], ["DELETE", "collection:delete"]],
83 | "^/collections/([^/]+)/items$": [["POST", "item:create"]],
84 | "^/collections/([^/]+)/items/([^/]+)$": [["PUT", "item:update"], ["PATCH", "item:update"], ["DELETE", "item:delete"]],
85 | "^/collections/([^/]+)/bulk_items$": [["POST", "item:create"]]
86 | }'
87 | ```
88 |
89 | For role-based permissions on a private API:
90 |
91 | ```bash
92 | # Set default policy to private
93 | DEFAULT_PUBLIC=false
94 |
95 | # Require specific scopes for write operations
96 | PRIVATE_ENDPOINTS='{
97 | "^/collections$": [["POST", "admin"]],
98 | "^/collections/([^/]+)$": [["PUT", "admin"], ["PATCH", "admin"], ["DELETE", "admin"]],
99 | "^/collections/([^/]+)/items$": [["POST", "editor"]],
100 | "^/collections/([^/]+)/items/([^/]+)$": [["PUT", "editor"], ["PATCH", "editor"], ["DELETE", "editor"]],
101 | "^/collections/([^/]+)/bulk_items$": [["POST", "editor"]]
102 | }'
103 | ```
104 |
105 | **Behavior**
106 |
107 | - Users must be authenticated AND have the required scope(s)
108 | - Different HTTP methods can require different scopes
109 | - Scopes are checked against the user's JWT scope claim
110 | - Unauthorized requests receive a 401 Unauthorized response
111 |
112 | > [!TIP]
113 | >
114 | > Multiple scopes can be provided in a space-separated format, such as `["POST", "scope_a scope_b scope_c"]`. These scope requirements are applied with AND logic, meaning that the incoming JWT must contain all the mentioned scopes.
115 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/utils/middleware.py:
--------------------------------------------------------------------------------
1 | """Utilities for middleware response handling."""
2 |
3 | import json
4 | import logging
5 | from abc import ABC, abstractmethod
6 | from typing import Any, Optional
7 |
8 | from starlette.datastructures import MutableHeaders
9 | from starlette.requests import Request
10 | from starlette.responses import JSONResponse
11 | from starlette.types import ASGIApp, Message, Receive, Scope, Send
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class JsonResponseMiddleware(ABC):
17 | """Base class for middleware that transforms JSON response bodies."""
18 |
19 | app: ASGIApp
20 |
21 | # Expected data type for JSON responses. Only responses matching this type will be transformed.
22 | # If None, all JSON responses will be transformed regardless of type.
23 | expected_data_type: Optional[type] = dict
24 |
25 | @abstractmethod
26 | def should_transform_response(
27 | self, request: Request, scope: Scope
28 | ) -> bool: # mypy: ignore
29 | """
30 | Determine if this response should be transformed. At a minimum, this
31 | should check the request's path and content type.
32 |
33 | Returns
34 | -------
35 | bool: True if the response should be transformed
36 |
37 | """
38 | ...
39 |
40 | @abstractmethod
41 | def transform_json(self, data: Any, request: Request) -> Any:
42 | """
43 | Transform the JSON data.
44 |
45 | Args:
46 | data: The parsed JSON data
47 | request: The HTTP request object
48 |
49 | Returns:
50 | -------
51 | The transformed JSON data
52 |
53 | """
54 | ...
55 |
56 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
57 | """Process the request/response."""
58 | if scope["type"] != "http":
59 | return await self.app(scope, receive, send)
60 |
61 | start_message: Optional[Message] = None
62 | body = b""
63 |
64 | async def transform_response(message: Message) -> None:
65 | nonlocal start_message
66 | nonlocal body
67 |
68 | start_message = start_message or message
69 | headers = MutableHeaders(scope=start_message)
70 | request = Request(scope)
71 |
72 | if not self.should_transform_response(
73 | request=request,
74 | scope=start_message,
75 | ):
76 | # For non-JSON responses, send the start message immediately
77 | await send(message)
78 | return
79 |
80 | # Delay sending start message until we've processed the body
81 | if message["type"] == "http.response.start":
82 | return
83 |
84 | body += message["body"]
85 |
86 | # Skip body chunks until all chunks have been received
87 | if message.get("more_body"):
88 | return
89 |
90 | # Transform the JSON body
91 | if body:
92 | try:
93 | data = json.loads(body)
94 | except json.JSONDecodeError as e:
95 | logger.error("Error parsing JSON: %s", e)
96 | logger.error("Body: %s", body)
97 | logger.error("Response scope: %s", scope)
98 | response = JSONResponse(
99 | {"error": "Received invalid JSON from upstream server"},
100 | status_code=502,
101 | )
102 | await response(scope, receive, send)
103 | return
104 |
105 | if self.expected_data_type is None or isinstance(
106 | data, self.expected_data_type
107 | ):
108 | transformed = self.transform_json(data, request=request)
109 | body = json.dumps(transformed).encode()
110 | else:
111 | logger.warning(
112 | "Received JSON response with unexpected data type %r from upstream server (%r %r), "
113 | "skipping transformation (expected: %r)",
114 | type(data).__name__,
115 | request.method,
116 | request.url,
117 | self.expected_data_type.__name__,
118 | )
119 |
120 | # Update content-length header
121 | headers["content-length"] = str(len(body))
122 | start_message["headers"] = [
123 | (key.encode(), value.encode()) for key, value in headers.items()
124 | ]
125 |
126 | # Send response
127 | await send(start_message)
128 | await send(
129 | {
130 | "type": "http.response.body",
131 | "body": body,
132 | "more_body": False,
133 | }
134 | )
135 |
136 | return await self.app(scope, receive, transform_response)
137 |
138 |
139 | def required_conformance(
140 | *conformances: str,
141 | attr_name: str = "__required_conformances__",
142 | ):
143 | """Register required conformance classes with a middleware class."""
144 |
145 | def decorator(middleware):
146 | setattr(middleware, attr_name, list(conformances))
147 | return middleware
148 |
149 | return decorator
150 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/handlers/reverse_proxy.py:
--------------------------------------------------------------------------------
1 | """Tooling to manage the reverse proxying of requests to an upstream STAC API."""
2 |
3 | import logging
4 | import time
5 | from dataclasses import dataclass, field
6 |
7 | import httpx
8 | from fastapi import Request
9 | from starlette.datastructures import MutableHeaders
10 | from starlette.responses import Response
11 |
12 | from stac_auth_proxy.utils.requests import build_server_timing_header
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @dataclass
18 | class ReverseProxyHandler:
19 | """Reverse proxy functionality."""
20 |
21 | upstream: str
22 | client: httpx.AsyncClient = None
23 | timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0))
24 |
25 | proxy_name: str = "stac-auth-proxy"
26 | override_host: bool = True
27 | legacy_forwarded_headers: bool = False
28 |
29 | def __post_init__(self):
30 | """Initialize the HTTP client."""
31 | self.client = self.client or httpx.AsyncClient(
32 | base_url=self.upstream,
33 | timeout=self.timeout,
34 | http2=True,
35 | )
36 |
37 | def _prepare_headers(self, request: Request) -> MutableHeaders:
38 | """
39 | Prepare headers for the proxied request. Construct a Forwarded header to inform
40 | the upstream API about the original request context, which will allow it to
41 | properly construct URLs in responses (namely, in the Links). If there are
42 | existing X-Forwarded-*/Forwarded headers (typically, in situations where the
43 | STAC Auth Proxy is behind a proxy like Traefik or NGINX), we use those values.
44 | """
45 | headers = MutableHeaders(request.headers)
46 | headers.setdefault("Via", f"1.1 {self.proxy_name}")
47 |
48 | proxy_client = headers.get(
49 | "X-Forwarded-For", request.client.host if request.client else "unknown"
50 | )
51 | proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme)
52 | proxy_host = headers.get("X-Forwarded-Host", request.url.netloc)
53 | proxy_port = str(headers.get("X-Forwarded-Port", request.url.port))
54 | proxy_path = headers.get("X-Forwarded-Path", request.base_url.path)
55 |
56 | # NOTE: If we don't include a port, it's possible that the upstream server may
57 | # mistakenly use the port from the Host header (which may be the internal port
58 | # of the upstream server) when constructing URLs.
59 | forwarded_host = proxy_host
60 | if proxy_port:
61 | forwarded_host = f"{forwarded_host}:{proxy_port}"
62 |
63 | headers.setdefault(
64 | "Forwarded",
65 | f"for={proxy_client};host={forwarded_host};proto={proxy_proto};path={proxy_path}",
66 | )
67 |
68 | # NOTE: This is useful if the upstream API does not support the Forwarded header
69 | # and there were no existing X-Forwarded-* headers on the incoming request.
70 | if self.legacy_forwarded_headers:
71 | headers.setdefault("X-Forwarded-For", proxy_client)
72 | headers.setdefault("X-Forwarded-Host", proxy_host)
73 | headers.setdefault("X-Forwarded-Path", proxy_path)
74 | headers.setdefault("X-Forwarded-Proto", proxy_proto)
75 | headers.setdefault("X-Forwarded-Port", proxy_port)
76 |
77 | # Set host to the upstream host
78 | if self.override_host:
79 | headers["Host"] = self.client.base_url.netloc.decode("utf-8")
80 |
81 | return headers
82 |
83 | async def proxy_request(self, request: Request) -> Response:
84 | """Proxy a request to the upstream STAC API."""
85 | headers = self._prepare_headers(request)
86 |
87 | # https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
88 | rp_req = self.client.build_request(
89 | request.method,
90 | url=httpx.URL(
91 | path=request.url.path,
92 | query=request.url.query.encode("utf-8"),
93 | ),
94 | headers=headers,
95 | content=request.stream(),
96 | )
97 |
98 | # NOTE: HTTPX adds headers, so we need to trim them before sending request
99 | for h in rp_req.headers:
100 | if h not in headers:
101 | del rp_req.headers[h]
102 |
103 | logger.debug(f"Proxying request to {rp_req.url}")
104 |
105 | start_time = time.perf_counter()
106 | rp_resp = await self.client.send(rp_req, stream=True)
107 | proxy_time = time.perf_counter() - start_time
108 | rp_resp.headers["Server-Timing"] = build_server_timing_header(
109 | rp_resp.headers.get("Server-Timing"),
110 | name="upstream",
111 | dur=proxy_time,
112 | desc="Upstream processing time",
113 | )
114 | logger.debug(
115 | f"Received response status {rp_resp.status_code!r} from {rp_req.url} in {proxy_time:.3f}s"
116 | )
117 |
118 | # We read the content here to make use of HTTPX's decompression, ensuring we have
119 | # non-compressed content for the middleware to work with.
120 | content = await rp_resp.aread()
121 | if rp_resp.headers.get("Content-Encoding"):
122 | del rp_resp.headers["Content-Encoding"]
123 |
124 | return Response(
125 | content=content,
126 | status_code=rp_resp.status_code,
127 | headers=dict(rp_resp.headers),
128 | )
129 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/lifespan.py:
--------------------------------------------------------------------------------
1 | """Reusable lifespan handler for FastAPI applications."""
2 |
3 | import asyncio
4 | import logging
5 | import re
6 | from contextlib import asynccontextmanager
7 | from typing import Any
8 |
9 | import httpx
10 | from fastapi import FastAPI
11 | from pydantic import HttpUrl
12 | from starlette.middleware import Middleware
13 |
14 | from .config import Settings
15 |
16 | logger = logging.getLogger(__name__)
17 | __all__ = ["build_lifespan", "check_conformance", "check_server_health"]
18 |
19 |
20 | async def check_server_healths(*urls: str | HttpUrl) -> None:
21 | """Wait for upstream APIs to become available."""
22 | logger.info("Running upstream server health checks...")
23 | for url in urls:
24 | await check_server_health(url)
25 | logger.info(
26 | "Upstream servers are healthy:\n%s",
27 | "\n".join([f" - {url}" for url in urls]),
28 | )
29 |
30 |
31 | async def check_server_health(
32 | url: str | HttpUrl,
33 | max_retries: int = 10,
34 | retry_delay: float = 1.0,
35 | retry_delay_max: float = 5.0,
36 | timeout: float = 5.0,
37 | ) -> None:
38 | """Wait for upstream API to become available."""
39 | # Convert url to string if it's a HttpUrl
40 | if isinstance(url, HttpUrl):
41 | url = str(url)
42 |
43 | async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
44 | for attempt in range(max_retries):
45 | try:
46 | response = await client.get(url)
47 | response.raise_for_status()
48 | logger.info(f"Upstream API {url!r} is healthy")
49 | return
50 | except httpx.ConnectError as e:
51 | logger.warning(f"Upstream health check for {url!r} failed: {e}")
52 | retry_in = min(retry_delay * (2**attempt), retry_delay_max)
53 | logger.warning(
54 | f"Upstream API {url!r} not healthy, retrying in {retry_in:.1f}s "
55 | f"(attempt {attempt + 1}/{max_retries})"
56 | )
57 | await asyncio.sleep(retry_in)
58 |
59 | raise RuntimeError(
60 | f"Upstream API {url!r} failed to respond after {max_retries} attempts"
61 | )
62 |
63 |
64 | async def check_conformance(
65 | middleware_classes: list[Middleware],
66 | api_url: str,
67 | attr_name: str = "__required_conformances__",
68 | endpoint: str = "/conformance",
69 | ):
70 | """Check if the upstream API supports a given conformance class."""
71 | required_conformances: dict[str, list[str]] = {}
72 | for middleware in middleware_classes:
73 | for conformance in getattr(middleware.cls, attr_name, []):
74 | required_conformances.setdefault(conformance, []).append(
75 | middleware.cls.__name__
76 | )
77 |
78 | async with httpx.AsyncClient(base_url=api_url) as client:
79 | response = await client.get(endpoint)
80 | response.raise_for_status()
81 | api_conforms_to = response.json().get("conformsTo", [])
82 |
83 | missing = [
84 | req_conformance
85 | for req_conformance in required_conformances.keys()
86 | if not any(
87 | re.match(req_conformance, conformance) for conformance in api_conforms_to
88 | )
89 | ]
90 |
91 | def conformance_str(conformance: str) -> str:
92 | return f" - {conformance} [{','.join(required_conformances[conformance])}]"
93 |
94 | if missing:
95 | missing_str = [conformance_str(c) for c in missing]
96 | raise RuntimeError(
97 | "\n".join(
98 | [
99 | "Upstream catalog is missing the following conformance classes:",
100 | *missing_str,
101 | ]
102 | )
103 | )
104 | logger.info(
105 | "Upstream catalog conforms to the following required conformance classes: \n%s",
106 | "\n".join([conformance_str(c) for c in required_conformances]),
107 | )
108 |
109 |
110 | def build_lifespan(settings: Settings | None = None, **settings_kwargs: Any):
111 | """
112 | Create a lifespan handler that runs startup checks.
113 |
114 | Parameters
115 | ----------
116 | settings : Settings | None, optional
117 | Pre-built settings instance. If omitted, a new one is constructed from
118 | ``settings_kwargs``.
119 | **settings_kwargs : Any
120 | Keyword arguments used to configure the health and conformance checks if
121 | ``settings`` is not provided.
122 |
123 | Returns
124 | -------
125 | Callable[[FastAPI], AsyncContextManager[Any]]
126 | A callable suitable for the ``lifespan`` parameter of ``FastAPI``.
127 |
128 | """
129 | if settings is None:
130 | settings = Settings(**settings_kwargs)
131 |
132 | @asynccontextmanager
133 | async def lifespan(app: "FastAPI"):
134 | assert settings is not None # Required for type checking
135 |
136 | # Wait for upstream servers to become available
137 | if settings.wait_for_upstream:
138 | await check_server_healths(
139 | settings.upstream_url, settings.oidc_discovery_internal_url
140 | )
141 |
142 | # Log all middleware connected to the app
143 | logger.info(
144 | "Connected middleware:\n%s",
145 | "\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
146 | )
147 |
148 | if settings.check_conformance:
149 | await check_conformance(app.user_middleware, str(settings.upstream_url))
150 |
151 | yield
152 |
153 | return lifespan
154 |
--------------------------------------------------------------------------------
/docs/architecture/middleware-stack.md:
--------------------------------------------------------------------------------
1 | # Middleware Stack
2 |
3 | Aside from the actual communication with the upstream STAC API, the majority of the proxy's functionality occurs within a chain of middlewares. Each request passes through this chain, wherein each middleware performs a specific task. The middleware chain is ordered from last added (first to run) to first added (last to run).
4 |
5 | > [!TIP]
6 | > If you want to apply just the middleware onto your existing FastAPI application, you can do this with [`configure_app`][stac_auth_proxy.configure_app] rather than setting up a separate proxy application.
7 |
8 | > [!IMPORTANT]
9 | > The order of middleware execution is **critical**. For example, `RemoveRootPathMiddleware` must run before `EnforceAuthMiddleware` so that authentication decisions are made on the correct path after root path removal.
10 |
11 | 1. **[`CompressionMiddleware`](https://github.com/developmentseed/starlette-cramjam)**
12 |
13 | - **Enabled if:** [`ENABLE_COMPRESSION`](../../user-guide/configuration#enable_compression) is enabled
14 | - Handles response compression
15 | - Reduces response size for better performance
16 |
17 | 2. **[`RemoveRootPathMiddleware`][stac_auth_proxy.middleware.RemoveRootPathMiddleware]**
18 |
19 | - **Enabled if:** [`ROOT_PATH`](../../user-guide/configuration#root_path) is configured
20 | - Removes the application root path from incoming requests
21 | - Ensures requests are properly routed to upstream API
22 |
23 | 3. **[`ProcessLinksMiddleware`][stac_auth_proxy.middleware.ProcessLinksMiddleware]**
24 |
25 | - **Enabled if:** [`ROOT_PATH`](../../user-guide/configuration#root_path) is set or [`UPSTREAM_URL`](../../user-guide/configuration#upstream_url) path is not `"/"`
26 | - Updates links in JSON responses to handle root path and upstream URL path differences
27 | - Removes upstream URL path from links and adds root path if configured
28 |
29 | 4. **[`EnforceAuthMiddleware`][stac_auth_proxy.middleware.EnforceAuthMiddleware]**
30 |
31 | - **Enabled if:** Always active (core authentication middleware)
32 | - Handles authentication and authorization
33 | - Configurable public/private endpoints via [`PUBLIC_ENDPOINTS`](../../user-guide/configuration#public_endpoints) and [`PRIVATE_ENDPOINTS`](../../user-guide/configuration#private_endpoints)
34 | - OIDC integration via [`OIDC_DISCOVERY_INTERNAL_URL`](../../user-guide/configuration#oidc_discovery_internal_url)
35 | - JWT audience validation via [`ALLOWED_JWT_AUDIENCES`](../../user-guide/configuration#allowed_jwt_audiences)
36 | - Places auth token payload in request state
37 |
38 | 5. **[`AddProcessTimeHeaderMiddleware`][stac_auth_proxy.middleware.AddProcessTimeHeaderMiddleware]**
39 |
40 | - **Enabled if:** Always active (monitoring middleware)
41 | - Adds processing time headers to responses
42 | - Useful for monitoring and debugging
43 |
44 | 6. **[`Cql2BuildFilterMiddleware`][stac_auth_proxy.middleware.Cql2BuildFilterMiddleware]**
45 |
46 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured
47 | - Builds CQL2 filters based on request context/state
48 | - Places [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) in request state
49 |
50 | 7. **[`Cql2RewriteLinksFilterMiddleware`][stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware]**
51 |
52 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured
53 | - Rewrites filter parameters in response links to remove applied filters
54 | - Ensures links in responses show the original filter state
55 |
56 | 8. **[`Cql2ApplyFilterQueryStringMiddleware`][stac_auth_proxy.middleware.Cql2ApplyFilterQueryStringMiddleware]**
57 |
58 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured
59 | - Retrieves [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) from request state
60 | - Augments `GET` requests with CQL2 filter by appending to querystring
61 |
62 | 9. **[`Cql2ApplyFilterBodyMiddleware`][stac_auth_proxy.middleware.Cql2ApplyFilterBodyMiddleware]**
63 |
64 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured
65 | - Retrieves [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) from request state
66 | - Augments `POST`/`PUT`/`PATCH` requests with CQL2 filter by modifying body
67 |
68 | 10. **[`Cql2ValidateResponseBodyMiddleware`][stac_auth_proxy.middleware.Cql2ValidateResponseBodyMiddleware]**
69 |
70 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured
71 | - Retrieves [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) from request state
72 | - Validates response against CQL2 filter for non-filterable endpoints
73 |
74 | 11. **[`OpenApiMiddleware`][stac_auth_proxy.middleware.OpenApiMiddleware]**
75 |
76 | - **Enabled if:** [`OPENAPI_SPEC_ENDPOINT`](../../user-guide/configuration#openapi_spec_endpoint) is set
77 | - Modifies OpenAPI specification based on endpoint configuration, adding security requirements
78 | - Configurable via [`OPENAPI_AUTH_SCHEME_NAME`](../../user-guide/configuration#openapi_auth_scheme_name) and [`OPENAPI_AUTH_SCHEME_OVERRIDE`](../../user-guide/configuration#openapi_auth_scheme_override)
79 |
80 | 12. **[`AuthenticationExtensionMiddleware`][stac_auth_proxy.middleware.AuthenticationExtensionMiddleware]**
81 | - **Enabled if:** [`ENABLE_AUTHENTICATION_EXTENSION`](../../user-guide/configuration#enable_authentication_extension) is enabled
82 | - Adds authentication extension information to STAC responses
83 | - Annotates links with authentication requirements based on [`PUBLIC_ENDPOINTS`](../../user-guide/configuration#public_endpoints) and [`PRIVATE_ENDPOINTS`](../../user-guide/configuration#private_endpoints)
84 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/app.py:
--------------------------------------------------------------------------------
1 | """
2 | STAC Auth Proxy API.
3 |
4 | This module defines the FastAPI application for the STAC Auth Proxy, which handles
5 | authentication, authorization, and proxying of requests to some internal STAC API.
6 | """
7 |
8 | import logging
9 | from typing import Any, Optional
10 |
11 | from fastapi import FastAPI
12 | from starlette_cramjam.middleware import CompressionMiddleware
13 |
14 | from .config import Settings
15 | from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI
16 | from .lifespan import build_lifespan
17 | from .middleware import (
18 | AddProcessTimeHeaderMiddleware,
19 | AuthenticationExtensionMiddleware,
20 | Cql2ApplyFilterBodyMiddleware,
21 | Cql2ApplyFilterQueryStringMiddleware,
22 | Cql2BuildFilterMiddleware,
23 | Cql2RewriteLinksFilterMiddleware,
24 | Cql2ValidateResponseBodyMiddleware,
25 | EnforceAuthMiddleware,
26 | OpenApiMiddleware,
27 | ProcessLinksMiddleware,
28 | RemoveRootPathMiddleware,
29 | )
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | def configure_app(
35 | app: FastAPI,
36 | settings: Optional[Settings] = None,
37 | **settings_kwargs: Any,
38 | ) -> FastAPI:
39 | """
40 | Apply routes and middleware to a FastAPI app.
41 |
42 | Parameters
43 | ----------
44 | app : FastAPI
45 | The FastAPI app to configure.
46 | settings : Settings | None, optional
47 | Pre-built settings instance. If omitted, a new one is constructed from
48 | ``settings_kwargs``.
49 | **settings_kwargs : Any
50 | Keyword arguments used to configure the health and conformance checks if
51 | ``settings`` is not provided.
52 |
53 | """
54 | settings = settings or Settings(**settings_kwargs)
55 |
56 | #
57 | # Route Handlers
58 | #
59 |
60 | # If we have customized Swagger UI Init settings (e.g. a provided client_id)
61 | # then we need to serve our own Swagger UI in place of the upstream's. This requires
62 | # that we know the Swagger UI endpoint and the OpenAPI spec endpoint.
63 | if all(
64 | [
65 | settings.swagger_ui_endpoint,
66 | settings.openapi_spec_endpoint,
67 | settings.swagger_ui_init_oauth,
68 | ]
69 | ):
70 | app.add_route(
71 | settings.swagger_ui_endpoint,
72 | SwaggerUI(
73 | openapi_url=settings.openapi_spec_endpoint, # type: ignore
74 | init_oauth=settings.swagger_ui_init_oauth,
75 | ).route,
76 | include_in_schema=False,
77 | )
78 |
79 | if settings.healthz_prefix:
80 | app.include_router(
81 | HealthzHandler(upstream_url=str(settings.upstream_url)).router,
82 | prefix=settings.healthz_prefix,
83 | )
84 |
85 | #
86 | # Middleware (order is important, last added = first to run)
87 | #
88 |
89 | if settings.enable_authentication_extension:
90 | app.add_middleware(
91 | AuthenticationExtensionMiddleware,
92 | default_public=settings.default_public,
93 | public_endpoints=settings.public_endpoints,
94 | private_endpoints=settings.private_endpoints,
95 | oidc_discovery_url=str(settings.oidc_discovery_url),
96 | )
97 |
98 | if settings.openapi_spec_endpoint:
99 | app.add_middleware(
100 | OpenApiMiddleware,
101 | openapi_spec_path=settings.openapi_spec_endpoint,
102 | oidc_discovery_url=str(settings.oidc_discovery_url),
103 | public_endpoints=settings.public_endpoints,
104 | private_endpoints=settings.private_endpoints,
105 | default_public=settings.default_public,
106 | root_path=settings.root_path,
107 | auth_scheme_name=settings.openapi_auth_scheme_name,
108 | auth_scheme_override=settings.openapi_auth_scheme_override,
109 | )
110 |
111 | if settings.items_filter or settings.collections_filter:
112 | app.add_middleware(Cql2ValidateResponseBodyMiddleware)
113 | app.add_middleware(Cql2ApplyFilterBodyMiddleware)
114 | app.add_middleware(Cql2ApplyFilterQueryStringMiddleware)
115 | app.add_middleware(Cql2RewriteLinksFilterMiddleware)
116 | app.add_middleware(
117 | Cql2BuildFilterMiddleware,
118 | items_filter=settings.items_filter() if settings.items_filter else None,
119 | collections_filter=(
120 | settings.collections_filter() if settings.collections_filter else None
121 | ),
122 | collections_filter_path=settings.collections_filter_path,
123 | items_filter_path=settings.items_filter_path,
124 | )
125 |
126 | app.add_middleware(
127 | AddProcessTimeHeaderMiddleware,
128 | )
129 |
130 | app.add_middleware(
131 | EnforceAuthMiddleware,
132 | public_endpoints=settings.public_endpoints,
133 | private_endpoints=settings.private_endpoints,
134 | default_public=settings.default_public,
135 | oidc_discovery_url=settings.oidc_discovery_internal_url,
136 | allowed_jwt_audiences=settings.allowed_jwt_audiences,
137 | )
138 |
139 | if settings.root_path or settings.upstream_url.path != "/":
140 | app.add_middleware(
141 | ProcessLinksMiddleware,
142 | upstream_url=str(settings.upstream_url),
143 | root_path=settings.root_path,
144 | )
145 |
146 | if settings.root_path:
147 | app.add_middleware(
148 | RemoveRootPathMiddleware,
149 | root_path=settings.root_path,
150 | )
151 |
152 | if settings.enable_compression:
153 | app.add_middleware(
154 | CompressionMiddleware,
155 | )
156 |
157 | return app
158 |
159 |
160 | def create_app(settings: Optional[Settings] = None) -> FastAPI:
161 | """FastAPI Application Factory."""
162 | settings = settings or Settings()
163 |
164 | app = FastAPI(
165 | openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
166 | lifespan=build_lifespan(settings=settings),
167 | root_path=settings.root_path,
168 | )
169 | if app.root_path:
170 | logger.debug("Mounted app at %s", app.root_path)
171 |
172 | configure_app(app, settings)
173 |
174 | app.add_api_route(
175 | "/{path:path}",
176 | ReverseProxyHandler(
177 | upstream=str(settings.upstream_url),
178 | override_host=settings.override_host,
179 | ).proxy_request,
180 | methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
181 | )
182 |
183 | return app
184 |
--------------------------------------------------------------------------------
/tests/test_auth_extension.py:
--------------------------------------------------------------------------------
1 | """Tests for AuthenticationExtensionMiddleware."""
2 |
3 | import pytest
4 | from starlette.requests import Request
5 |
6 | from stac_auth_proxy.config import EndpointMethods
7 | from stac_auth_proxy.middleware.AuthenticationExtensionMiddleware import (
8 | AuthenticationExtensionMiddleware,
9 | )
10 |
11 |
12 | @pytest.fixture
13 | def oidc_discovery_url():
14 | """Create test OIDC discovery URL."""
15 | return "https://auth.example.com/discovery"
16 |
17 |
18 | @pytest.fixture
19 | def middleware(oidc_discovery_url):
20 | """Create a test instance of the middleware."""
21 | return AuthenticationExtensionMiddleware(
22 | app=None, # We don't need the actual app for these tests
23 | default_public=True,
24 | private_endpoints=EndpointMethods(),
25 | public_endpoints=EndpointMethods(),
26 | oidc_discovery_url=oidc_discovery_url,
27 | auth_scheme_name="test_auth",
28 | auth_scheme={},
29 | )
30 |
31 |
32 | @pytest.fixture
33 | def request_scope():
34 | """Create a basic request scope."""
35 | return {
36 | "type": "http",
37 | "method": "GET",
38 | "path": "/",
39 | "headers": [],
40 | }
41 |
42 |
43 | @pytest.fixture(params=[b"application/json", b"application/geo+json"])
44 | def initial_message(request):
45 | """Create headers with JSON content type."""
46 | return {
47 | "type": "http.response.start",
48 | "status": 200,
49 | "headers": [
50 | (b"date", b"Mon, 07 Apr 2025 06:55:37 GMT"),
51 | (b"server", b"uvicorn"),
52 | (b"content-length", b"27642"),
53 | (b"content-type", request.param),
54 | (b"x-upstream-time", b"0.063"),
55 | ],
56 | }
57 |
58 |
59 | def test_should_transform_response_valid_paths(
60 | middleware, request_scope, initial_message
61 | ):
62 | """Test that valid STAC paths are transformed."""
63 | valid_paths = [
64 | "/",
65 | "/collections",
66 | "/collections/test-collection",
67 | "/collections/test-collection/items",
68 | "/collections/test-collection/items/test-item",
69 | "/search",
70 | ]
71 |
72 | for path in valid_paths:
73 | request_scope["path"] = path
74 | request = Request(request_scope)
75 | assert middleware.should_transform_response(request, initial_message)
76 |
77 |
78 | def test_should_transform_response_invalid_paths(
79 | middleware, request_scope, initial_message
80 | ):
81 | """Test that invalid paths are not transformed."""
82 | invalid_paths = [
83 | "/api",
84 | "/collections/test-collection/items/test-item/assets",
85 | "/random",
86 | ]
87 |
88 | for path in invalid_paths:
89 | request_scope["path"] = path
90 | request = Request(request_scope)
91 | assert not middleware.should_transform_response(request, initial_message)
92 |
93 |
94 | def test_should_transform_response_invalid_content_type(middleware, request_scope):
95 | """Test that non-JSON content types are not transformed."""
96 | request = Request(request_scope)
97 | assert not middleware.should_transform_response(
98 | request,
99 | {
100 | "type": "http.response.start",
101 | "status": 200,
102 | "headers": [
103 | (b"date", b"Mon, 07 Apr 2025 06:55:37 GMT"),
104 | (b"server", b"uvicorn"),
105 | (b"content-length", b"27642"),
106 | (b"content-type", b"text/html"),
107 | (b"x-upstream-time", b"0.063"),
108 | ],
109 | },
110 | )
111 |
112 |
113 | def test_transform_json_catalog(middleware, request_scope, oidc_discovery_url):
114 | """Test transforming a STAC catalog."""
115 | request = Request(request_scope)
116 |
117 | catalog = {
118 | "stac_version": "1.0.0",
119 | "id": "test-catalog",
120 | "description": "Test catalog",
121 | "links": [
122 | {"rel": "self", "href": "/"},
123 | {"rel": "root", "href": "/"},
124 | ],
125 | }
126 |
127 | transformed = middleware.transform_json(catalog, request)
128 |
129 | assert "stac_extensions" in transformed
130 | assert middleware.extension_url in transformed["stac_extensions"]
131 | assert "auth:schemes" in transformed
132 | assert "test_auth" in transformed["auth:schemes"]
133 |
134 | scheme = transformed["auth:schemes"]["test_auth"]
135 | assert scheme["type"] == "openIdConnect"
136 | assert scheme["openIdConnectUrl"] == oidc_discovery_url
137 |
138 |
139 | def test_transform_json_collection(middleware, request_scope):
140 | """Test transforming a STAC collection."""
141 | request = Request(request_scope)
142 |
143 | collection = {
144 | "stac_version": "1.0.0",
145 | "type": "Collection",
146 | "id": "test-collection",
147 | "description": "Test collection",
148 | "links": [
149 | {"rel": "self", "href": "/collections/test-collection"},
150 | {"rel": "items", "href": "/collections/test-collection/items"},
151 | ],
152 | }
153 |
154 | transformed = middleware.transform_json(collection, request)
155 |
156 | assert "stac_extensions" in transformed
157 | assert middleware.extension_url in transformed["stac_extensions"]
158 | assert "auth:schemes" in transformed
159 | assert "test_auth" in transformed["auth:schemes"]
160 |
161 |
162 | def test_transform_json_item(middleware, request_scope):
163 | """Test transforming a STAC item."""
164 | request = Request(request_scope)
165 |
166 | item = {
167 | "stac_version": "1.0.0",
168 | "type": "Feature",
169 | "id": "test-item",
170 | "properties": {},
171 | "links": [
172 | {"rel": "self", "href": "/collections/test-collection/items/test-item"},
173 | {"rel": "collection", "href": "/collections/test-collection"},
174 | ],
175 | }
176 |
177 | transformed = middleware.transform_json(item, request)
178 |
179 | assert "stac_extensions" in transformed
180 | assert middleware.extension_url in transformed["stac_extensions"]
181 | assert "auth:schemes" in transformed["properties"]
182 | assert "test_auth" in transformed["properties"]["auth:schemes"]
183 |
184 |
185 | def test_transform_json_missing_oidc_metadata(middleware, request_scope):
186 | """Test transforming when OIDC metadata is missing."""
187 | request = Request(request_scope)
188 |
189 | catalog = {
190 | "stac_version": "1.0.0",
191 | "id": "test-catalog",
192 | "description": "Test catalog",
193 | }
194 |
195 | transformed = middleware.transform_json(catalog, request)
196 | # Should return unchanged when OIDC metadata is missing
197 | assert transformed == catalog
198 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py:
--------------------------------------------------------------------------------
1 | """Middleware to enforce authentication."""
2 |
3 | import logging
4 | from dataclasses import dataclass, field
5 | from typing import Annotated, Any, Optional, Sequence
6 | from urllib.parse import urlparse, urlunparse
7 |
8 | import httpx
9 | import jwt
10 | from fastapi import HTTPException, Request, Security, status
11 | from pydantic import HttpUrl
12 | from starlette.responses import JSONResponse
13 | from starlette.types import ASGIApp, Receive, Scope, Send
14 |
15 | from ..config import EndpointMethods
16 | from ..utils.requests import find_match
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | @dataclass
22 | class OidcService:
23 | """OIDC configuration and JWKS client."""
24 |
25 | oidc_discovery_url: HttpUrl
26 | jwks_client: jwt.PyJWKClient = field(init=False)
27 | metadata: dict[str, Any] = field(init=False)
28 |
29 | def __post_init__(self) -> None:
30 | """Initialize OIDC config and JWKS client."""
31 | logger.debug("Requesting OIDC config")
32 | origin_url = str(self.oidc_discovery_url)
33 |
34 | try:
35 | response = httpx.get(origin_url)
36 | response.raise_for_status()
37 | self.metadata = response.json()
38 | assert self.metadata, "OIDC metadata is empty"
39 |
40 | # NOTE: We manually replace the origin of the jwks_uri in the event that
41 | # the jwks_uri is not available from within the proxy.
42 | oidc_url = urlparse(origin_url)
43 | jwks_uri = urlunparse(
44 | urlparse(self.metadata["jwks_uri"])._replace(
45 | netloc=oidc_url.netloc, scheme=oidc_url.scheme
46 | )
47 | )
48 | if jwks_uri != self.metadata["jwks_uri"]:
49 | logger.warning(
50 | "OIDC Discovery reported a JWKS URI of %s but we're going to use %s to match the OIDC Discovery URL",
51 | self.metadata["jwks_uri"],
52 | jwks_uri,
53 | )
54 | self.jwks_client = jwt.PyJWKClient(jwks_uri)
55 | except httpx.HTTPStatusError as e:
56 | logger.error(
57 | "Received a non-200 response when fetching OIDC config: %s",
58 | e.response.text,
59 | )
60 | raise OidcFetchError(
61 | f"Request for OIDC config failed with status {e.response.status_code}"
62 | ) from e
63 | except httpx.RequestError as e:
64 | logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
65 | raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
66 |
67 |
68 | @dataclass
69 | class EnforceAuthMiddleware:
70 | """Middleware to enforce authentication."""
71 |
72 | app: ASGIApp
73 | private_endpoints: EndpointMethods
74 | public_endpoints: EndpointMethods
75 | default_public: bool
76 | oidc_discovery_url: HttpUrl
77 | allowed_jwt_audiences: Optional[Sequence[str]] = None
78 | state_key: str = "payload"
79 |
80 | _oidc_config: Optional[OidcService] = None
81 |
82 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
83 | """Enforce authentication."""
84 | if scope["type"] != "http":
85 | return await self.app(scope, receive, send)
86 |
87 | request = Request(scope)
88 |
89 | # Skip authentication for OPTIONS requests, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials
90 | if request.method == "OPTIONS":
91 | return await self.app(scope, receive, send)
92 |
93 | match = find_match(
94 | request.url.path,
95 | request.method,
96 | private_endpoints=self.private_endpoints,
97 | public_endpoints=self.public_endpoints,
98 | default_public=self.default_public,
99 | )
100 | try:
101 | payload = self.validate_token(
102 | request.headers.get("Authorization"),
103 | auto_error=match.is_private,
104 | required_scopes=match.required_scopes,
105 | )
106 |
107 | except HTTPException as e:
108 | response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
109 | return await response(scope, receive, send)
110 |
111 | # Set the payload in the request state
112 | setattr(
113 | request.state,
114 | self.state_key,
115 | payload,
116 | )
117 | setattr(request.state, "oidc_metadata", self.oidc_config.metadata)
118 | return await self.app(scope, receive, send)
119 |
120 | def validate_token(
121 | self,
122 | auth_header: Annotated[str, Security(...)],
123 | auto_error: bool = True,
124 | required_scopes: Optional[Sequence[str]] = None,
125 | ) -> Optional[dict[str, Any]]:
126 | """Dependency to validate an OIDC token."""
127 | if not auth_header:
128 | if auto_error:
129 | raise HTTPException(
130 | status_code=status.HTTP_401_UNAUTHORIZED,
131 | detail="Not authenticated",
132 | headers={"WWW-Authenticate": "Bearer"},
133 | )
134 | return None
135 |
136 | # Extract token from header
137 | token_parts = auth_header.split(" ")
138 | if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
139 | logger.error("Invalid Authorization header format")
140 | raise HTTPException(
141 | status_code=status.HTTP_401_UNAUTHORIZED,
142 | detail="Invalid Authorization header format",
143 | headers={"WWW-Authenticate": "Bearer"},
144 | )
145 | [_, token] = token_parts
146 |
147 | # Parse & validate token
148 | try:
149 | key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key
150 | payload = jwt.decode(
151 | token,
152 | key,
153 | algorithms=["RS256"],
154 | # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
155 | audience=self.allowed_jwt_audiences,
156 | )
157 | except jwt.InvalidAudienceError as e:
158 | logger.error("Token audience validation failed: %s", str(e))
159 | raise HTTPException(
160 | status_code=status.HTTP_401_UNAUTHORIZED,
161 | detail="Invalid token audience",
162 | headers={"WWW-Authenticate": "Bearer"},
163 | )
164 | except (
165 | jwt.exceptions.InvalidTokenError,
166 | jwt.exceptions.DecodeError,
167 | jwt.exceptions.PyJWKClientError,
168 | ) as e:
169 | logger.error("Token validation failed: %s", type(e).__name__)
170 | raise HTTPException(
171 | status_code=status.HTTP_401_UNAUTHORIZED,
172 | detail="Invalid or expired token",
173 | headers={"WWW-Authenticate": "Bearer"},
174 | ) from e
175 |
176 | # Check authorization (scopes)
177 | if required_scopes:
178 | token_scopes = set(payload.get("scope", "").split())
179 | missing_scopes = set(required_scopes) - token_scopes
180 | if missing_scopes:
181 | logger.warning(
182 | "Insufficient scopes for user %s. Required: %s, Has: %s",
183 | payload.get("sub"),
184 | required_scopes,
185 | token_scopes,
186 | )
187 | raise HTTPException(
188 | status_code=status.HTTP_403_FORBIDDEN,
189 | detail=f"Insufficient permissions. Required scopes: {', '.join(missing_scopes)}",
190 | headers={
191 | "WWW-Authenticate": f'Bearer scope="{" ".join(required_scopes)}"'
192 | },
193 | )
194 |
195 | return payload
196 |
197 | @property
198 | def oidc_config(self) -> OidcService:
199 | """Get the OIDC configuration."""
200 | if not self._oidc_config:
201 | self._oidc_config = OidcService(oidc_discovery_url=self.oidc_discovery_url)
202 | return self._oidc_config
203 |
204 |
205 | class OidcFetchError(Exception):
206 | """Error fetching OIDC configuration."""
207 |
208 | ...
209 |
--------------------------------------------------------------------------------
/src/stac_auth_proxy/utils/requests.py:
--------------------------------------------------------------------------------
1 | """Utility functions for working with HTTP requests."""
2 |
3 | import json
4 | import logging
5 | import re
6 | from dataclasses import dataclass, field
7 | from typing import Dict, Optional, Sequence
8 | from urllib.parse import urlparse
9 |
10 | from starlette.requests import Request
11 |
12 | from ..config import EndpointMethods
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def extract_variables(url: str) -> dict:
18 | """
19 | Extract variables from a URL path. Being that we use a catch-all endpoint for the proxy,
20 | we can't rely on the path parameters that FastAPI provides.
21 | """
22 | path = urlparse(url).path
23 | # This allows either /items or /bulk_items, with an optional item_id following.
24 | pattern = r"^/collections/(?P[^/]+)(?:/(?:items|bulk_items)(?:/(?P[^/]+))?)?/?$"
25 | match = re.match(pattern, path)
26 | return {k: v for k, v in match.groupdict().items() if v} if match else {}
27 |
28 |
29 | def dict_to_bytes(d: dict) -> bytes:
30 | """Convert a dictionary to a body."""
31 | return json.dumps(d, separators=(",", ":")).encode("utf-8")
32 |
33 |
34 | def _check_endpoint_match(
35 | path: str,
36 | method: str,
37 | endpoints: EndpointMethods,
38 | ) -> tuple[bool, Sequence[str]]:
39 | """Check if the path and method match any endpoint in the given endpoints map."""
40 | for pattern, endpoint_methods in endpoints.items():
41 | if re.match(pattern, path):
42 | for endpoint_method in endpoint_methods:
43 | required_scopes: Sequence[str] = []
44 | if isinstance(endpoint_method, tuple):
45 | endpoint_method, _required_scopes = endpoint_method
46 | if _required_scopes: # Ignore empty scopes, e.g. `["POST", ""]`
47 | required_scopes = _required_scopes.split(" ")
48 | if method.casefold() == endpoint_method.casefold():
49 | return True, required_scopes
50 | return False, []
51 |
52 |
53 | def find_match(
54 | path: str,
55 | method: str,
56 | private_endpoints: EndpointMethods,
57 | public_endpoints: EndpointMethods,
58 | default_public: bool,
59 | ) -> "MatchResult":
60 | """Check if the given path and method match any of the regex patterns and methods in the endpoints."""
61 | primary_endpoints = private_endpoints if default_public else public_endpoints
62 | matched, required_scopes = _check_endpoint_match(path, method, primary_endpoints)
63 | if matched:
64 | return MatchResult(
65 | is_private=default_public,
66 | required_scopes=required_scopes,
67 | )
68 |
69 | # If default_public and no match found in private_endpoints, it's public
70 | if default_public:
71 | return MatchResult(is_private=False)
72 |
73 | # If not default_public, check private_endpoints for required scopes
74 | matched, required_scopes = _check_endpoint_match(path, method, private_endpoints)
75 | if matched:
76 | return MatchResult(is_private=True, required_scopes=required_scopes)
77 |
78 | # Default case: if not default_public and no explicit match, it's private
79 | return MatchResult(is_private=True)
80 |
81 |
82 | @dataclass
83 | class MatchResult:
84 | """Result of a match between a path and method and a set of endpoints."""
85 |
86 | is_private: bool
87 | required_scopes: Sequence[str] = field(default_factory=list)
88 |
89 |
90 | def build_server_timing_header(
91 | current_value: Optional[str] = None, *, name: str, desc: str, dur: float
92 | ):
93 | """Append a timing header to headers."""
94 | metric = f'{name};desc="{desc}";dur={dur:.3f}'
95 | if current_value:
96 | return f"{current_value}, {metric}"
97 | return metric
98 |
99 |
100 | def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]:
101 | """
102 | Parse the Forwarded header according to RFC 7239.
103 |
104 | Args:
105 | forwarded_header: The Forwarded header value
106 |
107 | Returns:
108 | Dictionary containing parsed forwarded information (proto, host, for, by, etc.)
109 |
110 | Example:
111 | >>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com")
112 | {'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'}
113 |
114 | """
115 | # Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com"
116 | # The format is: for=value1, for=value2; by=value; proto=value; host=value
117 | # We need to parse all the key=value pairs, taking the first 'for' value
118 | forwarded_info = {}
119 |
120 | try:
121 | # Parse all key=value pairs separated by semicolons
122 | for pair in forwarded_header.split(";"):
123 | pair = pair.strip()
124 | if "=" in pair:
125 | key, value = pair.split("=", 1)
126 | key = key.strip()
127 | value = value.strip().strip('"')
128 |
129 | # For 'for' field, only take the first value if there are multiple
130 | if key == "for" and key not in forwarded_info:
131 | # Extract the first for value (before comma if present)
132 | first_for_value = value.split(",")[0].strip()
133 | forwarded_info[key] = first_for_value
134 | elif key != "for":
135 | # For other fields, just use the value as-is
136 | forwarded_info[key] = value
137 | except Exception as e:
138 | logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}")
139 | return {}
140 |
141 | return forwarded_info
142 |
143 |
144 | def get_base_url(request: Request) -> str:
145 | """
146 | Get the request's base URL, accounting for forwarded headers from load balancers/proxies.
147 |
148 | This function handles both the standard Forwarded header (RFC 7239) and legacy
149 | X-Forwarded-* headers to reconstruct the original client URL when the service
150 | is deployed behind load balancers or reverse proxies.
151 |
152 | Args:
153 | request: The Starlette request object
154 |
155 | Returns:
156 | The reconstructed client base URL
157 |
158 | Example:
159 | >>> # With Forwarded header
160 | >>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"}
161 | >>> get_base_url(request)
162 | "https://api.example.com/"
163 |
164 | >>> # With X-Forwarded-* headers
165 | >>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"}
166 | >>> get_base_url(request)
167 | "https://api.example.com/"
168 |
169 | """
170 | # Check for standard Forwarded header first (RFC 7239)
171 | forwarded_header = request.headers.get("Forwarded")
172 | if forwarded_header:
173 | try:
174 | forwarded_info = parse_forwarded_header(forwarded_header)
175 | # Only use Forwarded header if we successfully parsed it and got useful info
176 | if forwarded_info and (
177 | "proto" in forwarded_info or "host" in forwarded_info
178 | ):
179 | scheme = forwarded_info.get("proto", request.url.scheme)
180 | host = forwarded_info.get("host", request.url.netloc)
181 | # Note: Forwarded header doesn't include path, so we use request.base_url.path
182 | path = request.base_url.path
183 | return f"{scheme}://{host}{path}"
184 | except Exception as e:
185 | logger.warning(f"Failed to parse Forwarded header: {e}")
186 |
187 | # Fall back to legacy X-Forwarded-* headers
188 | forwarded_host = request.headers.get("X-Forwarded-Host")
189 | forwarded_proto = request.headers.get("X-Forwarded-Proto")
190 | forwarded_path = request.headers.get("X-Forwarded-Path")
191 |
192 | if forwarded_host:
193 | # Use forwarded headers to reconstruct the original client URL
194 | scheme = forwarded_proto or request.url.scheme
195 | netloc = forwarded_host
196 | # Use forwarded path if available, otherwise use request base URL path
197 | path = forwarded_path or request.base_url.path
198 | else:
199 | # Fall back to the request's base URL if no forwarded headers
200 | scheme = request.url.scheme
201 | netloc = request.url.netloc
202 | path = request.base_url.path
203 |
204 | return f"{scheme}://{netloc}{path}"
205 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Pytest fixtures."""
2 |
3 | import os
4 | import socket
5 | import threading
6 | from functools import partial
7 | from typing import Any, AsyncGenerator
8 | from unittest.mock import DEFAULT, AsyncMock, MagicMock, patch
9 |
10 | import pytest
11 | import uvicorn
12 | from fastapi import FastAPI
13 | from jwcrypto import jwk, jwt
14 | from starlette_cramjam.middleware import CompressionMiddleware
15 | from utils import single_chunk_async_stream_response
16 |
17 |
18 | @pytest.fixture
19 | def test_key() -> jwk.JWK:
20 | """Generate a test RSA key."""
21 | return jwk.JWK.generate(
22 | kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256"
23 | )
24 |
25 |
26 | @pytest.fixture
27 | def public_key(test_key: jwk.JWK) -> dict[str, Any]:
28 | """Export public key."""
29 | return test_key.export_public(as_dict=True)
30 |
31 |
32 | @pytest.fixture(autouse=True)
33 | def mock_jwks(public_key: dict[str, Any]):
34 | """Mock JWKS endpoint."""
35 | mock_oidc_config = {
36 | "jwks_uri": "https://example.com/jwks",
37 | "authorization_endpoint": "https://example.com/auth",
38 | "token_endpoint": "https://example.com/token",
39 | "scopes_supported": ["openid", "profile", "email", "collection:create"],
40 | }
41 |
42 | mock_jwks = {"keys": [public_key]}
43 |
44 | with (
45 | patch("httpx.get") as mock_urlopen,
46 | patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data,
47 | ):
48 | mock_oidc_config_response = MagicMock()
49 | mock_oidc_config_response.json.return_value = mock_oidc_config
50 | mock_oidc_config_response.status = 200
51 |
52 | mock_urlopen.return_value = mock_oidc_config_response
53 | mock_fetch_data.return_value = mock_jwks
54 | yield mock_urlopen
55 |
56 |
57 | @pytest.fixture
58 | def token_builder(test_key: jwk.JWK):
59 | """Generate a valid JWT token builder."""
60 |
61 | def build_token(payload: dict[str, Any], key=None) -> str:
62 | jwt_token = jwt.JWT(
63 | header={k: test_key.get(k) for k in ["alg", "kid"]},
64 | claims=payload,
65 | )
66 | jwt_token.make_signed_token(key or test_key)
67 | return jwt_token.serialize()
68 |
69 | return build_token
70 |
71 |
72 | @pytest.fixture(scope="session")
73 | def source_api():
74 | """
75 | Create upstream API for testing purposes.
76 |
77 | You can customize the response for each endpoint by passing a dict of responses:
78 | {
79 | "path": {
80 | "method": response_body
81 | }
82 | }
83 | """
84 | app = FastAPI(docs_url="/api.html", openapi_url="/api")
85 |
86 | app.add_middleware(CompressionMiddleware, minimum_size=0, compression_level=1)
87 |
88 | # Default responses for each endpoint
89 | default_responses = {
90 | "/": {
91 | "GET": {"id": "Response from GET@"},
92 | "OPTIONS": {"id": "Response from OPTIONS@"},
93 | },
94 | "/conformance": {
95 | "GET": {"conformsTo": ["http://example.com/conformance"]},
96 | "OPTIONS": {"id": "Response from OPTIONS@"},
97 | },
98 | "/queryables": {
99 | "GET": {"queryables": {}},
100 | "OPTIONS": {"id": "Response from OPTIONS@"},
101 | },
102 | "/search": {
103 | "GET": {"type": "FeatureCollection", "features": []},
104 | "POST": {"type": "FeatureCollection", "features": []},
105 | "OPTIONS": {"id": "Response from OPTIONS@"},
106 | },
107 | "/collections": {
108 | "GET": {"collections": []},
109 | "POST": {"id": "Response from POST@"},
110 | "OPTIONS": {"id": "Response from OPTIONS@"},
111 | },
112 | "/collections/{collection_id}": {
113 | "GET": {"id": "Response from GET@"},
114 | "PUT": {"id": "Response from PUT@"},
115 | "PATCH": {"id": "Response from PATCH@"},
116 | "DELETE": {"id": "Response from DELETE@"},
117 | "OPTIONS": {"id": "Response from OPTIONS@"},
118 | },
119 | "/collections/{collection_id}/items": {
120 | "GET": {"type": "FeatureCollection", "features": []},
121 | "POST": {"id": "Response from POST@"},
122 | "OPTIONS": {"id": "Response from OPTIONS@"},
123 | },
124 | "/collections/{collection_id}/items/{item_id}": {
125 | "GET": {"id": "Response from GET@"},
126 | "PUT": {"id": "Response from PUT@"},
127 | "PATCH": {"id": "Response from PATCH@"},
128 | "DELETE": {"id": "Response from DELETE@"},
129 | "OPTIONS": {"id": "Response from OPTIONS@"},
130 | },
131 | "/collections/{collection_id}/bulk_items": {
132 | "POST": {"id": "Response from POST@"},
133 | "OPTIONS": {"id": "Response from OPTIONS@"},
134 | },
135 | }
136 |
137 | # Store responses in app state
138 | app.state.default_responses = default_responses
139 |
140 | def get_response(path: str, method: str) -> dict:
141 | """Get response for a given path and method."""
142 | return app.state.default_responses.get(path, {}).get(
143 | method, {"id": f"Response from {method}@{path}"}
144 | )
145 |
146 | for path, methods in default_responses.items():
147 | for method in methods:
148 | app.add_api_route(
149 | path,
150 | partial(get_response, path, method),
151 | methods=[method],
152 | )
153 |
154 | return app
155 |
156 |
157 | @pytest.fixture
158 | def source_api_responses(source_api):
159 | """
160 | Fixture to override source API responses for specific tests.
161 |
162 | Usage:
163 | def test_something(source_api_responses):
164 | # Override responses for specific endpoints
165 | source_api_responses["/collections"]["GET"] = {"collections": [{"id": "test"}]}
166 | source_api_responses["/search"]["POST"] = {"type": "FeatureCollection", "features": [{"id": "test"}]}
167 |
168 | # Your test code here
169 | """
170 | # Get the default responses from the source_api fixture
171 | default_responses = source_api.state.default_responses
172 |
173 | # Create a new dict that can be modified by tests
174 | responses = {}
175 | for path, methods in default_responses.items():
176 | responses[path] = methods.copy()
177 |
178 | # Store the responses in the app state for the get_response function to use
179 | source_api.state.default_responses = responses
180 |
181 | yield responses
182 |
183 | # Restore the original responses after the test
184 | source_api.state.default_responses = default_responses
185 |
186 |
187 | @pytest.fixture(scope="session")
188 | def free_port():
189 | """Get a free port."""
190 | sock = socket.socket()
191 | # Needed for Github Actions, https://stackoverflow.com/a/4466035
192 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
193 | sock.bind(("", 0))
194 | return sock.getsockname()[1]
195 |
196 |
197 | @pytest.fixture(scope="session")
198 | def source_api_server(source_api, free_port):
199 | """Run the source API in a background thread."""
200 | host = "127.0.0.1"
201 | server = uvicorn.Server(
202 | uvicorn.Config(
203 | source_api,
204 | host=host,
205 | port=free_port,
206 | )
207 | )
208 | thread = threading.Thread(target=server.run)
209 | thread.start()
210 | yield f"http://{host}:{free_port}"
211 | server.should_exit = True
212 | thread.join()
213 |
214 |
215 | @pytest.fixture(autouse=True, scope="session")
216 | def mock_env():
217 | """Clear environment variables to avoid poluting configs from runtime env."""
218 | with patch.dict(os.environ, clear=True):
219 | yield
220 |
221 |
222 | @pytest.fixture
223 | async def mock_upstream() -> AsyncGenerator[MagicMock, None]:
224 | """Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
225 | # NOTE: This fixture will interfere with the source_api_responses fixture
226 |
227 | async def store_body(request, **kwargs):
228 | """Exhaust and store the request body."""
229 | _streamed_body = b""
230 | async for chunk in request.stream:
231 | _streamed_body += chunk
232 | setattr(request, "_streamed_body", _streamed_body)
233 | return DEFAULT
234 |
235 | with patch(
236 | "stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send",
237 | new_callable=AsyncMock,
238 | side_effect=store_body,
239 | return_value=single_chunk_async_stream_response(b"{}"),
240 | ) as mock_send_method:
241 | yield mock_send_method
242 |
--------------------------------------------------------------------------------