├── .dockerignore ├── docs ├── index.md ├── changelog.md ├── assets │ ├── ds-symbol-negative-mono.png │ └── ds-symbol-positive-mono.png ├── architecture │ ├── filtering-data.md │ └── middleware-stack.md ├── overrides │ └── partials │ │ └── integrations │ │ └── analytics │ │ └── plausible.html └── user-guide │ ├── tips.md │ ├── getting-started.md │ ├── deployment.md │ └── route-level-auth.md ├── examples ├── custom-integration │ ├── .python-version │ ├── pyproject.toml │ ├── Dockerfile │ ├── README.md │ ├── docker-compose.yaml │ └── src │ │ └── custom_integration.py └── opa │ ├── policies │ └── stac │ │ └── policy.rego │ ├── docker-compose.yaml │ └── README.md ├── src └── stac_auth_proxy │ ├── utils │ ├── __init__.py │ ├── stac.py │ ├── filters.py │ ├── cache.py │ ├── middleware.py │ └── requests.py │ ├── filters │ ├── __init__.py │ ├── template.py │ └── opa.py │ ├── handlers │ ├── __init__.py │ ├── healthz.py │ ├── swagger_ui.py │ └── reverse_proxy.py │ ├── lambda.py │ ├── __init__.py │ ├── __main__.py │ ├── middleware │ ├── AddProcessTimeHeaderMiddleware.py │ ├── __init__.py │ ├── RemoveRootPathMiddleware.py │ ├── Cql2ApplyFilterQueryStringMiddleware.py │ ├── UpdateOpenApiMiddleware.py │ ├── Cql2ApplyFilterBodyMiddleware.py │ ├── AuthenticationExtensionMiddleware.py │ ├── Cql2RewriteLinksFilterMiddleware.py │ ├── ProcessLinksMiddleware.py │ ├── Cql2BuildFilterMiddleware.py │ ├── Cql2ValidateResponseBodyMiddleware.py │ └── EnforceAuthMiddleware.py │ ├── config.py │ ├── lifespan.py │ └── app.py ├── helm ├── Chart.yaml ├── templates │ ├── service.yaml │ ├── serviceaccount.yaml │ ├── ingress.yaml │ ├── _helpers.tpl │ ├── deployment.yaml │ └── NOTES.txt └── values.yaml ├── .vscode ├── settings.json └── launch.json ├── .github └── workflows │ ├── conventional-commits-prs.yaml │ ├── publish-pypi.yaml │ ├── release-please.yml │ ├── publish-helm.yaml │ ├── docs.yaml │ ├── cicd.yaml │ └── publish-docker.yaml ├── .coveragerc ├── .pre-commit-config.yaml ├── tests ├── test_configure_app.py ├── test_proxy.py ├── test_remove_root_path.py ├── utils.py ├── test_cache.py ├── test_filters_opa.py ├── test_defaults.py ├── test_utils.py ├── test_lifespan.py ├── test_auth_extension.py └── conftest.py ├── LICENSE ├── Dockerfile ├── Makefile ├── README.md ├── pyproject.toml ├── .gitignore ├── mkdocs.yml └── docker-compose.yaml /.dockerignore: -------------------------------------------------------------------------------- 1 | .pgdata -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ../CHANGELOG.md -------------------------------------------------------------------------------- /examples/custom-integration/.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils module for stac_auth_proxy.""" 2 | -------------------------------------------------------------------------------- /docs/assets/ds-symbol-negative-mono.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/developmentseed/stac-auth-proxy/HEAD/docs/assets/ds-symbol-negative-mono.png -------------------------------------------------------------------------------- /docs/assets/ds-symbol-positive-mono.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/developmentseed/stac-auth-proxy/HEAD/docs/assets/ds-symbol-positive-mono.png -------------------------------------------------------------------------------- /helm/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: stac-auth-proxy 3 | description: A Helm chart for stac-auth-proxy 4 | type: application 5 | version: 0.1.2 6 | appVersion: "1.0.0" 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests", 4 | "-n", 5 | "auto" 6 | ], 7 | "python.testing.unittestEnabled": false, 8 | "python.testing.pytestEnabled": true 9 | } -------------------------------------------------------------------------------- /src/stac_auth_proxy/filters/__init__.py: -------------------------------------------------------------------------------- 1 | """CQL2 filter factories.""" 2 | 3 | from .opa import Opa 4 | from .template import Template 5 | 6 | __all__ = [ 7 | "Opa", 8 | "Template", 9 | ] 10 | -------------------------------------------------------------------------------- /examples/custom-integration/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "custom_integration" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [] 8 | -------------------------------------------------------------------------------- /examples/custom-integration/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG STAC_AUTH_PROXY_VERSION 2 | FROM ghcr.io/developmentseed/stac-auth-proxy:${STAC_AUTH_PROXY_VERSION} 3 | 4 | ADD . /opt/stac-auth-proxy-integration 5 | 6 | RUN pip install /opt/stac-auth-proxy-integration 7 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | """Handlers to process requests.""" 2 | 3 | from .healthz import HealthzHandler 4 | from .reverse_proxy import ReverseProxyHandler 5 | from .swagger_ui import SwaggerUI 6 | 7 | __all__ = ["ReverseProxyHandler", "HealthzHandler", "SwaggerUI"] 8 | -------------------------------------------------------------------------------- /examples/opa/policies/stac/policy.rego: -------------------------------------------------------------------------------- 1 | package stac 2 | 3 | default items_cql2 := "\"naip:year\" = 2021" 4 | 5 | items_cql2 := "1=1" if { 6 | input.payload.sub != null 7 | } 8 | 9 | default collections_cql2 := "id = 'naip'" 10 | 11 | collections_cql2 := "1=1" if { 12 | input.payload.sub != null 13 | } 14 | -------------------------------------------------------------------------------- /examples/custom-integration/README.md: -------------------------------------------------------------------------------- 1 | # Custom Integration Example 2 | 3 | This example demonstrates how to integrate with a custom filter factory. 4 | 5 | ## Running the Example 6 | 7 | From the root directory, run: 8 | 9 | ```sh 10 | docker compose -f docker-compose.yaml -f examples/custom-integration/docker-compose.yaml up 11 | ``` 12 | -------------------------------------------------------------------------------- /examples/custom-integration/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | # This compose file is intended to be run alongside the `docker-compose.yaml` file in the 2 | # root directory. 3 | 4 | services: 5 | proxy: 6 | build: 7 | context: examples/custom-integration 8 | args: 9 | STAC_AUTH_PROXY_VERSION: 0.1.2 10 | environment: 11 | ITEMS_FILTER_CLS: custom_integration:cql2_builder 12 | ITEMS_FILTER_KWARGS: '{"admin_user": "user123"}' 13 | -------------------------------------------------------------------------------- /helm/templates/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: {{ include "stac-auth-proxy.fullname" . }} 5 | labels: 6 | {{- include "stac-auth-proxy.labels" . | nindent 4 }} 7 | spec: 8 | type: {{ .Values.service.type }} 9 | ports: 10 | - port: {{ .Values.service.port }} 11 | targetPort: http 12 | protocol: TCP 13 | name: http 14 | selector: 15 | {{- include "stac-auth-proxy.selectorLabels" . | nindent 4 }} -------------------------------------------------------------------------------- /.github/workflows/conventional-commits-prs.yaml: -------------------------------------------------------------------------------- 1 | name: PR Conventional Commit Validation 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened, edited] 6 | 7 | jobs: 8 | validate-pr-title: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: PR Conventional Commit Validation 12 | uses: ytanikin/pr-conventional-commits@1.4.0 13 | with: 14 | task_types: '["feat","fix","docs","test","ci","refactor","perf","chore","revert"]' 15 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/lambda.py: -------------------------------------------------------------------------------- 1 | """Handler for AWS Lambda.""" 2 | 3 | from stac_auth_proxy import create_app 4 | 5 | try: 6 | from mangum import Mangum 7 | except ImportError: 8 | raise ImportError( 9 | "mangum is required to use the Lambda handler. Install stac-auth-proxy[lambda]." 10 | ) 11 | 12 | 13 | handler = Mangum( 14 | create_app(), 15 | # NOTE: lifespan="off" skips conformance check and upstream health checks on startup 16 | lifespan="off", 17 | ) 18 | -------------------------------------------------------------------------------- /.github/workflows/publish-pypi.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | id-token: write 11 | 12 | jobs: 13 | publish-to-pypi: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v3 18 | - uses: astral-sh/setup-uv@v4 19 | with: 20 | enable-cache: true 21 | - run: uv build 22 | - run: uv publish 23 | -------------------------------------------------------------------------------- /helm/templates/serviceaccount.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.serviceAccount.create -}} 2 | apiVersion: v1 3 | kind: ServiceAccount 4 | metadata: 5 | name: {{ include "stac-auth-proxy.serviceAccountName" . }} 6 | labels: 7 | {{- include "stac-auth-proxy.labels" . | nindent 4 }} 8 | {{- with .Values.serviceAccount.annotations }} 9 | annotations: 10 | {{- toYaml . | nindent 4 }} 11 | {{- end }} 12 | {{- with .Values.serviceAccount.imagePullSecrets }} 13 | imagePullSecrets: 14 | {{- toYaml . | nindent 2 }} 15 | {{- end }} 16 | {{- end }} -------------------------------------------------------------------------------- /src/stac_auth_proxy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | STAC Auth Proxy package. 3 | 4 | This package contains the components for the STAC authentication and proxying system. 5 | It includes FastAPI routes for handling authentication, authorization, and interaction 6 | with some internal STAC API. 7 | """ 8 | 9 | from .app import configure_app, create_app 10 | from .config import Settings 11 | from .lifespan import build_lifespan 12 | 13 | __all__ = [ 14 | "build_lifespan", 15 | "create_app", 16 | "configure_app", 17 | "Settings", 18 | ] 19 | -------------------------------------------------------------------------------- /examples/opa/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | proxy: 3 | environment: 4 | ITEMS_FILTER_CLS: stac_auth_proxy.filters:Opa 5 | ITEMS_FILTER_ARGS: '["http://opa:8181", "stac/items_cql2"]' 6 | COLLECTIONS_FILTER_CLS: stac_auth_proxy.filters:Opa 7 | COLLECTIONS_FILTER_ARGS: '["http://opa:8181", "stac/collections_cql2"]' 8 | 9 | opa: 10 | image: openpolicyagent/opa:latest 11 | command: "run --server --addr=:8181 --watch /policies" 12 | ports: 13 | - "8181:8181" 14 | volumes: 15 | - ./examples/opa/policies:/policies 16 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/utils/stac.py: -------------------------------------------------------------------------------- 1 | """STAC-specific utilities.""" 2 | 3 | from itertools import chain 4 | 5 | 6 | def get_links(data: dict) -> chain[dict]: 7 | """Get all links from a STAC response.""" 8 | return chain( 9 | # Item/Collection 10 | data.get("links", []), 11 | # Collections/Items/Search 12 | ( 13 | link 14 | for prop in ["features", "collections"] 15 | for object_with_links in data.get(prop, []) 16 | for link in object_with_links.get("links", []) 17 | ), 18 | ) 19 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/__main__.py: -------------------------------------------------------------------------------- 1 | """Entry point for running the module without customized code.""" 2 | 3 | import os 4 | 5 | import uvicorn 6 | from uvicorn.config import LOGGING_CONFIG 7 | 8 | uvicorn.run( 9 | f"{__package__}.app:create_app", 10 | host="0.0.0.0", 11 | port=int(os.getenv("PORT", 8000)), 12 | log_config={ 13 | **LOGGING_CONFIG, 14 | "loggers": { 15 | **LOGGING_CONFIG["loggers"], 16 | __package__: { 17 | "level": "DEBUG", 18 | "handlers": ["default"], 19 | }, 20 | }, 21 | }, 22 | reload=True, 23 | factory=True, 24 | ) 25 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: FastAPI", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "module": "uvicorn", 12 | "args": [ 13 | "stac_auth_proxy.app:create_app", 14 | "--reload", 15 | "--factory" 16 | ], 17 | "jinja": true, 18 | "cwd": "${workspaceFolder}/src", 19 | "justMyCode": false 20 | } 21 | ] 22 | } -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = src/stac_auth_proxy 4 | omit = 5 | */tests/* 6 | */test_* 7 | */__pycache__/* 8 | */venv/* 9 | */build/* 10 | */dist/* 11 | */htmlcov/* 12 | */lambda.py 13 | */__main__.py 14 | 15 | [report] 16 | exclude_lines = 17 | pragma: no cover 18 | def __repr__ 19 | if self.debug: 20 | if settings.DEBUG 21 | raise AssertionError 22 | raise NotImplementedError 23 | if 0: 24 | if __name__ == .__main__.: 25 | class .*\bProtocol\): 26 | @(abc\.)?abstractmethod 27 | # Have to re-enable the standard pragma 28 | pragma: no cover 29 | 30 | [html] 31 | directory = htmlcov 32 | 33 | [xml] 34 | output = coverage.xml 35 | -------------------------------------------------------------------------------- /.github/workflows/release-please.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | 6 | permissions: 7 | contents: write 8 | pull-requests: write 9 | 10 | name: release-please 11 | 12 | jobs: 13 | release-please: 14 | runs-on: ubuntu-latest 15 | outputs: 16 | release_created: ${{ steps.release.outputs.release_created }} 17 | steps: 18 | - uses: actions/create-github-app-token@v2.0.6 19 | id: app-token 20 | with: 21 | app-id: ${{ secrets.DS_RELEASE_BOT_ID }} 22 | private-key: ${{ secrets.DS_RELEASE_BOT_PRIVATE_KEY }} 23 | 24 | - uses: googleapis/release-please-action@v4 25 | id: release 26 | with: 27 | release-type: python 28 | token: ${{ steps.app-token.outputs.token }} 29 | -------------------------------------------------------------------------------- /examples/opa/README.md: -------------------------------------------------------------------------------- 1 | # Open Policy Agent (OPA) Integration 2 | 3 | This example demonstrates how to integrate with an Open Policy Agent (OPA) to authorize requests to a STAC API. 4 | 5 | ## Running the Example 6 | 7 | From the root directory, run: 8 | 9 | ```sh 10 | docker compose -f docker-compose.yaml -f examples/opa/docker-compose.yaml up 11 | ``` 12 | 13 | ## Testing OPA 14 | 15 | ```sh 16 | ▶ curl -X POST "http://localhost:8181/v1/data/stac/cql2" \ 17 | -H "Content-Type: application/json" \ 18 | -d '{"input":{"payload": null}}' 19 | {"result":"private = true"} 20 | ``` 21 | 22 | ```sh 23 | ▶ curl -X POST "http://localhost:8181/v1/data/stac/cql2" \ 24 | -H "Content-Type: application/json" \ 25 | -d '{"input":{"payload": {"sub": "user1"}}}' 26 | {"result":"1=1"} 27 | ``` 28 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/filters/template.py: -------------------------------------------------------------------------------- 1 | """Generate CQL2 filter expressions via Jinja2 templating.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Any 5 | 6 | from jinja2 import BaseLoader, Environment 7 | 8 | 9 | @dataclass 10 | class Template: 11 | """Generate CQL2 filter expressions via Jinja2 templating.""" 12 | 13 | template_str: str 14 | env: Environment = field(init=False) 15 | 16 | def __post_init__(self): 17 | """Initialize the Jinja2 environment.""" 18 | self.env = Environment(loader=BaseLoader).from_string(self.template_str) 19 | 20 | async def __call__(self, context: dict[str, Any]) -> str: 21 | """Render a CQL2 filter expression with the request and auth token.""" 22 | return self.env.render(**context).strip() 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/abravalheri/validate-pyproject 3 | rev: v0.24.1 4 | hooks: 5 | - id: validate-pyproject 6 | 7 | - repo: https://github.com/PyCQA/isort 8 | rev: 6.0.1 9 | hooks: 10 | - id: isort 11 | language_version: python 12 | 13 | - repo: https://github.com/charliermarsh/ruff-pre-commit 14 | rev: v0.12.11 15 | hooks: 16 | - id: ruff-check 17 | args: ["--fix"] 18 | - id: ruff-format 19 | 20 | - repo: https://github.com/pre-commit/mirrors-mypy 21 | rev: v1.17.1 22 | hooks: 23 | - id: mypy 24 | language_version: python 25 | exclude: tests/.* 26 | additional_dependencies: 27 | - types-simplejson 28 | - types-attrs 29 | - pydantic~=2.0 30 | -------------------------------------------------------------------------------- /tests/test_configure_app.py: -------------------------------------------------------------------------------- 1 | """Tests for configuring an external FastAPI application.""" 2 | 3 | from fastapi import FastAPI 4 | from fastapi.routing import APIRoute 5 | 6 | from stac_auth_proxy import Settings, configure_app 7 | 8 | 9 | def test_configure_app_excludes_proxy_route(): 10 | """Ensure `configure_app` adds health route and omits proxy route.""" 11 | app = FastAPI() 12 | settings = Settings( 13 | upstream_url="https://example.com", 14 | oidc_discovery_url="https://example.com/.well-known/openid-configuration", 15 | wait_for_upstream=False, 16 | check_conformance=False, 17 | default_public=True, 18 | ) 19 | 20 | configure_app(app, settings) 21 | 22 | routes = [r.path for r in app.router.routes if isinstance(r, APIRoute)] 23 | assert settings.healthz_prefix in routes 24 | assert "/{path:path}" not in routes 25 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/handlers/healthz.py: -------------------------------------------------------------------------------- 1 | """Health check endpoints.""" 2 | 3 | from dataclasses import dataclass, field 4 | 5 | from fastapi import APIRouter 6 | from httpx import AsyncClient 7 | 8 | 9 | @dataclass 10 | class HealthzHandler: 11 | """Handler for health check endpoints.""" 12 | 13 | upstream_url: str 14 | router: APIRouter = field(init=False) 15 | 16 | def __post_init__(self): 17 | """Initialize the router.""" 18 | self.router = APIRouter() 19 | self.router.add_api_route("", self.healthz, methods=["GET"]) 20 | self.router.add_api_route("/upstream", self.healthz_upstream, methods=["GET"]) 21 | 22 | async def healthz(self): 23 | """Return health of this API.""" 24 | return {"status": "ok"} 25 | 26 | async def healthz_upstream(self): 27 | """Return health of upstream STAC API.""" 28 | async with AsyncClient() as client: 29 | response = await client.get(self.upstream_url) 30 | response.raise_for_status() 31 | return {"status": "ok", "code": response.status_code} 32 | -------------------------------------------------------------------------------- /.github/workflows/publish-helm.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Helm Chart 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - 'helm/**' 9 | - '.github/workflows/publish-helm.yaml' 10 | release: 11 | types: [created] 12 | 13 | jobs: 14 | publish-helm: 15 | runs-on: ubuntu-latest 16 | permissions: 17 | contents: read 18 | packages: write 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | with: 24 | fetch-depth: 0 25 | 26 | - name: Install Helm 27 | uses: azure/setup-helm@v3 28 | with: 29 | version: v3.12.1 30 | 31 | - name: Login to GHCR 32 | uses: docker/login-action@v3 33 | with: 34 | registry: ghcr.io 35 | username: ${{ github.actor }} 36 | password: ${{ secrets.GITHUB_TOKEN }} 37 | 38 | - name: Package Helm Chart 39 | run: | 40 | helm package helm/ 41 | 42 | - name: Push Helm Chart 43 | run: | 44 | helm push *.tgz oci://ghcr.io/${{ github.repository }}/charts -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | name: Test 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout main 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 3.11 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: 3.11 20 | 21 | - uses: astral-sh/setup-uv@v4 22 | with: 23 | enable-cache: true 24 | 25 | - name: Test docs 26 | run: uv run --extra docs mkdocs build --strict 27 | 28 | deploy: 29 | name: Deploy 30 | runs-on: ubuntu-latest 31 | if: github.ref_name == 'main' 32 | steps: 33 | - name: Checkout main 34 | uses: actions/checkout@v4 35 | 36 | - name: Set up Python 3.11 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: 3.11 40 | 41 | - uses: astral-sh/setup-uv@v4 42 | with: 43 | enable-cache: true 44 | 45 | - name: Deploy docs 46 | run: uv run --extra docs mkdocs gh-deploy --force 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Development Seed 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/AddProcessTimeHeaderMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to add Server-Timing header with proxy processing time.""" 2 | 3 | import time 4 | 5 | from fastapi import Request, Response 6 | from starlette.middleware.base import BaseHTTPMiddleware 7 | 8 | from stac_auth_proxy.utils.requests import build_server_timing_header 9 | 10 | 11 | class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware): 12 | """Middleware to add Server-Timing header with proxy processing time.""" 13 | 14 | async def dispatch(self, request: Request, call_next) -> Response: 15 | """Add Server-Timing header with proxy processing time to the response.""" 16 | start_time = time.perf_counter() 17 | response = await call_next(request) 18 | process_time = time.perf_counter() - start_time 19 | 20 | # Add Server-Timing header with proxy processing time 21 | response.headers["Server-Timing"] = build_server_timing_header( 22 | response.headers.get("Server-Timing"), 23 | name="proxy", 24 | dur=process_time, 25 | desc="Proxy processing time", 26 | ) 27 | 28 | return response 29 | -------------------------------------------------------------------------------- /examples/custom-integration/src/custom_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | A custom integration example. 3 | 4 | In this example, we're intentionally using a functional pattern but you could also use a 5 | class like we do in the integrations found in stac_auth_proxy.filters. 6 | """ 7 | 8 | from typing import Any 9 | 10 | 11 | def cql2_builder(admin_user: str): 12 | """CQL2 builder integration filter.""" 13 | # NOTE: This is where you would set up things like connection pools. 14 | # NOTE: args/kwargs are passed in via environment variables. 15 | 16 | async def custom_integration_filter(ctx: dict[str, Any]) -> str: 17 | """ 18 | Generate CQL2 expressions based on the request context. 19 | 20 | Returns a CQL2 expression, either as a string (cql2-text) or as a dict (cql2-json). 21 | """ 22 | # NOTE: This is where you would perform a lookup from a database, API, etc. 23 | # NOTE: ctx is the request context, which includes the payload, headers, etc. 24 | 25 | if ctx["payload"] and ctx["payload"]["sub"] == admin_user: 26 | return "1=1" 27 | return "private = true" 28 | 29 | return custom_integration_filter 30 | -------------------------------------------------------------------------------- /docs/architecture/filtering-data.md: -------------------------------------------------------------------------------- 1 | # Filtering Data 2 | 3 | > [!NOTE] 4 | > 5 | > For more information on using filters to solve authorization needs, more information can be found in the [user guide](../user-guide/record-level-auth.md). 6 | 7 | ## Example Request Flow for multi-record endpoints 8 | 9 | ```mermaid 10 | sequenceDiagram 11 | Client->>Proxy: GET /collections 12 | Note over Proxy: EnforceAuth checks credentials 13 | Note over Proxy: BuildCql2Filter creates filter 14 | Note over Proxy: ApplyCql2Filter applies filter to request 15 | Proxy->>STAC API: GET /collection?filter=(collection=landsat) 16 | STAC API->>Client: Response 17 | ``` 18 | 19 | ## Example Request Flow for single-record endpoints 20 | 21 | The Filter Extension does not apply to fetching individual records. As such, we must validate the record _after_ it is returned from the upstream API but _before_ it is returned to the user: 22 | 23 | ```mermaid 24 | sequenceDiagram 25 | Client->>Proxy: GET /collections/abc123 26 | Note over Proxy: EnforceAuth checks credentials 27 | Note over Proxy: BuildCql2Filter creates filter 28 | Proxy->>STAC API: GET /collection/abc123 29 | Note over Proxy: ApplyCql2Filter validates the response 30 | STAC API->>Client: Response 31 | ``` 32 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom middleware.""" 2 | 3 | from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware 4 | from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware 5 | from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware 6 | from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware 7 | from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware 8 | from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware 9 | from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware 10 | from .EnforceAuthMiddleware import EnforceAuthMiddleware 11 | from .ProcessLinksMiddleware import ProcessLinksMiddleware 12 | from .RemoveRootPathMiddleware import RemoveRootPathMiddleware 13 | from .UpdateOpenApiMiddleware import OpenApiMiddleware 14 | 15 | __all__ = [ 16 | "AddProcessTimeHeaderMiddleware", 17 | "AuthenticationExtensionMiddleware", 18 | "Cql2ApplyFilterBodyMiddleware", 19 | "Cql2ApplyFilterQueryStringMiddleware", 20 | "Cql2BuildFilterMiddleware", 21 | "Cql2RewriteLinksFilterMiddleware", 22 | "Cql2ValidateResponseBodyMiddleware", 23 | "EnforceAuthMiddleware", 24 | "OpenApiMiddleware", 25 | "ProcessLinksMiddleware", 26 | "RemoveRootPathMiddleware", 27 | ] 28 | -------------------------------------------------------------------------------- /helm/templates/ingress.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.ingress.enabled -}} 2 | {{- $fullName := include "stac-auth-proxy.fullname" . -}} 3 | {{- $svcPort := .Values.service.port -}} 4 | {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} 5 | {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} 6 | {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} 7 | {{- end }} 8 | {{- end }} 9 | apiVersion: networking.k8s.io/v1 10 | kind: Ingress 11 | metadata: 12 | name: {{ $fullName }} 13 | labels: 14 | {{- include "stac-auth-proxy.labels" . | nindent 4 }} 15 | {{- with .Values.ingress.annotations }} 16 | annotations: 17 | {{- toYaml . | nindent 4 }} 18 | {{- end }} 19 | spec: 20 | ingressClassName: {{ .Values.ingress.className }} 21 | {{- if and .Values.ingress.tls.enabled .Values.ingress.host }} 22 | tls: 23 | - hosts: 24 | - {{ .Values.ingress.host }} 25 | secretName: {{ .Values.ingress.tls.secretName | default (printf "%s-tls" .Values.ingress.host) }} 26 | {{- end }} 27 | rules: 28 | {{- if .Values.ingress.host }} 29 | - host: {{ .Values.ingress.host }} 30 | http: 31 | paths: 32 | - path: / 33 | pathType: Prefix 34 | backend: 35 | service: 36 | name: {{ $fullName }} 37 | port: 38 | number: {{ $svcPort }} 39 | {{- end }} 40 | {{- end }} -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # https://github.com/astral-sh/uv-docker-example/blob/c16a61fb3e6ab568ac58d94b73a7d79594a5d570/Dockerfile 2 | 3 | # Build stage 4 | FROM python:3.13-slim AS builder 5 | 6 | WORKDIR /app 7 | 8 | ENV UV_COMPILE_BYTECODE=1 9 | ENV UV_LINK_MODE=copy 10 | ENV UV_PROJECT_ENVIRONMENT=/usr/local 11 | 12 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ 13 | 14 | # Install the project's dependencies using the lockfile and settings 15 | RUN --mount=type=cache,target=/root/.cache/uv \ 16 | --mount=type=bind,source=uv.lock,target=uv.lock \ 17 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 18 | uv sync --frozen --no-install-project --no-dev 19 | 20 | # Then, add the rest of the project source code and install it 21 | # Installing separately from its dependencies allows optimal layer caching 22 | ADD . /app 23 | RUN --mount=type=cache,target=/root/.cache/uv \ 24 | uv sync --frozen --no-dev 25 | 26 | # Runtime stage 27 | FROM python:3.13-slim 28 | 29 | WORKDIR /app 30 | 31 | # Copy installed packages from builder 32 | COPY --from=builder /usr/local/lib/python3.13/site-packages /usr/local/lib/python3.13/site-packages 33 | COPY --from=builder /usr/local/bin /usr/local/bin 34 | 35 | # Copy only the source code directory needed at runtime 36 | COPY --from=builder /app/src/stac_auth_proxy /app/src/stac_auth_proxy 37 | 38 | 39 | RUN useradd -m -u 1001 -s /bin/bash user && \ 40 | chown -R user:user /app 41 | 42 | USER user 43 | 44 | ENV PYTHONPATH=/app/src 45 | 46 | CMD ["python", "-m", "stac_auth_proxy"] 47 | -------------------------------------------------------------------------------- /tests/test_proxy.py: -------------------------------------------------------------------------------- 1 | """Test authentication cases for the proxy app.""" 2 | 3 | from fastapi.testclient import TestClient 4 | from utils import AppFactory, get_upstream_request 5 | 6 | app_factory = AppFactory( 7 | oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration", 8 | default_public=True, 9 | public_endpoints={}, 10 | private_endpoints={}, 11 | ) 12 | 13 | 14 | async def test_proxied_headers_no_encoding(source_api_server, mock_upstream): 15 | """Clients that don't accept encoding should not receive it.""" 16 | test_app = app_factory(upstream_url=source_api_server) 17 | 18 | client = TestClient(test_app) 19 | req = client.build_request(method="GET", url="/", headers={}) 20 | for h in req.headers: 21 | if h in ["accept-encoding"]: 22 | del req.headers[h] 23 | client.send(req) 24 | 25 | proxied_request = await get_upstream_request(mock_upstream) 26 | assert "accept-encoding" not in proxied_request.headers 27 | 28 | 29 | async def test_proxied_headers_with_encoding(source_api_server, mock_upstream): 30 | """Clients that do accept encoding should receive it.""" 31 | test_app = app_factory(upstream_url=source_api_server) 32 | 33 | client = TestClient(test_app) 34 | req = client.build_request( 35 | method="GET", url="/", headers={"accept-encoding": "gzip"} 36 | ) 37 | client.send(req) 38 | 39 | proxied_request = await get_upstream_request(mock_upstream) 40 | assert proxied_request.headers.get("accept-encoding") == "gzip" 41 | -------------------------------------------------------------------------------- /.github/workflows/cicd.yaml: -------------------------------------------------------------------------------- 1 | name: CI/CD 2 | 3 | on: 4 | push: 5 | release: 6 | 7 | jobs: 8 | lint: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: '3.13' 16 | - uses: astral-sh/setup-uv@v4 17 | with: 18 | enable-cache: true 19 | - uses: actions/cache@v4 20 | with: 21 | path: ~/.cache/pre-commit 22 | key: ${{ hashFiles('.pre-commit-config.yaml') }} 23 | - run: uv run pre-commit run --all-files 24 | 25 | test: 26 | runs-on: ubuntu-latest 27 | 28 | steps: 29 | - uses: actions/checkout@v4 30 | - uses: actions/setup-python@v3 31 | with: 32 | python-version: '3.13' 33 | - uses: astral-sh/setup-uv@v4 34 | with: 35 | enable-cache: true 36 | - name: Run tests with coverage 37 | run: | 38 | uv run pytest -n auto --cov=src/stac_auth_proxy --cov-report=xml --cov-report=html --cov-report=term-missing --cov-fail-under=85 39 | - name: Upload coverage reports to Codecov 40 | uses: codecov/codecov-action@v5 41 | with: 42 | files: ./coverage.xml 43 | flags: unittests 44 | fail_ci_if_error: false 45 | - name: Archive coverage reports 46 | uses: actions/upload-artifact@v4 47 | if: always() 48 | with: 49 | name: coverage-reports 50 | path: | 51 | htmlcov/ 52 | coverage.xml 53 | -------------------------------------------------------------------------------- /.github/workflows/publish-docker.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | env: 8 | REGISTRY: ghcr.io 9 | IMAGE_NAME: ${{ github.repository }} 10 | 11 | jobs: 12 | build-and-push: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | packages: write 17 | 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Docker Buildx 23 | uses: docker/setup-buildx-action@v3 24 | 25 | - name: Log in to the Container registry 26 | uses: docker/login-action@v3 27 | with: 28 | registry: ${{ env.REGISTRY }} 29 | username: ${{ github.actor }} 30 | password: ${{ secrets.GITHUB_TOKEN }} 31 | 32 | - name: Extract metadata (tags, labels) for Docker 33 | id: meta 34 | uses: docker/metadata-action@v5 35 | with: 36 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 37 | tags: | 38 | type=raw,value=${{ github.event.release.tag_name }} 39 | type=raw,value=latest,enable=${{ github.event.release.prerelease == false }} 40 | 41 | - name: Build and push Docker image 42 | uses: docker/build-push-action@v5 43 | with: 44 | context: . 45 | push: true 46 | platforms: linux/amd64,linux/arm64 47 | tags: ${{ steps.meta.outputs.tags }} 48 | labels: ${{ steps.meta.outputs.labels }} 49 | cache-from: type=gha 50 | cache-to: type=gha,mode=max 51 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/filters/opa.py: -------------------------------------------------------------------------------- 1 | """Integration with Open Policy Agent (OPA) to generate CQL2 filters for requests to a STAC API.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Any 5 | 6 | import httpx 7 | 8 | from ..utils.cache import MemoryCache, get_value_by_path 9 | 10 | 11 | @dataclass 12 | class Opa: 13 | """Call Open Policy Agent (OPA) to generate CQL2 filters from request context.""" 14 | 15 | host: str 16 | decision: str 17 | 18 | client: httpx.AsyncClient = field(init=False) 19 | cache: MemoryCache = field(init=False) 20 | cache_key: str = "req.headers.authorization" 21 | cache_ttl: float = 5.0 22 | 23 | def __post_init__(self): 24 | """Initialize the client.""" 25 | self.client = httpx.AsyncClient(base_url=self.host) 26 | self.cache = MemoryCache(ttl=self.cache_ttl) 27 | 28 | async def __call__(self, context: dict[str, Any]) -> str: 29 | """Generate a CQL2 filter for the request.""" 30 | token = get_value_by_path(context, self.cache_key) 31 | try: 32 | expr_str = self.cache[token] 33 | except KeyError: 34 | expr_str = await self._fetch(context) 35 | self.cache[token] = expr_str 36 | return expr_str 37 | 38 | async def _fetch(self, context: dict[str, Any]) -> str: 39 | """Fetch the CQL2 filter from OPA.""" 40 | response = await self.client.post( 41 | f"/v1/data/{self.decision}", 42 | json={"input": context}, 43 | ) 44 | return response.raise_for_status().json()["result"] 45 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/handlers/swagger_ui.py: -------------------------------------------------------------------------------- 1 | """ 2 | In order to allow customization fo the Swagger UI's OAuth2 configuration, we support 3 | overriding the default handler. This is useful for adding custom parameters such as 4 | `usePkceWithAuthorizationCodeGrant` or `clientId`. 5 | 6 | See: 7 | - https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/ 8 | """ 9 | 10 | from dataclasses import dataclass, field 11 | from typing import Optional 12 | 13 | from fastapi.openapi.docs import get_swagger_ui_html 14 | from starlette.requests import Request 15 | from starlette.responses import HTMLResponse 16 | 17 | 18 | @dataclass 19 | class SwaggerUI: 20 | """Swagger UI handler.""" 21 | 22 | openapi_url: str 23 | title: Optional[str] = "STAC API" 24 | init_oauth: dict = field(default_factory=dict) 25 | parameters: dict = field(default_factory=dict) 26 | oauth2_redirect_url: str = "/docs/oauth2-redirect" 27 | 28 | async def route(self, req: Request) -> HTMLResponse: 29 | """Route handler.""" 30 | root_path = req.scope.get("root_path", "").rstrip("/") 31 | openapi_url = root_path + self.openapi_url 32 | oauth2_redirect_url = self.oauth2_redirect_url 33 | if oauth2_redirect_url: 34 | oauth2_redirect_url = root_path + oauth2_redirect_url 35 | return get_swagger_ui_html( 36 | openapi_url=openapi_url, 37 | title=f"{self.title} - Swagger UI", 38 | oauth2_redirect_url=oauth2_redirect_url, 39 | init_oauth=self.init_oauth, 40 | swagger_ui_parameters=self.parameters, 41 | ) 42 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/utils/filters.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | import json 4 | import logging 5 | from typing import Optional 6 | from urllib.parse import parse_qs 7 | 8 | from cql2 import Expr 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def append_qs_filter(qs: str, filter: Expr, filter_lang: Optional[str] = None) -> bytes: 14 | """Insert a filter expression into a query string. If a filter already exists, combine them.""" 15 | qs_dict = {k: v[0] for k, v in parse_qs(qs).items()} 16 | new_qs_dict = append_body_filter( 17 | qs_dict, filter, filter_lang or qs_dict.get("filter-lang", "cql2-text") 18 | ) 19 | return dict_to_query_string(new_qs_dict).encode("utf-8") 20 | 21 | 22 | def append_body_filter( 23 | body: dict, filter: Expr, filter_lang: Optional[str] = None 24 | ) -> dict: 25 | """Insert a filter expression into a request body. If a filter already exists, combine them.""" 26 | cur_filter = body.get("filter") 27 | filter_lang = filter_lang or body.get("filter-lang", "cql2-json") 28 | if cur_filter: 29 | filter = filter + Expr(cur_filter) 30 | return { 31 | **body, 32 | "filter": filter.to_text() if filter_lang == "cql2-text" else filter.to_json(), 33 | "filter-lang": filter_lang, 34 | } 35 | 36 | 37 | def dict_to_query_string(params: dict) -> str: 38 | """ 39 | Convert a dictionary to a query string. Dict values are converted to JSON strings, 40 | unlike the default behavior of urllib.parse.urlencode. 41 | """ 42 | parts = [] 43 | for key, val in params.items(): 44 | if isinstance(val, (dict, list)): 45 | val = json.dumps(val, separators=(",", ":")) 46 | parts.append(f"{key}={val}") 47 | return "&".join(parts) 48 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/RemoveRootPathMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to remove ROOT_PATH from incoming requests and update links in responses.""" 2 | 3 | import logging 4 | from dataclasses import dataclass 5 | 6 | from starlette.responses import Response 7 | from starlette.types import ASGIApp, Receive, Scope, Send 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class RemoveRootPathMiddleware: 14 | """ 15 | Middleware to remove the root path of the request before it is sent to the upstream 16 | server. 17 | 18 | IMPORTANT: This middleware must be placed early in the middleware chain (ie late in 19 | the order of declaration) so that it trims the root_path from the request path before 20 | any middleware that may need to use the request path (e.g. EnforceAuthMiddleware). 21 | """ 22 | 23 | app: ASGIApp 24 | root_path: str 25 | 26 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 27 | """Remove ROOT_PATH from the request path if it exists.""" 28 | if scope["type"] != "http": 29 | return await self.app(scope, receive, send) 30 | 31 | # If root_path is set and path doesn't start with it, return 404 32 | if self.root_path and not scope["path"].startswith(self.root_path): 33 | response = Response("Not Found", status_code=404) 34 | logger.error( 35 | f"Root path {self.root_path!r} not found in path {scope['path']!r}" 36 | ) 37 | await response(scope, receive, send) 38 | return 39 | 40 | # Remove root_path if it exists at the start of the path 41 | if scope["path"].startswith(self.root_path): 42 | scope["raw_path"] = scope["path"].encode() 43 | scope["path"] = scope["path"][len(self.root_path) :] or "/" 44 | 45 | return await self.app(scope, receive, send) 46 | -------------------------------------------------------------------------------- /docs/overrides/partials/integrations/analytics/plausible.html: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 54 | -------------------------------------------------------------------------------- /helm/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "stac-auth-proxy.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | */}} 11 | {{- define "stac-auth-proxy.fullname" -}} 12 | {{- if .Values.fullnameOverride }} 13 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 14 | {{- else }} 15 | {{- $name := default .Chart.Name .Values.nameOverride }} 16 | {{- if contains $name .Release.Name }} 17 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 18 | {{- else }} 19 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 20 | {{- end }} 21 | {{- end }} 22 | {{- end }} 23 | 24 | {{/* 25 | Create chart name and version as used by the chart label. 26 | */}} 27 | {{- define "stac-auth-proxy.chart" -}} 28 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 29 | {{- end }} 30 | 31 | {{/* 32 | Common labels 33 | */}} 34 | {{- define "stac-auth-proxy.labels" -}} 35 | helm.sh/chart: {{ include "stac-auth-proxy.chart" . }} 36 | {{ include "stac-auth-proxy.selectorLabels" . }} 37 | {{- if .Chart.AppVersion }} 38 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 39 | {{- end }} 40 | app.kubernetes.io/managed-by: {{ .Release.Service }} 41 | {{- end }} 42 | 43 | {{/* 44 | Selector labels 45 | */}} 46 | {{- define "stac-auth-proxy.selectorLabels" -}} 47 | app.kubernetes.io/name: {{ include "stac-auth-proxy.name" . }} 48 | app.kubernetes.io/instance: {{ .Release.Name }} 49 | {{- end }} 50 | 51 | {{/* 52 | Create the name of the service account to use 53 | */}} 54 | {{- define "stac-auth-proxy.serviceAccountName" -}} 55 | {{- if .Values.serviceAccount.create }} 56 | {{- default (include "stac-auth-proxy.fullname" .) .Values.serviceAccount.name }} 57 | {{- else }} 58 | {{- default "default" .Values.serviceAccount.name }} 59 | {{- end }} 60 | {{- end }} 61 | 62 | {{/* 63 | Render env var value based on type 64 | */}} 65 | {{- define "stac-auth-proxy.envValue" -}} 66 | {{- if kindIs "string" . -}} 67 | {{- . | quote -}} 68 | {{- else -}} 69 | {{- . | toJson | quote -}} 70 | {{- end -}} 71 | {{- end -}} 72 | -------------------------------------------------------------------------------- /helm/templates/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: {{ include "stac-auth-proxy.fullname" . }} 5 | labels: 6 | {{- include "stac-auth-proxy.labels" . | nindent 4 }} 7 | spec: 8 | replicas: {{ .Values.replicaCount }} 9 | selector: 10 | matchLabels: 11 | {{- include "stac-auth-proxy.selectorLabels" . | nindent 6 }} 12 | template: 13 | metadata: 14 | labels: 15 | {{- include "stac-auth-proxy.selectorLabels" . | nindent 8 }} 16 | spec: 17 | serviceAccountName: {{ include "stac-auth-proxy.serviceAccountName" . }} 18 | securityContext: 19 | {{- toYaml .Values.securityContext | nindent 8 }} 20 | {{- with .Values.initContainers }} 21 | initContainers: 22 | {{- toYaml . | nindent 8 }} 23 | {{- end }} 24 | containers: 25 | - name: {{ .Chart.Name }} 26 | securityContext: 27 | {{- toYaml .Values.containerSecurityContext | nindent 12 }} 28 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" 29 | imagePullPolicy: {{ .Values.image.pullPolicy }} 30 | ports: 31 | - name: http 32 | containerPort: 8000 33 | protocol: TCP 34 | resources: 35 | {{- toYaml .Values.resources | nindent 12 }} 36 | env: 37 | {{- range $key, $value := .Values.env }} 38 | - name: {{ $key }} 39 | value: {{ include "stac-auth-proxy.envValue" $value }} 40 | {{- end }} 41 | {{- with .Values.extraVolumeMounts }} 42 | volumeMounts: 43 | {{- toYaml . | nindent 12 }} 44 | {{- end }} 45 | 46 | {{- with .Values.extraVolumes }} 47 | volumes: 48 | {{- toYaml . | nindent 8 }} 49 | {{- end }} 50 | {{- with .Values.nodeSelector }} 51 | nodeSelector: 52 | {{- toYaml . | nindent 8 }} 53 | {{- end }} 54 | {{- with .Values.affinity }} 55 | affinity: 56 | {{- toYaml . | nindent 8 }} 57 | {{- end }} 58 | {{- with .Values.tolerations }} 59 | tolerations: 60 | {{- toYaml . | nindent 8 }} 61 | {{- end }} 62 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to inject CQL2 filters into the query string for GET/list endpoints.""" 2 | 3 | import re 4 | from dataclasses import dataclass 5 | from logging import getLogger 6 | from typing import Optional 7 | 8 | from cql2 import Expr 9 | from starlette.requests import Request 10 | from starlette.types import ASGIApp, Receive, Scope, Send 11 | 12 | from ..utils import filters 13 | from ..utils.middleware import required_conformance 14 | 15 | logger = getLogger(__name__) 16 | 17 | 18 | @required_conformance( 19 | r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", 20 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", 21 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", 22 | ) 23 | @dataclass(frozen=True) 24 | class Cql2ApplyFilterQueryStringMiddleware: 25 | """Middleware to inject CQL2 filters into the query string for GET/list endpoints.""" 26 | 27 | app: ASGIApp 28 | state_key: str = "cql2_filter" 29 | 30 | single_record_endpoints = [ 31 | r"^/collections/([^/]+)/items/([^/]+)$", 32 | r"^/collections/([^/]+)$", 33 | ] 34 | 35 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 36 | """Apply the CQL2 filter to the query string.""" 37 | if scope["type"] != "http": 38 | return await self.app(scope, receive, send) 39 | 40 | request = Request(scope) 41 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) 42 | if not cql2_filter: 43 | return await self.app(scope, receive, send) 44 | 45 | # Only handle GET requests that are not single-record endpoints 46 | if request.method != "GET": 47 | return await self.app(scope, receive, send) 48 | if any( 49 | re.match(expr, request.url.path) for expr in self.single_record_endpoints 50 | ): 51 | return await self.app(scope, receive, send) 52 | 53 | # Inject filter into query string 54 | scope = dict(scope) 55 | scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter) 56 | return await self.app(scope, receive, send) 57 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help test test-coverage test-fast lint format clean install dev-install docs 2 | 3 | help: ## Show this help message 4 | @echo "Available commands:" 5 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' 6 | 7 | install: ## Install the package 8 | uv sync 9 | 10 | dev-install: ## Install development dependencies 11 | uv sync --group dev 12 | 13 | test: ## Run tests 14 | uv run pytest 15 | 16 | test-fast: ## Run tests in parallel 17 | uv run pytest -n auto 18 | 19 | test-coverage: ## Run tests with coverage 20 | @echo "🧪 Running tests with coverage..." 21 | uv run pytest \ 22 | --cov=src/stac_auth_proxy \ 23 | --cov-report=term-missing \ 24 | --cov-report=html \ 25 | --cov-report=xml \ 26 | --cov-fail-under=85 \ 27 | -v 28 | @echo "✅ Coverage report generated!" 29 | @echo "📊 HTML report available at: htmlcov/index.html" 30 | @echo "📄 XML report available at: coverage.xml" 31 | @if [ "$(CI)" = "true" ]; then \ 32 | echo "🚀 Running in CI environment"; \ 33 | else \ 34 | echo "💻 Running locally - opening HTML report..."; \ 35 | if command -v open >/dev/null 2>&1; then \ 36 | open htmlcov/index.html; \ 37 | elif command -v xdg-open >/dev/null 2>&1; then \ 38 | xdg-open htmlcov/index.html; \ 39 | else \ 40 | echo "Please open htmlcov/index.html in your browser to view the coverage report"; \ 41 | fi; \ 42 | fi 43 | 44 | lint: ## Run linting 45 | uv run pre-commit run ruff-check --all-files 46 | uv run pre-commit run mypy --all-files 47 | 48 | format: ## Format code 49 | uv run pre-commit run ruff-format --all-files 50 | 51 | clean: ## Clean up generated files 52 | rm -rf htmlcov/ 53 | rm -rf .coverage 54 | rm -rf coverage.xml 55 | rm -rf .pytest_cache/ 56 | rm -rf build/ 57 | rm -rf dist/ 58 | rm -rf *.egg-info/ 59 | find . -type d -name __pycache__ -delete 60 | find . -type f -name "*.pyc" -delete 61 | 62 | ci: ## Run CI checks locally 63 | uv run pre-commit run --all-files 64 | @echo "🧪 Running tests with coverage..." 65 | uv run pytest \ 66 | -n auto \ 67 | --cov=src/stac_auth_proxy \ 68 | --cov-report=term-missing \ 69 | --cov-report=html \ 70 | --cov-report=xml \ 71 | --cov-fail-under=85 \ 72 | -v 73 | @echo "✅ CI checks completed!" 74 | 75 | docs: ## Serve documentation locally 76 | uv sync --extra docs 77 | DYLD_FALLBACK_LIBRARY_PATH=/opt/homebrew/lib uv run mkdocs serve -------------------------------------------------------------------------------- /tests/test_remove_root_path.py: -------------------------------------------------------------------------------- 1 | """Tests for RemoveRootPathMiddleware.""" 2 | 3 | import pytest 4 | from fastapi import FastAPI 5 | from starlette.testclient import TestClient 6 | from starlette.types import Receive, Scope, Send 7 | 8 | from stac_auth_proxy.middleware.RemoveRootPathMiddleware import RemoveRootPathMiddleware 9 | 10 | 11 | class MockASGIApp: 12 | """Mock ASGI application for testing.""" 13 | 14 | def __init__(self): 15 | """Initialize the mock app.""" 16 | self.called = False 17 | self.scope = None 18 | 19 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 20 | """Mock ASGI call.""" 21 | self.called = True 22 | self.scope = scope 23 | 24 | 25 | @pytest.mark.asyncio 26 | async def test_remove_root_path_middleware(): 27 | """Test that root path is removed from request path.""" 28 | mock_app = MockASGIApp() 29 | middleware = RemoveRootPathMiddleware(mock_app, root_path="/api") 30 | 31 | # Test with root path 32 | scope = { 33 | "type": "http", 34 | "path": "/api/test", 35 | "raw_path": b"/api/test", 36 | } 37 | await middleware(scope, None, None) 38 | assert mock_app.called 39 | assert mock_app.scope["path"] == "/test" 40 | assert mock_app.scope["raw_path"] == b"/api/test" 41 | 42 | 43 | @pytest.mark.asyncio 44 | async def test_remove_root_path_middleware_non_http(): 45 | """Test that non-HTTP requests are passed through unchanged.""" 46 | mock_app = MockASGIApp() 47 | middleware = RemoveRootPathMiddleware(mock_app, root_path="/api") 48 | 49 | scope = { 50 | "type": "websocket", 51 | "path": "/api/test", 52 | } 53 | await middleware(scope, None, None) 54 | assert mock_app.called 55 | assert mock_app.scope["path"] == "/api/test" 56 | 57 | 58 | @pytest.mark.asyncio 59 | async def test_remove_root_path_middleware_empty_path(): 60 | """Test that empty path after root path removal is set to '/'.""" 61 | mock_app = MockASGIApp() 62 | middleware = RemoveRootPathMiddleware(mock_app, root_path="/api") 63 | 64 | scope = { 65 | "type": "http", 66 | "path": "/api", 67 | "raw_path": b"/api", 68 | } 69 | await middleware(scope, None, None) 70 | assert mock_app.called 71 | assert mock_app.scope["path"] == "/" 72 | assert mock_app.scope["raw_path"] == b"/api" 73 | 74 | 75 | def test_remove_root_path_middleware_integration(): 76 | """Test middleware integration with FastAPI.""" 77 | app = FastAPI() 78 | app.add_middleware(RemoveRootPathMiddleware, root_path="/api") 79 | 80 | @app.get("/test") 81 | async def test_endpoint(): 82 | return {"message": "test"} 83 | 84 | client = TestClient(app) 85 | 86 | # Test with root path 87 | response = client.get("/api/test") 88 | assert response.status_code == 200 89 | assert response.json() == {"message": "test"} 90 | 91 | # Test without root path 92 | response = client.get("/test") 93 | assert response.status_code == 404 # Should not find the endpoint 94 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for testing.""" 2 | 3 | import json 4 | from dataclasses import dataclass 5 | from typing import Callable, cast 6 | from unittest.mock import MagicMock 7 | from urllib.parse import parse_qs, unquote 8 | 9 | import httpx 10 | from httpx import Headers, Request 11 | 12 | from stac_auth_proxy import Settings, create_app 13 | 14 | 15 | class AppFactory: 16 | """Factory for creating test apps with default settings.""" 17 | 18 | def __init__(self, **defaults): 19 | """Initialize the factory with default settings.""" 20 | self.defaults = defaults 21 | 22 | def __call__(self, *, upstream_url, **overrides) -> Callable: 23 | """Create a new app with the given overrides.""" 24 | return create_app( 25 | Settings.model_validate( 26 | { 27 | **self.defaults, 28 | **overrides, 29 | "upstream_url": upstream_url, 30 | }, 31 | ) 32 | ) 33 | 34 | 35 | @dataclass 36 | class SingleChunkAsyncStream(httpx.AsyncByteStream): 37 | """Mock async stream that returns a single chunk of data.""" 38 | 39 | body: bytes 40 | 41 | async def __aiter__(self): 42 | """Return a single chunk of data.""" 43 | yield self.body 44 | 45 | 46 | def single_chunk_async_stream_response( 47 | body: bytes, status_code=200, headers={"content-type": "application/json"} 48 | ): 49 | """Create a response with a single chunk of data.""" 50 | return httpx.Response( 51 | stream=SingleChunkAsyncStream(body), 52 | status_code=status_code, 53 | headers=headers, 54 | ) 55 | 56 | 57 | def parse_query_string(qs: str) -> dict: 58 | """Parse a query string into a dictionary.""" 59 | # Python's parse_qs will turn dicts into strings (e.g. parse_qs('foo={"x":"y"}') == {'foo': ['{"x":"y"}']}) 60 | # so we need some special tooling to examine the query params and compare them to expected values 61 | parsed = parse_qs(qs) 62 | 63 | result = {} 64 | for key, value_list in parsed.items(): 65 | value = value_list[0] 66 | if key == "filter" and parsed.get("filter-lang") == ["cql2-json"]: 67 | decoded_str = unquote(value) 68 | result[key] = json.loads(decoded_str) 69 | else: 70 | result[key] = unquote(value) 71 | 72 | return result 73 | 74 | 75 | async def get_upstream_request(mock_upstream: MagicMock) -> "UpstreamRequest": 76 | """Fetch the raw body and query params from the single upstream request.""" 77 | assert mock_upstream.call_count == 1 78 | [request] = cast(list[Request], mock_upstream.call_args[0]) 79 | req_body = request._streamed_body 80 | return UpstreamRequest( 81 | body=req_body.decode(), 82 | query_params=parse_query_string(request.url.query.decode("utf-8")), 83 | headers=request.headers, 84 | ) 85 | 86 | 87 | @dataclass 88 | class UpstreamRequest: 89 | """The raw body and query params from the single upstream request.""" 90 | 91 | body: str 92 | query_params: dict 93 | headers: Headers 94 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/utils/cache.py: -------------------------------------------------------------------------------- 1 | """Cache utilities.""" 2 | 3 | from dataclasses import dataclass, field 4 | from time import time 5 | from typing import Any 6 | 7 | from stac_auth_proxy.utils.filters import logger 8 | 9 | 10 | @dataclass 11 | class MemoryCache: 12 | """Cache results of a method call for a given key.""" 13 | 14 | ttl: float = 5.0 15 | cache: dict[tuple[Any], tuple[Any, float]] = field(default_factory=dict) 16 | _last_pruned: float = field(default_factory=time) 17 | 18 | def __getitem__(self, key: Any) -> Any: 19 | """Get a value from the cache if it is not expired.""" 20 | if key not in self.cache: 21 | msg = f"{self._key_str(key)!r} not in cache." 22 | logger.debug(msg) 23 | raise KeyError(msg) 24 | 25 | result, timestamp = self.cache[key] 26 | if (time() - timestamp) > self.ttl: 27 | msg = f"{self._key_str(key)!r} in cache, but expired." 28 | del self.cache[key] 29 | logger.debug(msg) 30 | raise KeyError(f"{key} expired") 31 | 32 | logger.debug(f"{self._key_str(key)} in cache, returning cached result.") 33 | return result 34 | 35 | def __setitem__(self, key: Any, value: Any): 36 | """Set a value in the cache.""" 37 | self.cache[key] = (value, time()) 38 | self._prune() 39 | 40 | def __contains__(self, key: Any) -> bool: 41 | """Check if a key is in the cache and is not expired.""" 42 | try: 43 | self[key] 44 | return True 45 | except KeyError: 46 | return False 47 | 48 | def get(self, key: Any) -> Any: 49 | """Get a value from the cache.""" 50 | try: 51 | return self[key] 52 | except KeyError: 53 | return None 54 | 55 | def _prune(self): 56 | """Prune the cache of expired items.""" 57 | if time() - self._last_pruned < self.ttl: 58 | return 59 | self.cache = { 60 | k: (v, time_entered) 61 | for k, (v, time_entered) in self.cache.items() 62 | if time_entered > (time() - self.ttl) 63 | } 64 | self._last_pruned = time() 65 | 66 | @staticmethod 67 | def _key_str(key: Any) -> str: 68 | """Get a string representation of a key.""" 69 | return key if len(str(key)) < 10 else f"{str(key)[:9]}..." 70 | 71 | 72 | def get_value_by_path(obj: dict, path: str, default: Any = None) -> Any: 73 | """ 74 | Get a value from a dictionary using dot notation. 75 | 76 | Args: 77 | obj: The dictionary to search in 78 | path: The dot notation path (e.g. "payload.sub") 79 | default: Default value to return if path doesn't exist 80 | 81 | Returns: 82 | ------- 83 | The value at the specified path or default if path doesn't exist 84 | 85 | """ 86 | try: 87 | for key in path.split("."): 88 | if obj is None: 89 | return default 90 | obj = obj.get(key, default) 91 | return obj 92 | except (AttributeError, KeyError, TypeError): 93 | return default 94 | -------------------------------------------------------------------------------- /helm/values.yaml: -------------------------------------------------------------------------------- 1 | # Default values for stac-auth-proxy 2 | 3 | replicaCount: 1 4 | 5 | image: 6 | repository: ghcr.io/developmentseed/stac-auth-proxy 7 | pullPolicy: IfNotPresent 8 | tag: "latest" 9 | 10 | service: 11 | type: ClusterIP 12 | port: 8000 13 | 14 | ingress: 15 | enabled: true 16 | className: "nginx" 17 | annotations: 18 | cert-manager.io/cluster-issuer: "letsencrypt-prod" 19 | host: "stac-proxy.example.com" # This should be overridden in production 20 | tls: 21 | enabled: true 22 | secretName: "" # If empty, will be auto-generated as "{host}-tls" 23 | 24 | resources: 25 | limits: 26 | cpu: 500m 27 | memory: 512Mi 28 | requests: 29 | cpu: 200m 30 | memory: 256Mi 31 | 32 | # Pod-level security context 33 | securityContext: 34 | runAsNonRoot: true 35 | runAsUser: 1000 36 | runAsGroup: 1000 37 | 38 | # Container-level security context 39 | containerSecurityContext: 40 | allowPrivilegeEscalation: false 41 | capabilities: 42 | drop: 43 | - ALL 44 | 45 | nodeSelector: {} 46 | tolerations: [] 47 | affinity: {} 48 | 49 | # Additional volumes to mount 50 | extraVolumes: [] 51 | # Example: 52 | # extraVolumes: 53 | # - name: filters 54 | # configMap: 55 | # name: stac-auth-proxy-filters 56 | 57 | # Additional volume mounts for the container 58 | extraVolumeMounts: [] 59 | # Example: 60 | # extraVolumeMounts: 61 | # - name: filters 62 | # mountPath: /app/src/stac_auth_proxy/custom_filters.py 63 | # subPath: custom_filters.py 64 | # readOnly: true 65 | 66 | # Init containers to run before the main container starts 67 | # initContainers: [] 68 | # Example: 69 | # initContainers: 70 | # - name: wait-for-oidc 71 | # image: busybox:1.35 72 | # command: ['sh', '-c', 'until nc -z oidc-server 8080; do sleep 2; done'] 73 | 74 | # Environment variables for the application 75 | env: 76 | # Required configuration 77 | UPSTREAM_URL: "" # STAC API URL 78 | OIDC_DISCOVERY_URL: "" # OpenID Connect discovery URL 79 | 80 | # Optional configuration 81 | WAIT_FOR_UPSTREAM: true 82 | HEALTHZ_PREFIX: "/healthz" 83 | OIDC_DISCOVERY_INTERNAL_URL: "" 84 | DEFAULT_PUBLIC: false 85 | PRIVATE_ENDPOINTS: | 86 | { 87 | "^/collections$": ["POST"], 88 | "^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"], 89 | "^/collections/([^/]+)/items$": ["POST"], 90 | "^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"], 91 | "^/collections/([^/]+)/bulk_items$": ["POST"] 92 | } 93 | PUBLIC_ENDPOINTS: | 94 | { 95 | "^/api.html$": ["GET"], 96 | "^/api$": ["GET"], 97 | "^/docs/oauth2-redirect": ["GET"], 98 | "^/healthz": ["GET"] 99 | } 100 | 101 | 102 | 103 | serviceAccount: 104 | # Specifies whether a service account should be created 105 | create: true 106 | # Annotations to add to the service account 107 | annotations: {} 108 | # The name of the service account to use. 109 | # If not set and create is true, a name is generated using the fullname template 110 | name: "" 111 | # Image pull secrets to add to the service account 112 | imagePullSecrets: [] 113 | # - name: my-registry-secret -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

stac auth proxy

3 |

Reverse proxy to apply auth*n to your STAC API.

4 |
5 | 6 | --- 7 | 8 | [![PyPI - Version][pypi-version-badge]][pypi-link] 9 | [![GHCR - Version][ghcr-version-badge]][ghcr-link] 10 | [![GHCR - Size][ghcr-size-badge]][ghcr-link] 11 | [![codecov][codecov-badge]][codecov-link] 12 | [![Tests][tests-badge]][tests-link] 13 | 14 | STAC Auth Proxy is a proxy API that mediates between the client and your internally accessible STAC API to provide flexible authentication, authorization, and content-filtering mechanisms. 15 | 16 | > [!IMPORTANT] 17 | > 18 | > **We would :heart: to hear from you!** 19 | > Please [join the discussion](https://github.com/developmentseed/eoAPI/discussions/209) and let us know how you're using eoAPI! This helps us improve the project for you and others. 20 | > If you prefer to remain anonymous, you can email us at eoapi@developmentseed.org, and we'll be happy to post a summary on your behalf. 21 | 22 | ## ✨Features✨ 23 | 24 | - **🔐 Authentication:** Apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) token validation and optional scope checks to specified endpoints and methods 25 | - **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on request context (e.g. user role) 26 | - **🤝 External Policy Integration:** Integrate with external systems (e.g. [Open Policy Agent (OPA)](https://www.openpolicyagent.org/)) to generate CQL2 filters dynamically from policy decisions 27 | - **🧩 Authentication Extension:** Add the [Authentication Extension](https://github.com/stac-extensions/authentication) to API responses to expose auth-related metadata 28 | - **📘 OpenAPI Augmentation:** Enhance the [OpenAPI spec](https://swagger.io/specification/) with security details to keep auto-generated docs and UIs (e.g., [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate 29 | - **🗜️ Response Compression:** Optimize response sizes using [`starlette-cramjam`](https://github.com/developmentseed/starlette-cramjam/) 30 | 31 | ## Documentation 32 | 33 | [Full documentation is available on the website](https://developmentseed.org/stac-auth-proxy). 34 | 35 | Head to [Getting Started](https://developmentseed.org/stac-auth-proxy/user-guide/getting-started/) to dig in. 36 | 37 | [pypi-version-badge]: https://badge.fury.io/py/stac-auth-proxy.svg 38 | [pypi-link]: https://pypi.org/project/stac-auth-proxy/ 39 | [ghcr-version-badge]: https://ghcr-badge.egpl.dev/developmentseed/stac-auth-proxy/latest_tag?color=%2344cc11&ignore=latest&label=image+version&trim= 40 | [ghcr-size-badge]: https://ghcr-badge.egpl.dev/developmentseed/stac-auth-proxy/size?color=%2344cc11&tag=latest&label=image+size&trim= 41 | [ghcr-link]: https://github.com/developmentseed/stac-auth-proxy/pkgs/container/stac-auth-proxy 42 | [codecov-badge]: https://codecov.io/gh/developmentseed/stac-auth-proxy/branch/main/graph/badge.svg 43 | [codecov-link]: https://codecov.io/gh/developmentseed/stac-auth-proxy 44 | [tests-badge]: https://github.com/developmentseed/stac-auth-proxy/actions/workflows/cicd.yaml/badge.svg 45 | [tests-link]: https://github.com/developmentseed/stac-auth-proxy/actions/workflows/cicd.yaml 46 | -------------------------------------------------------------------------------- /docs/user-guide/tips.md: -------------------------------------------------------------------------------- 1 | # Tips 2 | 3 | ## CORS 4 | 5 | The STAC Auth Proxy does not modify the [CORS response headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS#the_http_response_headers) from the upstream STAC API. All CORS configuration must be handled by the upstream API. 6 | 7 | Because the STAC Auth Proxy introduces authentication, the upstream API’s CORS settings may need adjustment to support credentials. In most cases, this means: 8 | 9 | - [`Access-Control-Allow-Credentials`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Credentials) must be `true` 10 | - [`Access-Control-Allow-Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin) must _not_ be `*`[^CORSNotSupportingCredentials] 11 | 12 | [^CORSNotSupportingCredentials]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS/Errors/CORSNotSupportingCredentials 13 | 14 | ## Root Paths 15 | 16 | The proxy can be optionally served from a non-root path (e.g., `/api/v1`). Additionally, the proxy can optionally proxy requests to an upstream API served from a non-root path (e.g., `/stac`). To handle this, the proxy will: 17 | 18 | - Remove the `ROOT_PATH` from incoming requests before forwarding to the upstream API 19 | - Remove the proxy's prefix from all links in STAC API responses 20 | - Add the `ROOT_PATH` prefix to all links in STAC API responses 21 | - Update the OpenAPI specification to include the `ROOT_PATH` in the servers field 22 | - Handle requests that don't match the `ROOT_PATH` with a 404 response 23 | 24 | ## Non-OIDC Workaround 25 | 26 | If the upstream server utilizes RS256 JWTs but does not utilize a proper OIDC server, the proxy can be configured to work around this by setting the `OIDC_DISCOVERY_URL` to a statically-hosted OIDC discovery document that points to a valid JWKS endpoint. 27 | 28 | ## Swagger UI Direct JWT Input 29 | 30 | Rather than performing the login flow, the Swagger UI can be configured to accept direct JWT as input with the the following configuration: 31 | 32 | ```sh 33 | OPENAPI_AUTH_SCHEME_NAME=jwtAuth 34 | OPENAPI_AUTH_SCHEME_OVERRIDE='{ 35 | "type": "http", 36 | "scheme": "bearer", 37 | "bearerFormat": "JWT", 38 | "description": "Paste your raw JWT here. This API uses Bearer token authorization." 39 | }' 40 | ``` 41 | 42 | ## Non-proxy Configuration 43 | 44 | While the STAC Auth Proxy is designed to work out-of-the-box as an application, it might not address every projects needs. When the need for customization arises, the codebase can instead be treated as a library of components that can be used to augment a FastAPI server. 45 | 46 | This may look something like the following: 47 | 48 | ```py 49 | from fastapi import FastAPI 50 | from stac_fastapi.api.app import StacApi 51 | from stac_auth_proxy import configure_app, Settings as StacAuthSettings 52 | 53 | # Create Auth Settings 54 | auth_settings = StacAuthSettings( 55 | upstream_url='https://stac-server', # Dummy value, we don't make use of this value in non-proxy mode 56 | oidc_discovery_url='https://auth-server/.well-known/openid-configuration', 57 | ) 58 | 59 | # Setup App 60 | app = FastAPI( ... ) 61 | 62 | # Apply STAC Auth Proxy middleware 63 | configure_app(app, auth_settings) 64 | 65 | # Setup STAC API 66 | api = StacApi( app, ... ) 67 | ``` 68 | 69 | > [!IMPORTANT] 70 | > Avoid using `build_lifespan()` when operating in non-proxy mode, as we are unable to check for the non-existent upstream API. 71 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to add auth information to the OpenAPI spec served by upstream API.""" 2 | 3 | import re 4 | from dataclasses import dataclass 5 | from typing import Any, Optional 6 | 7 | from starlette.datastructures import Headers 8 | from starlette.requests import Request 9 | from starlette.types import ASGIApp, Scope 10 | 11 | from ..config import EndpointMethods 12 | from ..utils.middleware import JsonResponseMiddleware 13 | from ..utils.requests import find_match 14 | 15 | 16 | @dataclass(frozen=True) 17 | class OpenApiMiddleware(JsonResponseMiddleware): 18 | """Middleware to add the OpenAPI spec to the response.""" 19 | 20 | app: ASGIApp 21 | openapi_spec_path: str 22 | oidc_discovery_url: str 23 | private_endpoints: EndpointMethods 24 | public_endpoints: EndpointMethods 25 | default_public: bool 26 | root_path: str = "" 27 | auth_scheme_name: str = "oidcAuth" 28 | auth_scheme_override: Optional[dict] = None 29 | 30 | json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)" 31 | 32 | def should_transform_response(self, request: Request, scope: Scope) -> bool: 33 | """Only transform responses for the OpenAPI spec path.""" 34 | return ( 35 | all( 36 | re.match(expr, val) 37 | for expr, val in [ 38 | (self.openapi_spec_path, request.url.path), 39 | ( 40 | self.json_content_type_expr, 41 | Headers(scope=scope).get("content-type", ""), 42 | ), 43 | ] 44 | ) 45 | and 200 <= scope["status"] < 300 46 | ) 47 | 48 | def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]: 49 | """Augment the OpenAPI spec with auth information.""" 50 | # Remove any existing servers field from upstream API 51 | # This ensures we don't have conflicting server declarations 52 | if "servers" in data: 53 | del data["servers"] 54 | 55 | # Add servers field with root path if root_path is set 56 | if self.root_path: 57 | data["servers"] = [{"url": self.root_path}] 58 | 59 | # Add security scheme 60 | components = data.setdefault("components", {}) 61 | securitySchemes = components.setdefault("securitySchemes", {}) 62 | securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or { 63 | "type": "openIdConnect", 64 | "openIdConnectUrl": self.oidc_discovery_url, 65 | } 66 | 67 | # Add security to private endpoints 68 | for path, method_config in data["paths"].items(): 69 | for method, config in method_config.items(): 70 | if method == "options": 71 | # OPTIONS requests are not authenticated, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials 72 | continue 73 | match = find_match( 74 | path, 75 | method, 76 | self.private_endpoints, 77 | self.public_endpoints, 78 | self.default_public, 79 | ) 80 | if match.is_private: 81 | config.setdefault("security", []).append( 82 | {self.auth_scheme_name: match.required_scopes} 83 | ) 84 | return data 85 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [{name = "Anthony Lukach", email = "anthonylukach@gmail.com"}] 3 | classifiers = [ 4 | "Programming Language :: Python :: 3", 5 | "Programming Language :: Python :: 3.8", 6 | "License :: OSI Approved :: MIT License", 7 | ] 8 | dependencies = [ 9 | "boto3>=1.37.16", 10 | "brotli>=1.1.0", 11 | "cql2>=0.4.2", 12 | "cryptography>=44.0.1", 13 | "fastapi>=0.115.5", 14 | "httpx[http2]>=0.28.0", 15 | "jinja2>=3.1.4", 16 | "pydantic-settings>=2.6.1", 17 | "pyjwt>=2.10.1", 18 | "starlette-cramjam>=0.4.0", 19 | "uvicorn>=0.32.1", 20 | ] 21 | description = "STAC authentication proxy with FastAPI" 22 | keywords = ["STAC", "FastAPI", "Authentication", "Proxy"] 23 | license = {file = "LICENSE"} 24 | name = "stac-auth-proxy" 25 | readme = "README.md" 26 | requires-python = ">=3.10" 27 | version = "0.11.0" 28 | 29 | [project.optional-dependencies] 30 | docs = [ 31 | "griffe-fieldz>=0.3.0", 32 | "griffe-inherited-docstrings>=1.1.1", 33 | "markdown-gfm-admonition>=0.1.1", 34 | "mkdocs>=1.6.1", 35 | "mkdocs-api-autonav>=0.3.0", 36 | "mkdocs-material[imaging]>=9.6.16", 37 | "mkdocstrings[python]>=0.30.0", 38 | ] 39 | lambda = [ 40 | "mangum>=0.19.0", 41 | ] 42 | 43 | [tool.coverage.run] 44 | branch = true 45 | source = ["src/stac_auth_proxy"] 46 | omit = [ 47 | "*/tests/*", 48 | "*/test_*", 49 | "*/__pycache__/*", 50 | "*/venv/*", 51 | "*/build/*", 52 | "*/dist/*", 53 | "*/htmlcov/*", 54 | "*/lambda.py", # Lambda entry point not tested in unit tests 55 | ] 56 | 57 | [tool.coverage.report] 58 | exclude_lines = [ 59 | "pragma: no cover", 60 | "def __repr__", 61 | "if self.debug:", 62 | "if settings.DEBUG", 63 | "raise AssertionError", 64 | "raise NotImplementedError", 65 | "if 0:", 66 | "if __name__ == .__main__.:", 67 | "class .*\\bProtocol\\):", 68 | "@(abc\\.)?abstractmethod", 69 | ] 70 | 71 | [tool.coverage.html] 72 | directory = "htmlcov" 73 | 74 | [tool.coverage.xml] 75 | output = "coverage.xml" 76 | 77 | [tool.isort] 78 | known_first_party = ["stac_auth_proxy"] 79 | profile = "black" 80 | 81 | [tool.ruff.lint] 82 | ignore = ["E501", "D203", "D205", "D212"] 83 | select = ["D", "E", "F"] 84 | 85 | [build-system] 86 | build-backend = "hatchling.build" 87 | requires = ["hatchling>=1.12.0"] 88 | 89 | [dependency-groups] 90 | dev = [ 91 | "jwcrypto>=1.5.6", 92 | "mypy>=1.3.0", 93 | "pre-commit>=3.5.0", 94 | "pytest-asyncio>=0.25.1", 95 | "pytest-cov>=5.0.0", 96 | "pytest-xdist>=3.6.1", 97 | "pytest>=8.3.3", 98 | "ruff>=0.0.238", 99 | "starlette-cramjam>=0.4.0", 100 | "types-simplejson", 101 | "types-attrs", 102 | ] 103 | 104 | [tool.pytest.ini_options] 105 | asyncio_default_fixture_loop_scope = "function" 106 | asyncio_mode = "auto" 107 | testpaths = ["tests"] 108 | python_files = ["test_*.py", "*_test.py"] 109 | python_classes = ["Test*"] 110 | python_functions = ["test_*"] 111 | addopts = [ 112 | "--strict-markers", 113 | "--strict-config", 114 | "--verbose", 115 | "--tb=short", 116 | "--cov=src/stac_auth_proxy", 117 | "--cov-report=term-missing", 118 | "--cov-report=html", 119 | "--cov-report=xml", 120 | "--cov-fail-under=85", 121 | ] 122 | markers = [ 123 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 124 | "integration: marks tests as integration tests", 125 | "unit: marks tests as unit tests", 126 | ] 127 | -------------------------------------------------------------------------------- /tests/test_cache.py: -------------------------------------------------------------------------------- 1 | """Tests for cache utilities.""" 2 | 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from stac_auth_proxy.utils.cache import MemoryCache, get_value_by_path 8 | 9 | 10 | def test_memory_cache_basic_operations(): 11 | """Test basic cache operations.""" 12 | cache = MemoryCache(ttl=5.0) # 5 second TTL 13 | key = "test_key" 14 | value = "test_value" 15 | 16 | # Test setting and getting a value 17 | cache[key] = value 18 | assert cache[key] == value 19 | assert key in cache 20 | 21 | # Test getting non-existent key 22 | with pytest.raises(KeyError): 23 | _ = cache["non_existent"] 24 | 25 | # Test get() method 26 | assert cache.get(key) == value 27 | assert cache.get("non_existent") is None 28 | 29 | 30 | def test_memory_cache_expiration(): 31 | """Test cache expiration.""" 32 | cache = MemoryCache(ttl=5.0) # 5 second TTL 33 | key = "test_key" 34 | value = "test_value" 35 | 36 | # Set initial time 37 | with patch("stac_auth_proxy.utils.cache.time") as mock_time: 38 | mock_time.return_value = 1000.0 39 | cache[key] = value 40 | assert cache[key] == value 41 | 42 | # Advance time past TTL 43 | mock_time.return_value = 1006.0 # 6 seconds later 44 | 45 | # Test expired key 46 | with pytest.raises(KeyError): 47 | cache[key] 48 | 49 | # Test contains after expiration 50 | assert key not in cache 51 | 52 | 53 | def test_memory_cache_pruning(): 54 | """Test cache pruning.""" 55 | cache = MemoryCache(ttl=5.0) # 5 second TTL 56 | key1 = "key1" 57 | key2 = "key2" 58 | value = "test_value" 59 | 60 | with patch("stac_auth_proxy.utils.cache.time") as mock_time: 61 | # Set initial time 62 | mock_time.return_value = 1000.0 63 | cache[key1] = value 64 | cache[key2] = value 65 | 66 | # Advance time past TTL 67 | mock_time.return_value = 1006.0 # 6 seconds later 68 | 69 | # Force pruning by adding a new item 70 | cache["key3"] = value 71 | 72 | # Check that expired items were pruned 73 | assert key1 not in cache 74 | assert key2 not in cache 75 | assert "key3" in cache 76 | 77 | 78 | def test_memory_cache_key_str(): 79 | """Test key string representation.""" 80 | cache = MemoryCache() 81 | 82 | # Test short key 83 | short_key = "123" 84 | assert cache._key_str(short_key) == short_key 85 | 86 | # Test long key 87 | long_key = "1234567890" 88 | assert cache._key_str(long_key) == "123456789..." 89 | 90 | 91 | @pytest.mark.parametrize( 92 | "obj, path, default, expected", 93 | [ 94 | # Basic path 95 | ({"a": {"b": 1}}, "a.b", None, 1), 96 | # Nested path 97 | ({"a": {"b": {"c": 2}}}, "a.b.c", None, 2), 98 | # Non-existent path 99 | ({"a": {"b": 1}}, "a.c", None, None), 100 | # Default value 101 | ({"a": {"b": 1}}, "a.c", "default", "default"), 102 | # None in path 103 | ({"a": None}, "a.b", None, None), 104 | # Empty path 105 | ({"a": 1}, "", None, None), 106 | # Complex object 107 | ({"a": {"b": [1, 2, 3]}}, "a.b", None, [1, 2, 3]), 108 | ], 109 | ) 110 | def test_get_value_by_path(obj, path, default, expected): 111 | """Test getting values by path.""" 112 | assert get_value_by_path(obj, path, default) == expected 113 | -------------------------------------------------------------------------------- /helm/templates/NOTES.txt: -------------------------------------------------------------------------------- 1 | Thank you for installing {{ .Chart.Name }}. 2 | 3 | Your STAC Auth Proxy has been deployed with the following configuration: 4 | 5 | 1. Application Access: 6 | {{- if .Values.ingress.enabled }} 7 | {{- if .Values.ingress.host }} 8 | Your proxy is available at: 9 | {{- if .Values.ingress.tls.enabled }} 10 | https://{{ .Values.ingress.host }} 11 | {{- else }} 12 | http://{{ .Values.ingress.host }} 13 | {{- end }} 14 | {{- end }} 15 | {{- else if contains "NodePort" .Values.service.type }} 16 | Get the application URL by running these commands: 17 | export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "stac-auth-proxy.fullname" . }}) 18 | export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}") 19 | echo http://$NODE_IP:$NODE_PORT 20 | {{- else if contains "LoadBalancer" .Values.service.type }} 21 | Get the application URL by running these commands: 22 | NOTE: It may take a few minutes for the LoadBalancer IP to be available. 23 | You can watch the status by running: 24 | kubectl get svc --namespace {{ .Release.Namespace }} {{ include "stac-auth-proxy.fullname" . }} -w 25 | 26 | Once ready, get the external IP/hostname with: 27 | export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "stac-auth-proxy.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}") 28 | echo http://$SERVICE_IP:{{ .Values.service.port }} 29 | {{- else }} 30 | The service is accessible within the cluster at: 31 | {{ include "stac-auth-proxy.fullname" . }}.{{ .Release.Namespace }}.svc.cluster.local:{{ .Values.service.port }} 32 | {{- end }} 33 | 34 | 2. Configuration Details: 35 | - Upstream STAC API: {{ .Values.env.UPSTREAM_URL }} 36 | - OIDC Discovery URL: {{ .Values.env.OIDC_DISCOVERY_URL }} 37 | - Health Check Endpoint: {{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }} 38 | - Default Public Access: {{ .Values.env.DEFAULT_PUBLIC | default "false" }} 39 | 40 | 3. Verify the deployment: 41 | kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "stac-auth-proxy.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" 42 | 43 | 4. View the logs: 44 | kubectl logs --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "stac-auth-proxy.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" 45 | 46 | 5. Health check: 47 | {{- if .Values.ingress.enabled }} 48 | {{- if .Values.ingress.host }} 49 | {{- if .Values.ingress.tls.enabled }} 50 | curl https://{{ .Values.ingress.host }}{{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }} 51 | {{- else }} 52 | curl http://{{ .Values.ingress.host }}{{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }} 53 | {{- end }} 54 | {{- end }} 55 | {{- else }} 56 | kubectl port-forward --namespace {{ .Release.Namespace }} service/{{ include "stac-auth-proxy.fullname" . }} 8000:{{ .Values.service.port }} 57 | curl http://localhost:8000{{ .Values.env.HEALTHZ_PREFIX | default "/healthz" }} 58 | {{- end }} 59 | 60 | For more information about STAC Auth Proxy, please visit: 61 | https://github.com/developmentseed/stac-auth-proxy 62 | 63 | {{- if or (not .Values.env.UPSTREAM_URL) (not .Values.env.OIDC_DISCOVERY_URL) }} 64 | WARNING: Some required configuration values are not set. Please ensure you have configured: 65 | {{- if not .Values.env.UPSTREAM_URL }} 66 | - env.UPSTREAM_URL 67 | {{- end }} 68 | {{- if not .Values.env.OIDC_DISCOVERY_URL }} 69 | - env.OIDC_DISCOVERY_URL 70 | {{- end }} 71 | {{- end }} -------------------------------------------------------------------------------- /tests/test_filters_opa.py: -------------------------------------------------------------------------------- 1 | """Test OPA filter integration.""" 2 | 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import pytest 6 | from httpx import AsyncClient, Response 7 | 8 | from stac_auth_proxy.filters.opa import Opa 9 | 10 | 11 | @pytest.fixture 12 | def opa_filter_factory(): 13 | """Create an OPA instance for testing.""" 14 | return Opa(host="http://localhost:8181", decision="stac/filter") 15 | 16 | 17 | @pytest.fixture 18 | def mock_opa_response(): 19 | """Create a mock httpx Response.""" 20 | response = MagicMock(spec=Response) 21 | response.json.return_value = {"result": "collection = 'test'"} 22 | response.raise_for_status.return_value = response 23 | return response 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_opa_initialization(opa_filter_factory): 28 | """Test OPA initialization.""" 29 | assert opa_filter_factory.host == "http://localhost:8181" 30 | assert opa_filter_factory.decision == "stac/filter" 31 | assert opa_filter_factory.cache_key == "req.headers.authorization" 32 | assert opa_filter_factory.cache_ttl == 5.0 33 | assert isinstance(opa_filter_factory.client, AsyncClient) 34 | assert opa_filter_factory.cache is not None 35 | 36 | 37 | @pytest.mark.asyncio 38 | async def test_opa_cache_hit(opa_filter_factory, mock_opa_response): 39 | """Test OPA cache hit behavior.""" 40 | context = {"req": {"headers": {"authorization": "test-token"}}} 41 | 42 | # Mock the OPA response 43 | with patch.object( 44 | opa_filter_factory.client, "post", new_callable=AsyncMock 45 | ) as mock_post: 46 | mock_post.return_value = mock_opa_response 47 | 48 | # First call should hit OPA 49 | result = await opa_filter_factory(context) 50 | assert result == "collection = 'test'" 51 | assert mock_post.call_count == 1 52 | 53 | # Second call should use cache 54 | result = await opa_filter_factory(context) 55 | assert result == "collection = 'test'" 56 | assert mock_post.call_count == 1 # Still 1, no new call made 57 | 58 | 59 | @pytest.mark.asyncio 60 | async def test_opa_cache_miss(opa_filter_factory, mock_opa_response): 61 | """Test OPA cache miss behavior.""" 62 | context = {"req": {"headers": {"authorization": "test-token"}}} 63 | 64 | with patch.object( 65 | opa_filter_factory.client, "post", new_callable=AsyncMock 66 | ) as mock_post: 67 | mock_post.return_value = mock_opa_response 68 | 69 | # First call with token1 70 | result = await opa_filter_factory(context) 71 | assert result == "collection = 'test'" 72 | assert mock_post.call_count == 1 73 | 74 | # Call with different token should miss cache 75 | context["req"]["headers"]["authorization"] = "different-token" 76 | result = await opa_filter_factory(context) 77 | assert result == "collection = 'test'" 78 | assert mock_post.call_count == 2 # New call made 79 | 80 | 81 | @pytest.mark.asyncio 82 | async def test_opa_error_handling(opa_filter_factory): 83 | """Test OPA error handling.""" 84 | context = {"req": {"headers": {"authorization": "test-token"}}} 85 | 86 | with patch.object( 87 | opa_filter_factory.client, "post", new_callable=AsyncMock 88 | ) as mock_post: 89 | # Create a mock response that raises an exception on raise_for_status 90 | error_response = MagicMock(spec=Response) 91 | error_response.raise_for_status.side_effect = Exception("Internal server error") 92 | mock_post.return_value = error_response 93 | 94 | with pytest.raises(Exception): 95 | await opa_filter_factory(context) 96 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests.""" 2 | 3 | import json 4 | from dataclasses import dataclass 5 | from logging import getLogger 6 | from typing import Optional 7 | 8 | from cql2 import Expr 9 | from starlette.requests import Request 10 | from starlette.types import ASGIApp, Receive, Scope, Send 11 | 12 | from ..utils import filters 13 | from ..utils.middleware import required_conformance 14 | 15 | logger = getLogger(__name__) 16 | 17 | 18 | @required_conformance( 19 | r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", 20 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", 21 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", 22 | ) 23 | @dataclass(frozen=True) 24 | class Cql2ApplyFilterBodyMiddleware: 25 | """Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests.""" 26 | 27 | app: ASGIApp 28 | state_key: str = "cql2_filter" 29 | 30 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 31 | """Apply the CQL2 filter to the request body.""" 32 | if scope["type"] != "http": 33 | return await self.app(scope, receive, send) 34 | 35 | request = Request(scope) 36 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) 37 | if not cql2_filter: 38 | return await self.app(scope, receive, send) 39 | 40 | if request.method not in ["POST", "PUT", "PATCH"]: 41 | return await self.app(scope, receive, send) 42 | 43 | body = b"" 44 | more_body = True 45 | while more_body: 46 | message = await receive() 47 | if message["type"] == "http.request": 48 | body += message.get("body", b"") 49 | more_body = message.get("more_body", False) 50 | 51 | try: 52 | body_json = json.loads(body) if body else {} 53 | except json.JSONDecodeError: 54 | logger.warning("Failed to parse request body as JSON") 55 | from starlette.responses import JSONResponse 56 | 57 | response = JSONResponse( 58 | { 59 | "code": "ParseError", 60 | "description": "Request body must be valid JSON.", 61 | }, 62 | status_code=400, 63 | ) 64 | await response(scope, receive, send) 65 | return 66 | 67 | if not isinstance(body_json, dict): 68 | logger.warning("Request body must be a JSON object") 69 | from starlette.responses import JSONResponse 70 | 71 | response = JSONResponse( 72 | { 73 | "code": "TypeError", 74 | "description": "Request body must be a JSON object.", 75 | }, 76 | status_code=400, 77 | ) 78 | await response(scope, receive, send) 79 | return 80 | 81 | new_body = json.dumps( 82 | filters.append_body_filter(body_json, cql2_filter) 83 | ).encode("utf-8") 84 | 85 | # Patch content-length in the headers 86 | headers = dict(scope["headers"]) 87 | headers[b"content-length"] = str(len(new_body)).encode("latin1") 88 | scope = dict(scope) 89 | scope["headers"] = list(headers.items()) 90 | 91 | async def new_receive(): 92 | return { 93 | "type": "http.request", 94 | "body": new_body, 95 | "more_body": False, 96 | } 97 | 98 | await self.app(scope, new_receive, send) 99 | -------------------------------------------------------------------------------- /tests/test_defaults.py: -------------------------------------------------------------------------------- 1 | """Basic test cases for the proxy app.""" 2 | 3 | import pytest 4 | from fastapi.testclient import TestClient 5 | from utils import AppFactory 6 | 7 | app_factory = AppFactory( 8 | oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration" 9 | ) 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "path,method,expected_status", 14 | [ 15 | ("/", "GET", 200), 16 | ("/conformance", "GET", 200), 17 | ("/queryables", "GET", 200), 18 | ("/search", "GET", 200), 19 | ("/search", "POST", 200), 20 | ("/collections", "GET", 200), 21 | ("/collections", "POST", 401), 22 | ("/collections/example-collection", "GET", 200), 23 | ("/collections/example-collection", "PUT", 401), 24 | ("/collections/example-collection", "DELETE", 401), 25 | ("/collections/example-collection/items", "GET", 200), 26 | ("/collections/example-collection/items", "POST", 401), 27 | ("/collections/example-collection/items/example-item", "GET", 200), 28 | ("/collections/example-collection/items/example-item", "PUT", 401), 29 | ("/collections/example-collection/items/example-item", "DELETE", 401), 30 | ("/collections/example-collection/bulk_items", "POST", 401), 31 | ("/api.html", "GET", 200), 32 | ("/api", "GET", 200), 33 | ], 34 | ) 35 | def test_default_public_true(source_api_server, path, method, expected_status): 36 | """ 37 | When default_public=true and private_endpoints are set, all endpoints should be 38 | public except for transaction endpoints. 39 | """ 40 | test_app = app_factory( 41 | upstream_url=source_api_server, 42 | public_endpoints={}, 43 | private_endpoints={ 44 | r"^/collections$": ["POST"], 45 | r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"], 46 | r"^/collections/([^/]+)/items$": ["POST"], 47 | r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"], 48 | r"^/collections/([^/]+)/bulk_items$": ["POST"], 49 | }, 50 | default_public=True, 51 | ) 52 | client = TestClient(test_app) 53 | response = client.request(method=method, url=path) 54 | assert response.status_code == expected_status 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "path,method,expected_status", 59 | [ 60 | ("/", "GET", 401), 61 | ("/conformance", "GET", 401), 62 | ("/queryables", "GET", 401), 63 | ("/search", "GET", 401), 64 | ("/search", "POST", 401), 65 | ("/collections", "GET", 401), 66 | ("/collections", "POST", 401), 67 | ("/collections/example-collection", "GET", 401), 68 | ("/collections/example-collection", "PUT", 401), 69 | ("/collections/example-collection", "DELETE", 401), 70 | ("/collections/example-collection/items", "GET", 401), 71 | ("/collections/example-collection/items", "POST", 401), 72 | ("/collections/example-collection/items/example-item", "GET", 401), 73 | ("/collections/example-collection/items/example-item", "PUT", 401), 74 | ("/collections/example-collection/items/example-item", "DELETE", 401), 75 | ("/collections/example-collection/bulk_items", "POST", 401), 76 | ("/api.html", "GET", 200), 77 | ("/api", "GET", 200), 78 | ], 79 | ) 80 | def test_default_public_false(source_api_server, path, method, expected_status): 81 | """ 82 | When default_public=false and private_endpoints aren't set, all endpoints should be 83 | public except for transaction endpoints. 84 | """ 85 | test_app = app_factory( 86 | upstream_url=source_api_server, 87 | public_endpoints={"/api.html": ["GET"], "/api": ["GET"]}, 88 | private_endpoints={}, 89 | default_public=False, 90 | ) 91 | client = TestClient(test_app) 92 | response = client.request(method=method, url=path) 93 | assert response.status_code == expected_status 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | .pgdata 164 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to add auth information to item response served by upstream API.""" 2 | 3 | import logging 4 | import re 5 | from dataclasses import dataclass, field 6 | from typing import Any 7 | from urllib.parse import urlparse 8 | 9 | from starlette.datastructures import Headers 10 | from starlette.requests import Request 11 | from starlette.types import ASGIApp, Scope 12 | 13 | from ..config import EndpointMethods 14 | from ..utils.middleware import JsonResponseMiddleware 15 | from ..utils.requests import find_match 16 | from ..utils.stac import get_links 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @dataclass 22 | class AuthenticationExtensionMiddleware(JsonResponseMiddleware): 23 | """Middleware to add the authentication extension to the response.""" 24 | 25 | app: ASGIApp 26 | 27 | default_public: bool 28 | private_endpoints: EndpointMethods 29 | public_endpoints: EndpointMethods 30 | 31 | oidc_discovery_url: str 32 | auth_scheme_name: str = "oidc" 33 | auth_scheme: dict[str, Any] = field(default_factory=dict) 34 | extension_url: str = ( 35 | "https://stac-extensions.github.io/authentication/v1.1.0/schema.json" 36 | ) 37 | 38 | json_content_type_expr: str = r"application/(geo\+)?json" 39 | 40 | def should_transform_response(self, request: Request, scope: Scope) -> bool: 41 | """Determine if the response should be transformed.""" 42 | # Match STAC catalog, collection, or item URLs with a single regex 43 | return ( 44 | all( 45 | ( 46 | re.match(expr, val) 47 | for expr, val in [ 48 | ( 49 | # catalog, collections, collection, items, item, search 50 | r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$", 51 | request.url.path, 52 | ), 53 | ( 54 | self.json_content_type_expr, 55 | Headers(scope=scope).get("content-type", ""), 56 | ), 57 | ] 58 | ), 59 | ) 60 | and 200 <= scope["status"] < 300 61 | ) 62 | 63 | def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]: 64 | """Augment the STAC Item with auth information.""" 65 | extensions = data.setdefault("stac_extensions", []) 66 | if self.extension_url not in extensions: 67 | extensions.append(self.extension_url) 68 | 69 | # auth:schemes 70 | # --- 71 | # A property that contains all of the scheme definitions used by Assets and 72 | # Links in the STAC Item or Collection. 73 | # - Catalogs 74 | # - Collections 75 | # - Item Properties 76 | 77 | scheme_loc = data["properties"] if "properties" in data else data 78 | schemes = scheme_loc.setdefault("auth:schemes", {}) 79 | schemes[self.auth_scheme_name] = { 80 | "type": "openIdConnect", 81 | "openIdConnectUrl": self.oidc_discovery_url, 82 | } 83 | 84 | # auth:refs 85 | # --- 86 | # Annotate links with "auth:refs": [auth_scheme] 87 | for link in get_links(data): 88 | if "href" not in link: 89 | logger.warning("Link %s has no href", link) 90 | continue 91 | match = find_match( 92 | path=urlparse(link["href"]).path, 93 | method="GET", 94 | private_endpoints=self.private_endpoints, 95 | public_endpoints=self.public_endpoints, 96 | default_public=self.default_public, 97 | ) 98 | if match.is_private: 99 | link.setdefault("auth:refs", []).append(self.auth_scheme_name) 100 | 101 | return data 102 | -------------------------------------------------------------------------------- /docs/user-guide/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | STAC Auth Proxy is a reverse proxy that adds authentication and authorization to your STAC API. It sits between clients and your STAC API, validating tokens to authenticate request and applying custom authorization rules. 4 | 5 | ## Core Requirements 6 | 7 | To get started with STAC Auth Proxy, you need to provide two essential pieces of information: 8 | 9 | ### 1. OIDC Discovery URL 10 | 11 | You need a valid OpenID Connect (OIDC) discovery URL that points to your identity provider's configuration. This URL typically follows the pattern: 12 | 13 | ``` 14 | https://your-auth-provider.com/.well-known/openid-configuration 15 | ``` 16 | 17 | > [!TIP] 18 | > 19 | > Common OIDC providers include: 20 | > 21 | > - **Auth0**: `https://{tenant-id}.auth0.com/.well-known/openid-configuration` 22 | > - **AWS Cognito**: `https://cognito-idp.{region}.amazonaws.com/{user-pool-id}/.well-known/openid-configuration` 23 | > - **Azure AD**: `https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration` 24 | > - **Google**: `https://accounts.google.com/.well-known/openid-configuration` 25 | > - **Keycloak**: `https://{keycloak-server}/auth/realms/{realm-id}/.well-known/openid-configuration` 26 | 27 | ### 2. Upstream STAC API URL 28 | 29 | You need the URL to your upstream STAC API that the proxy will protect: 30 | 31 | ``` 32 | https://your-stac-api.com/stac 33 | ``` 34 | 35 | This should be a valid STAC API that conforms to the STAC specification. 36 | 37 | ## Quick Start 38 | 39 | Here's a minimal example to get you started: 40 | 41 | ### Using Docker 42 | 43 | ```bash 44 | docker run -p 8000:8000 \ 45 | -e UPSTREAM_URL=https://your-stac-api.com/stac \ 46 | -e OIDC_DISCOVERY_URL=https://your-auth-provider.com/.well-known/openid-configuration \ 47 | ghcr.io/developmentseed/stac-auth-proxy:latest 48 | ``` 49 | 50 | ### Using Python 51 | 52 | 1. Install the package: 53 | ```bash 54 | pip install stac-auth-proxy 55 | ``` 56 | 2. Set environment variables: 57 | ```bash 58 | export UPSTREAM_URL=https://your-stac-api.com/stac 59 | export OIDC_DISCOVERY_URL=https://your-auth-provider.com/.well-known/openid-configuration 60 | ``` 61 | 3. Run the proxy: 62 | ```bash 63 | python -m stac_auth_proxy 64 | ``` 65 | 66 | ### Using Docker Compose 67 | 68 | For development and experimentation, the codebase (ie within the repository, not within the Docker or Python distributions) ships with a `docker-compose.yaml` file, allowing the proxy to be run locally alongside various supporting services: the database, the STAC API, and a Mock OIDC provider. 69 | 70 | #### pgSTAC Backend 71 | 72 | Run the application stack with a pgSTAC backend using [stac-fastapi-pgstac](https://github.com/stac-utils/stac-fastapi-pgstac): 73 | 74 | ```sh 75 | docker compose up 76 | ``` 77 | 78 | #### OpenSearch Backend 79 | 80 | Run the application stack with an OpenSearch backend using [stac-fastapi-elasticsearch-opensearch](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch): 81 | 82 | ```sh 83 | docker compose --profile os up 84 | ``` 85 | 86 | The proxy will start on `http://localhost:8000` by default. 87 | 88 | ## What Happens Next? 89 | 90 | Once the proxy starts successfully: 91 | 92 | 1. **Health Check**: The proxy verifies your upstream STAC API is accessible 93 | 2. **Conformance Check**: It ensures your STAC API conforms to required specifications 94 | 3. **OIDC Discovery**: It fetches and validates your OIDC provider configuration 95 | 4. **Ready**: The proxy is now ready to handle requests 96 | 97 | ## Testing Your Setup 98 | 99 | You can test that your proxy is working by accessing the health endpoint: 100 | 101 | ```bash 102 | curl http://localhost:8000/healthz 103 | ``` 104 | 105 | ## Next Steps 106 | 107 | - [Configuration Guide](configuration.md) - Learn about all available configuration options 108 | - [Route-Level Authentication](route-level-auth.md) - Configure which endpoints require authentication 109 | - [Record-Level Authentication](record-level-auth.md) - Set up content filtering based on user permissions 110 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for OpenAPI spec handling.""" 2 | 3 | import pytest 4 | from utils import parse_query_string 5 | 6 | from stac_auth_proxy.utils.requests import ( 7 | extract_variables, 8 | get_base_url, 9 | parse_forwarded_header, 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "url, expected", 15 | ( 16 | ("/collections/123", {"collection_id": "123"}), 17 | ("/collections/123/items", {"collection_id": "123"}), 18 | ("/collections/123/bulk_items", {"collection_id": "123"}), 19 | ("/collections/123/items/456", {"collection_id": "123", "item_id": "456"}), 20 | ("/collections/123/bulk_items/456", {"collection_id": "123", "item_id": "456"}), 21 | ("/other/123", {}), 22 | ), 23 | ) 24 | def test_extract_variables(url, expected): 25 | """Test extracting variables from a URL path.""" 26 | assert extract_variables(url) == expected 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "query, expected", 31 | ( 32 | ("foo=bar", {"foo": "bar"}), 33 | ( 34 | 'filter={"xyz":"abc"}&filter-lang=cql2-json', 35 | {"filter": {"xyz": "abc"}, "filter-lang": "cql2-json"}, 36 | ), 37 | ), 38 | ) 39 | def test_parse_query_string(query, expected): 40 | """Validate test helper for parsing query strings.""" 41 | assert parse_query_string(query) == expected 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "header, expected", 46 | ( 47 | # Basic Forwarded header parsing 48 | ( 49 | "for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com", 50 | { 51 | "for": "192.0.2.43", 52 | "by": "203.0.113.60", 53 | "proto": "https", 54 | "host": "api.example.com", 55 | }, 56 | ), 57 | # Multiple for values - should only take the first 58 | ( 59 | "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=api.example.com", 60 | { 61 | "for": "192.0.2.43", 62 | "by": "203.0.113.60", 63 | "proto": "https", 64 | "host": "api.example.com", 65 | }, 66 | ), 67 | # Quoted values 68 | ( 69 | 'for="192.0.2.43"; by="203.0.113.60"; proto="https"; host="api.example.com"', 70 | { 71 | "for": "192.0.2.43", 72 | "by": "203.0.113.60", 73 | "proto": "https", 74 | "host": "api.example.com", 75 | }, 76 | ), 77 | # Malformed content 78 | ("malformed header content", {}), 79 | # Empty content 80 | ("", {}), 81 | ), 82 | ) 83 | def test_parse_forwarded_header(header, expected): 84 | """Test Forwarded header parsing with various scenarios.""" 85 | result = parse_forwarded_header(header) 86 | assert result == expected 87 | 88 | 89 | @pytest.mark.parametrize( 90 | "headers, expected_url", 91 | ( 92 | # Forwarded header 93 | ( 94 | [ 95 | (b"host", b"internal-proxy:8080"), 96 | (b"forwarded", b"for=192.0.2.43; proto=https; host=api.example.com"), 97 | ], 98 | "https://api.example.com/", 99 | ), 100 | # X-Forwarded-* headers 101 | ( 102 | [ 103 | (b"host", b"internal-proxy:8080"), 104 | (b"x-forwarded-host", b"api.example.com"), 105 | (b"x-forwarded-proto", b"https"), 106 | ], 107 | "https://api.example.com/", 108 | ), 109 | # No forwarded headers 110 | ( 111 | [ 112 | (b"host", b"proxy.example.com"), 113 | ], 114 | "http://proxy.example.com/", 115 | ), 116 | ), 117 | ) 118 | def test_get_base_url(headers, expected_url): 119 | """Test get_base_url with various header configurations.""" 120 | from starlette.requests import Request 121 | 122 | scope = { 123 | "type": "http", 124 | "method": "GET", 125 | "path": "/test", 126 | "headers": headers, 127 | } 128 | request = Request(scope) 129 | 130 | result = get_base_url(request) 131 | assert result == expected_url 132 | -------------------------------------------------------------------------------- /tests/test_lifespan.py: -------------------------------------------------------------------------------- 1 | """Tests for lifespan module.""" 2 | 3 | from dataclasses import dataclass 4 | from unittest.mock import AsyncMock, patch 5 | 6 | import pytest 7 | from fastapi import FastAPI 8 | from fastapi.testclient import TestClient 9 | from starlette.middleware import Middleware 10 | from starlette.types import ASGIApp 11 | 12 | from stac_auth_proxy import build_lifespan 13 | from stac_auth_proxy.lifespan import check_conformance, check_server_health 14 | from stac_auth_proxy.utils.middleware import required_conformance 15 | 16 | 17 | @required_conformance("http://example.com/conformance") 18 | @dataclass 19 | class ExampleMiddleware: 20 | """Test middleware with required conformance.""" 21 | 22 | app: ASGIApp 23 | 24 | 25 | async def test_check_server_health_success(source_api_server): 26 | """Test successful health check.""" 27 | await check_server_health(source_api_server) 28 | 29 | 30 | async def test_check_server_health_failure(): 31 | """Test health check failure.""" 32 | with pytest.raises(RuntimeError) as exc_info: 33 | with patch("asyncio.sleep") as mock_sleep: 34 | await check_server_health("http://localhost:9999") 35 | assert "failed to respond after" in str(exc_info.value) 36 | # Verify sleep was called with exponential backoff 37 | assert mock_sleep.call_count > 0 38 | # First call should be with base delay 39 | # NOTE: Concurrency issues makes this test flaky 40 | # assert mock_sleep.call_args_list[0][0][0] == 1.0 41 | # Last call should be with max delay 42 | assert mock_sleep.call_args_list[-1][0][0] == 5.0 43 | 44 | 45 | async def test_check_conformance_success(source_api_server, source_api_responses): 46 | """Test successful conformance check.""" 47 | middleware = [Middleware(ExampleMiddleware)] 48 | await check_conformance(middleware, source_api_server) 49 | 50 | 51 | async def test_check_conformance_failure(source_api_server, source_api_responses): 52 | """Test conformance check failure.""" 53 | # Override the conformance response to not include required conformance 54 | source_api_responses["/conformance"]["GET"] = {"conformsTo": []} 55 | 56 | middleware = [Middleware(ExampleMiddleware)] 57 | with pytest.raises(RuntimeError) as exc_info: 58 | await check_conformance(middleware, source_api_server) 59 | assert "missing the following conformance classes" in str(exc_info.value) 60 | 61 | 62 | async def test_check_conformance_multiple_middleware(source_api_server): 63 | """Test conformance check with multiple middleware.""" 64 | 65 | @required_conformance("http://example.com/conformance") 66 | class TestMiddleware2: 67 | def __init__(self, app): 68 | self.app = app 69 | 70 | middleware = [ 71 | Middleware(ExampleMiddleware), 72 | Middleware(TestMiddleware2), 73 | ] 74 | await check_conformance(middleware, source_api_server) 75 | 76 | 77 | async def test_check_conformance_no_required(source_api_server): 78 | """Test conformance check with middleware that has no required conformances.""" 79 | 80 | class NoConformanceMiddleware: 81 | def __init__(self, app): 82 | self.app = app 83 | 84 | middleware = [Middleware(NoConformanceMiddleware)] 85 | await check_conformance(middleware, source_api_server) 86 | 87 | 88 | def test_lifespan_reusable(): 89 | """Ensure the public lifespan handler runs health and conformance checks.""" 90 | upstream_url = "https://example.com" 91 | oidc_discovery_url = "https://example.com/.well-known/openid-configuration" 92 | with ( 93 | patch( 94 | "stac_auth_proxy.lifespan.check_server_health", 95 | new=AsyncMock(), 96 | ) as mock_health, 97 | patch( 98 | "stac_auth_proxy.lifespan.check_conformance", 99 | new=AsyncMock(), 100 | ) as mock_conf, 101 | ): 102 | app = FastAPI( 103 | lifespan=build_lifespan( 104 | upstream_url=upstream_url, 105 | oidc_discovery_url=oidc_discovery_url, 106 | ) 107 | ) 108 | with TestClient(app): 109 | pass 110 | assert mock_health.await_count == 2 111 | expected_upstream = upstream_url.rstrip("/") + "/" 112 | mock_conf.assert_awaited_once_with(app.user_middleware, expected_upstream) 113 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: STAC Auth Proxy 2 | site_description: A reverse proxy to mediate communication between a client and an internally accessible STAC API in order to provide a flexible authentication mechanism. 3 | site_dir: build 4 | 5 | repo_name: developmentseed/stac-auth-proxy 6 | repo_url: https://github.com/developmentseed/stac-auth-proxy 7 | edit_uri: blob/main/docs/ 8 | site_url: https://developmentseed.org/stac-auth-proxy 9 | 10 | extra: 11 | analytics: 12 | provider: plausible 13 | domain: developmentseed.org/stac-auth-proxy 14 | 15 | feedback: 16 | title: Was this page helpful? 17 | ratings: 18 | - icon: material/emoticon-happy-outline 19 | name: This page was helpful 20 | data: good 21 | note: Thanks for your feedback! 22 | 23 | - icon: material/emoticon-sad-outline 24 | name: This page could be improved 25 | data: bad 26 | note: Thanks for your feedback! 27 | social: 28 | - icon: fontawesome/brands/github 29 | link: https://github.com/developmentseed 30 | nav: 31 | - Overview: index.md 32 | - User Guide: 33 | - Getting Started: user-guide/getting-started.md 34 | - Configuration: user-guide/configuration.md 35 | - Route-Level Auth: user-guide/route-level-auth.md 36 | - Record-Level Auth: user-guide/record-level-auth.md 37 | - Deployment: user-guide/deployment.md 38 | - Tips: user-guide/tips.md 39 | - Architecture: 40 | - Middleware Stack: architecture/middleware-stack.md 41 | - Filtering Data: architecture/filtering-data.md 42 | - Changelog: changelog.md 43 | 44 | plugins: 45 | - search 46 | - social 47 | - api-autonav: 48 | modules: 49 | - src/stac_auth_proxy 50 | - mkdocstrings: 51 | enable_inventory: true 52 | handlers: 53 | python: 54 | paths: 55 | - src 56 | options: 57 | extensions: 58 | - griffe_fieldz 59 | show_signature_annotations: true 60 | inventories: 61 | - https://docs.python.org/3/objects.inv 62 | - https://docs.pydantic.dev/latest/objects.inv 63 | - https://fastapi.tiangolo.com/objects.inv 64 | - https://www.starlette.io/objects.inv 65 | 66 | theme: 67 | name: material 68 | palette: 69 | # Palette toggle for automatic mode 70 | - media: (prefers-color-scheme) 71 | toggle: 72 | icon: material/brightness-auto 73 | name: Switch to light mode 74 | 75 | # Palette toggle for light mode 76 | - media: "(prefers-color-scheme: light)" 77 | scheme: default 78 | primary: indigo 79 | accent: indigo 80 | toggle: 81 | icon: material/brightness-7 82 | name: Switch to dark mode 83 | 84 | # Palette toggle for dark mode 85 | - media: "(prefers-color-scheme: dark)" 86 | scheme: slate 87 | primary: indigo 88 | accent: indigo 89 | toggle: 90 | icon: material/brightness-4 91 | name: Switch to light mode 92 | 93 | custom_dir: docs/overrides 94 | favicon: assets/ds-symbol-positive-mono.png 95 | logo: assets/ds-symbol-negative-mono.png 96 | 97 | features: 98 | - content.code.annotate 99 | - content.code.copy 100 | - navigation.indexes 101 | - navigation.instant 102 | - navigation.tracking 103 | - search.suggest 104 | - search.share 105 | 106 | # https://github.com/kylebarron/cogeo-mosaic/blob/mkdocs/mkdocs.yml#L50-L75 107 | markdown_extensions: 108 | - attr_list 109 | - codehilite: 110 | guess_lang: false 111 | - def_list 112 | - footnotes 113 | - markdown_gfm_admonition # support github-flavored admonitions 114 | - pymdownx.emoji: # Convert emoji shortcodes to images 115 | emoji_index: !!python/name:material.extensions.emoji.twemoji 116 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 117 | - pymdownx.magiclink: # render hrefs as links 118 | hide_protocol: true 119 | repo_url_shortener: true # handle github links 120 | - pymdownx.superfences: 121 | custom_fences: 122 | - name: mermaid # support mermaid diagrams 123 | class: mermaid 124 | format: !!python/name:pymdownx.superfences.fence_code_format 125 | - toc: 126 | permalink: true # Add permalink hover link to each heading 127 | - pymdownx.tasklist: 128 | custom_checkbox: true 129 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | stac-pg: 3 | profiles: [""] # default profile 4 | image: ghcr.io/stac-utils/stac-fastapi-pgstac:6.1.1 5 | environment: 6 | STAC_FASTAPI_TITLE: stac-fastapi-pgstac + stac-auth-proxy 7 | STAC_FASTAPI_DESCRIPTION: A STAC FastAPI implemented with pgSTAC, protected with STAC Auth Proxy 8 | APP_HOST: 0.0.0.0 9 | APP_PORT: 8001 10 | CORS_HEADERS: "*" 11 | RELOAD: true 12 | POSTGRES_USER: username 13 | POSTGRES_PASS: password 14 | POSTGRES_DBNAME: postgis 15 | POSTGRES_HOST_READER: database-pg 16 | POSTGRES_HOST_WRITER: database-pg 17 | POSTGRES_PORT: 5432 18 | DB_MIN_CONN_SIZE: 1 19 | DB_MAX_CONN_SIZE: 1 20 | USE_API_HYDRATE: ${USE_API_HYDRATE:-false} 21 | ENABLE_TRANSACTIONS_EXTENSIONS: 1 22 | hostname: stac 23 | ports: 24 | - "8001:8001" 25 | depends_on: 26 | - database-pg 27 | command: python -m stac_fastapi.pgstac.app 28 | # command: bash -c "./scripts/wait-for-it.sh database-pg:5432 && python -m stac_fastapi.pgstac.app" 29 | 30 | stac-os: 31 | profiles: ["os"] 32 | container_name: stac-fastapi-os 33 | image: ghcr.io/stac-utils/stac-fastapi-os:v6.1.0 34 | environment: 35 | STAC_FASTAPI_TITLE: stac-fastapi-opensearch + stac-auth-proxy 36 | STAC_FASTAPI_DESCRIPTION: A STAC FastAPI with an Opensearch backend, protected with STAC Auth Proxy 37 | STAC_FASTAPI_LANDING_PAGE_ID: stac-fastapi-opensearch 38 | APP_HOST: 0.0.0.0 39 | APP_PORT: 8001 40 | RELOAD: true 41 | ENVIRONMENT: local 42 | ES_HOST: database-os 43 | ES_PORT: 9200 44 | ES_USE_SSL: false 45 | ES_VERIFY_CERTS: false 46 | BACKEND: opensearch 47 | STAC_FASTAPI_RATE_LIMIT: 200/minute 48 | hostname: stac 49 | ports: 50 | - "8001:8001" 51 | depends_on: 52 | - database-os 53 | command: | 54 | bash -c "./scripts/wait-for-it-es.sh database-os:9200 && python -m stac_fastapi.opensearch.app" 55 | 56 | database-pg: 57 | profiles: [""] # default profile 58 | container_name: database-pg 59 | image: ghcr.io/stac-utils/pgstac:v0.9.5 60 | environment: 61 | POSTGRES_USER: username 62 | POSTGRES_PASSWORD: password 63 | POSTGRES_DB: postgis 64 | PGUSER: username 65 | PGPASSWORD: password 66 | PGDATABASE: postgis 67 | ports: 68 | - "${MY_DOCKER_IP:-127.0.0.1}:5439:5432" 69 | command: postgres -N 500 70 | volumes: 71 | - ./.pgdata:/var/lib/postgresql/data 72 | 73 | database-os: 74 | profiles: ["os"] 75 | container_name: database-os 76 | image: opensearchproject/opensearch:${OPENSEARCH_VERSION:-2.11.1} 77 | hostname: database-os 78 | environment: 79 | cluster.name: stac-cluster 80 | node.name: os01 81 | http.port: 9200 82 | http.cors.allow-headers: X-Requested-With,Content-Type,Content-Length,Accept,Authorization 83 | discovery.type: single-node 84 | plugins.security.disabled: true 85 | OPENSEARCH_JAVA_OPTS: -Xms512m -Xmx512m 86 | ports: 87 | - "9200:9200" 88 | 89 | proxy: 90 | depends_on: 91 | - oidc 92 | build: 93 | context: . 94 | environment: 95 | UPSTREAM_URL: ${UPSTREAM_URL:-http://stac:8001} 96 | OIDC_DISCOVERY_URL: ${OIDC_DISCOVERY_URL:-http://localhost:8888/.well-known/openid-configuration} 97 | OIDC_DISCOVERY_INTERNAL_URL: ${OIDC_DISCOVERY_INTERNAL_URL:-http://oidc:8888/.well-known/openid-configuration} 98 | env_file: 99 | - path: .env 100 | required: false 101 | ports: 102 | - "8000:8000" 103 | volumes: 104 | - ./src:/app/src 105 | 106 | oidc: 107 | image: ghcr.io/alukach/mock-oidc-server:latest 108 | environment: 109 | ISSUER: http://localhost:8888 110 | SCOPES: item:create,item:update,item:delete,collection:create,collection:update,collection:delete 111 | PORT: 8888 112 | ports: 113 | - "8888:8888" 114 | 115 | stac-browser: 116 | image: ghcr.io/radiantearth/stac-browser:latest 117 | ports: 118 | - 8080:8080 119 | environment: 120 | SB_catalogUrl: "http://localhost:8000" 121 | SB_authConfig: | 122 | { 123 | "type": "openIdConnect", 124 | "openIdConnectUrl": "http://localhost:8888/.well-known/openid-configuration", 125 | "oidcOptions": { 126 | "client_id": "stac-browser" 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/config.py: -------------------------------------------------------------------------------- 1 | """Configuration for the STAC Auth Proxy.""" 2 | 3 | import importlib 4 | from typing import Any, Literal, Optional, Sequence, TypeAlias, Union 5 | 6 | from pydantic import BaseModel, Field, field_validator, model_validator 7 | from pydantic.networks import HttpUrl 8 | from pydantic_settings import BaseSettings, SettingsConfigDict 9 | 10 | METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"] 11 | EndpointMethods: TypeAlias = dict[str, Sequence[METHODS]] 12 | EndpointMethodsWithScope: TypeAlias = dict[ 13 | str, Sequence[Union[METHODS, tuple[METHODS, str]]] 14 | ] 15 | 16 | _PREFIX_PATTERN = r"^/.*$" 17 | 18 | 19 | def str2list(x: Optional[str] = None) -> Optional[Sequence[str]]: 20 | """Convert string to list based on , delimiter.""" 21 | if x: 22 | return x.replace(" ", "").split(",") 23 | 24 | return None 25 | 26 | 27 | class _ClassInput(BaseModel): 28 | """Input model for dynamically loading a class or function.""" 29 | 30 | cls: str 31 | args: Sequence[str] = Field(default_factory=list) 32 | kwargs: dict[str, str] = Field(default_factory=dict) 33 | 34 | def __call__(self): 35 | """Dynamically load a class and instantiate it with args & kwargs.""" 36 | assert self.cls.count(":") 37 | module_path, class_name = self.cls.rsplit(":", 1) 38 | module = importlib.import_module(module_path) 39 | cls = getattr(module, class_name) 40 | return cls(*self.args, **self.kwargs) 41 | 42 | 43 | class Settings(BaseSettings): 44 | """Configuration settings for the STAC Auth Proxy.""" 45 | 46 | # External URLs 47 | upstream_url: HttpUrl 48 | oidc_discovery_url: HttpUrl 49 | oidc_discovery_internal_url: HttpUrl 50 | allowed_jwt_audiences: Optional[Sequence[str]] = None 51 | 52 | root_path: str = "" 53 | override_host: bool = True 54 | healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz") 55 | wait_for_upstream: bool = True 56 | check_conformance: bool = True 57 | enable_compression: bool = True 58 | 59 | # OpenAPI / Swagger UI 60 | openapi_spec_endpoint: Optional[str] = Field( 61 | pattern=_PREFIX_PATTERN, default="/api" 62 | ) 63 | openapi_auth_scheme_name: str = "oidcAuth" 64 | openapi_auth_scheme_override: Optional[dict] = None 65 | swagger_ui_endpoint: Optional[str] = Field( 66 | pattern=_PREFIX_PATTERN, default="/api.html" 67 | ) 68 | swagger_ui_init_oauth: dict = Field(default_factory=dict) 69 | 70 | # Auth 71 | enable_authentication_extension: bool = True 72 | default_public: bool = False 73 | public_endpoints: EndpointMethods = { 74 | r"^/$": ["GET"], 75 | r"^/api.html$": ["GET"], 76 | r"^/api$": ["GET"], 77 | r"^/conformance$": ["GET"], 78 | r"^/docs/oauth2-redirect": ["GET"], 79 | r"^/healthz": ["GET"], 80 | } 81 | private_endpoints: EndpointMethodsWithScope = { 82 | # https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods 83 | r"^/collections$": ["POST"], 84 | r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"], 85 | # https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods 86 | r"^/collections/([^/]+)/items$": ["POST"], 87 | r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"], 88 | # https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension 89 | r"^/collections/([^/]+)/bulk_items$": ["POST"], 90 | } 91 | 92 | # Filters 93 | items_filter: Optional[_ClassInput] = None 94 | items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)" 95 | collections_filter: Optional[_ClassInput] = None 96 | collections_filter_path: str = r"^/collections(/[^/]+)?$" 97 | 98 | model_config = SettingsConfigDict( 99 | env_nested_delimiter="_", 100 | ) 101 | 102 | @model_validator(mode="before") 103 | @classmethod 104 | def _default_oidc_discovery_internal_url(cls, data: Any) -> Any: 105 | """Set the internal OIDC discovery URL to the public URL if not set.""" 106 | if not data.get("oidc_discovery_internal_url"): 107 | data["oidc_discovery_internal_url"] = data.get("oidc_discovery_url") 108 | return data 109 | 110 | @field_validator("allowed_jwt_audiences", mode="before") 111 | @classmethod 112 | def parse_audience(cls, v) -> Optional[Sequence[str]]: 113 | """Parse a comma separated string list of audiences into a list.""" 114 | return str2list(v) 115 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" 2 | 3 | import json 4 | from dataclasses import dataclass 5 | from logging import getLogger 6 | from typing import Optional 7 | from urllib.parse import parse_qs, urlencode, urlparse, urlunparse 8 | 9 | from cql2 import Expr 10 | from starlette.requests import Request 11 | from starlette.types import ASGIApp, Message, Receive, Scope, Send 12 | 13 | logger = getLogger(__name__) 14 | 15 | 16 | @dataclass(frozen=True) 17 | class Cql2RewriteLinksFilterMiddleware: 18 | """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" 19 | 20 | app: ASGIApp 21 | state_key: str = "cql2_filter" 22 | 23 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 24 | """Replace 'filter' in .links of the JSON response to state before we had applied the filter.""" 25 | if scope["type"] != "http": 26 | return await self.app(scope, receive, send) 27 | 28 | request = Request(scope) 29 | original_filter = request.query_params.get("filter") 30 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) 31 | if cql2_filter is None: 32 | # No filter set, just pass through 33 | return await self.app(scope, receive, send) 34 | 35 | # Intercept the response 36 | response_start = None 37 | body_chunks = [] 38 | more_body = True 39 | 40 | async def send_wrapper(message: Message): 41 | nonlocal response_start, body_chunks, more_body 42 | if message["type"] == "http.response.start": 43 | response_start = message 44 | elif message["type"] == "http.response.body": 45 | body_chunks.append(message.get("body", b"")) 46 | more_body = message.get("more_body", False) 47 | if not more_body: 48 | await self._process_and_send_response( 49 | response_start, body_chunks, send, original_filter 50 | ) 51 | else: 52 | await send(message) 53 | 54 | await self.app(scope, receive, send_wrapper) 55 | 56 | async def _process_and_send_response( 57 | self, 58 | response_start: Message, 59 | body_chunks: list[bytes], 60 | send: Send, 61 | original_filter: Optional[str], 62 | ): 63 | body = b"".join(body_chunks) 64 | try: 65 | data = json.loads(body) 66 | except Exception: 67 | await send(response_start) 68 | await send({"type": "http.response.body", "body": body, "more_body": False}) 69 | return 70 | 71 | cql2_filter = Expr(original_filter) if original_filter else None 72 | links = data.get("links") 73 | if isinstance(links, list): 74 | for link in links: 75 | # Handle filter in query string 76 | if "href" in link: 77 | url = urlparse(link["href"]) 78 | qs = parse_qs(url.query) 79 | if "filter" in qs: 80 | if cql2_filter: 81 | qs["filter"] = [cql2_filter.to_text()] 82 | else: 83 | qs.pop("filter", None) 84 | qs.pop("filter-lang", None) 85 | new_query = urlencode(qs, doseq=True) 86 | link["href"] = urlunparse(url._replace(query=new_query)) 87 | 88 | # Handle filter in body (for POST links) 89 | if "body" in link and isinstance(link["body"], dict): 90 | if "filter" in link["body"]: 91 | if cql2_filter: 92 | link["body"]["filter"] = cql2_filter.to_json() 93 | else: 94 | link["body"].pop("filter", None) 95 | link["body"].pop("filter-lang", None) 96 | 97 | # Send the modified response 98 | new_body = json.dumps(data).encode("utf-8") 99 | 100 | # Patch content-length 101 | headers = [ 102 | (k, v) for k, v in response_start["headers"] if k != b"content-length" 103 | ] 104 | headers.append((b"content-length", str(len(new_body)).encode("latin1"))) 105 | response_start = dict(response_start) 106 | response_start["headers"] = headers 107 | await send(response_start) 108 | await send({"type": "http.response.body", "body": new_body, "more_body": False}) 109 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to remove the application root path from incoming requests and update links in responses.""" 2 | 3 | import logging 4 | import re 5 | from dataclasses import dataclass 6 | from typing import Any, Optional 7 | from urllib.parse import ParseResult, urlparse, urlunparse 8 | 9 | from starlette.datastructures import Headers 10 | from starlette.requests import Request 11 | from starlette.types import ASGIApp, Scope 12 | 13 | from ..utils.middleware import JsonResponseMiddleware 14 | from ..utils.requests import get_base_url 15 | from ..utils.stac import get_links 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @dataclass 21 | class ProcessLinksMiddleware(JsonResponseMiddleware): 22 | """ 23 | Middleware to update links in responses, removing the upstream_url path and adding 24 | the root_path if it exists. 25 | """ 26 | 27 | app: ASGIApp 28 | upstream_url: str 29 | root_path: Optional[str] = None 30 | 31 | json_content_type_expr: str = r"application/(geo\+)?json" 32 | 33 | def should_transform_response(self, request: Request, scope: Scope) -> bool: 34 | """Only transform responses with JSON content type.""" 35 | return bool( 36 | re.match( 37 | self.json_content_type_expr, 38 | Headers(scope=scope).get("content-type", ""), 39 | ) 40 | ) 41 | 42 | def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]: 43 | """Update links in the response to include root_path.""" 44 | # Get the client's actual base URL (accounting for load balancers/proxies) 45 | req_base_url = get_base_url(request) 46 | parsed_req_url = urlparse(req_base_url) 47 | parsed_upstream_url = urlparse(self.upstream_url) 48 | 49 | for link in get_links(data): 50 | try: 51 | self._update_link(link, parsed_req_url, parsed_upstream_url) 52 | except Exception as e: 53 | logger.error( 54 | "Failed to parse link href %r, (ignoring): %s", 55 | link.get("href"), 56 | str(e), 57 | ) 58 | return data 59 | 60 | def _update_link( 61 | self, link: dict[str, Any], request_url: ParseResult, upstream_url: ParseResult 62 | ) -> None: 63 | """ 64 | Ensure that link hrefs that are local to upstream url are rewritten as local to 65 | the proxy. 66 | """ 67 | if "href" not in link: 68 | logger.warning("Link %r has no href", link) 69 | return 70 | 71 | parsed_link = urlparse(link["href"]) 72 | 73 | if parsed_link.netloc not in [ 74 | request_url.netloc, 75 | upstream_url.netloc, 76 | ]: 77 | logger.debug( 78 | "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", 79 | link["href"], 80 | request_url.netloc, 81 | upstream_url.netloc, 82 | ) 83 | return 84 | 85 | # If the link path is not a descendant of the upstream path, don't transform it 86 | if upstream_url.path != "/" and not parsed_link.path.startswith( 87 | upstream_url.path 88 | ): 89 | logger.debug( 90 | "Ignoring link %s because it is not descendant of upstream path (%s)", 91 | link["href"], 92 | upstream_url.path, 93 | ) 94 | return 95 | 96 | # Replace the upstream host with the client's host 97 | if parsed_link.netloc == upstream_url.netloc: 98 | parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace( 99 | scheme=request_url.scheme 100 | ) 101 | 102 | # Remove the upstream prefix from the link path 103 | if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path): 104 | parsed_link = parsed_link._replace( 105 | path=parsed_link.path[len(upstream_url.path) :] 106 | ) 107 | 108 | # Add the root_path to the link if it exists 109 | if self.root_path: 110 | parsed_link = parsed_link._replace( 111 | path=f"{self.root_path}{parsed_link.path}" 112 | ) 113 | 114 | updated_href = urlunparse(parsed_link) 115 | if updated_href == link["href"]: 116 | return 117 | 118 | logger.debug( 119 | "Rewriting %r link %r to %r", 120 | link.get("rel"), 121 | link["href"], 122 | updated_href, 123 | ) 124 | 125 | link["href"] = updated_href 126 | -------------------------------------------------------------------------------- /docs/user-guide/deployment.md: -------------------------------------------------------------------------------- 1 | # Deployment 2 | 3 | ## General 4 | 5 | Deploying the STAC Auth Proxy is similar to deploying any other service. In general, we recommend you mirror the architecture of your other systems. 6 | 7 | The core principles of deploying the STAC Auth Proxy are: 8 | 9 | 1. The STAC API should not be available on the public internet 10 | 2. The STAC Auth Proxy should be able to communicate with both the STAC API and the OIDC Server (namely, the discovery endpoint and JWKS endpoint) 11 | 12 | ### Networking Considerations 13 | 14 | #### Hiding the STAC API 15 | 16 | The STAC API should not be directly accessible from the public internet. The STAC Auth Proxy acts as the public-facing endpoint. 17 | 18 | ##### AWS Strategy 19 | 20 | - Place the STAC API in a private subnet 21 | - Place the STAC Auth Proxy in a public subnet with internet access 22 | - Use security groups to restrict access between components 23 | 24 | ##### Kubernetes Strategy 25 | 26 | - Deploy the STAC API as an internal service (ClusterIP) 27 | - Deploy the STAC Auth Proxy with an Ingress for external access 28 | - Use network policies to control traffic flow 29 | 30 | #### Communicating with the OIDC Server 31 | 32 | The STAC Auth Proxy needs to communicate with your OIDC provider for authentication. If your OIDC server is not directly available to the STAC Auth Proxy, use [`OIDC_DISCOVERY_INTERNAL_URL`](configuration.md#oidc_discovery_internal_url) (the [`OIDC_DISCOVERY_URL`](configuration.md#oidc_discovery_url) will still be used for auth in the browser, such as the Swagger UI page). 33 | 34 | ## AWS Lambda 35 | 36 | For AWS Lambda deployments, we recommend using the [Mangum](https://pypi.org/project/mangum/) handler with disabled lifespan events. Such a handler is available at `stac_auth_proxy.lambda:handler`. 37 | 38 | > [!TIP] 39 | > 40 | > If using `stac_auth_proxy.lambda:handler`, be sure to install the `lambda` optional dependencies: 41 | > 42 | > ```bash 43 | > pip install stac_auth_proxy[lambda] 44 | > ``` 45 | 46 | ### CDK 47 | 48 | If using [AWS CDK](https://docs.aws.amazon.com/cdk/), a [`StacAuthProxy` Construct](https://developmentseed.org/eoapi-cdk/#stacauthproxylambda-) is made available within the [`eoapi-cdk`](https://github.com/developmentseed/eoapi-cdk) project. 49 | 50 | ## Docker 51 | 52 | The STAC Auth Proxy is available as a [Docker image from the GitHub Container Registry (GHCR)](https://github.com/developmentseed/stac-auth-proxy/pkgs/container/stac-auth-proxy). 53 | 54 | ```bash 55 | # Latest version 56 | docker pull ghcr.io/developmentseed/stac-auth-proxy:latest 57 | 58 | # Specific version 59 | docker pull ghcr.io/developmentseed/stac-auth-proxy:v0.7.1 60 | ``` 61 | 62 | ## Kubernetes 63 | 64 | The STAC Auth Proxy can be deployed to Kubernetes via the [Helm Chart available on the GitHub Container Registry (GHCR)](https://github.com/developmentseed/stac-auth-proxy/pkgs/container/stac-auth-proxy%2Fcharts%2Fstac-auth-proxy). 65 | 66 | ### Prerequisites 67 | 68 | - Kubernetes 1.19+ 69 | - Helm 3.2.0+ 70 | 71 | ### Installation 72 | 73 | ```bash 74 | # Add the Helm repository 75 | helm registry login ghcr.io 76 | 77 | # Install with minimal configuration 78 | helm install stac-auth-proxy oci://ghcr.io/developmentseed/stac-auth-proxy/charts/stac-auth-proxy \ 79 | --set env.UPSTREAM_URL=https://your-stac-api.com/stac \ 80 | --set env.OIDC_DISCOVERY_URL=https://your-auth-server/.well-known/openid-configuration \ 81 | --set ingress.host=stac-proxy.your-domain.com 82 | ``` 83 | 84 | ### Configuration 85 | 86 | | Parameter | Description | Required | Default | 87 | | ------------------------ | --------------------------------------------- | -------- | ------- | 88 | | `env.UPSTREAM_URL` | URL of the STAC API to proxy | Yes | - | 89 | | `env.OIDC_DISCOVERY_URL` | OpenID Connect discovery document URL | Yes | - | 90 | | `env` | Environment variables passed to the container | No | `{}` | 91 | | `ingress.enabled` | Enable ingress | No | `true` | 92 | | `ingress.className` | Ingress class name | No | `nginx` | 93 | | `ingress.host` | Hostname for the ingress | No | `""` | 94 | | `ingress.tls.enabled` | Enable TLS for ingress | No | `true` | 95 | | `replicaCount` | Number of replicas | No | `1` | 96 | 97 | For a complete list of values, see the [values.yaml](https://github.com/developmentseed/stac-auth-proxy/blob/main/helm/values.yaml) file. 98 | 99 | ### Management 100 | 101 | ```bash 102 | # Upgrade 103 | helm upgrade stac-auth-proxy oci://ghcr.io/developmentseed/stac-auth-proxy/charts/stac-auth-proxy 104 | 105 | # Uninstall 106 | helm uninstall stac-auth-proxy 107 | ``` 108 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to build the Cql2Filter.""" 2 | 3 | import logging 4 | import re 5 | from dataclasses import dataclass 6 | from typing import Any, Awaitable, Callable, Optional 7 | 8 | from cql2 import Expr, ValidationError 9 | from starlette.requests import Request 10 | from starlette.responses import Response 11 | from starlette.types import ASGIApp, Receive, Scope, Send 12 | 13 | from ..utils import requests 14 | from ..utils.middleware import required_conformance 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @required_conformance( 20 | "http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", 21 | "http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", 22 | "http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", 23 | ) 24 | @dataclass(frozen=True) 25 | class Cql2BuildFilterMiddleware: 26 | """Middleware to build the Cql2Filter.""" 27 | 28 | app: ASGIApp 29 | 30 | state_key: str = "cql2_filter" 31 | 32 | # Filters 33 | collections_filter: Optional[Callable] = None 34 | collections_filter_path: str = r"^/collections(/[^/]+)?$" 35 | items_filter: Optional[Callable] = None 36 | items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)" 37 | 38 | def __post_init__(self): 39 | """Set required conformances based on the filter functions.""" 40 | required_conformances = set() 41 | if self.collections_filter: 42 | logger.debug("Appending required conformance for collections filter") 43 | # https://github.com/stac-api-extensions/collection-search/blob/4825b4b1cee96bdc0cbfbb342d5060d0031976f0/README.md#L5 44 | required_conformances.update( 45 | [ 46 | "https://api.stacspec.org/v1.0.0/core", 47 | r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search", 48 | r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search#filter", 49 | "http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query", 50 | ] 51 | ) 52 | if self.items_filter: 53 | logger.debug("Appending required conformance for items filter") 54 | # https://github.com/stac-api-extensions/filter/blob/c763dbbf0a52210ab8d9866ff048da448d270f93/README.md#conformance-classes 55 | required_conformances.update( 56 | [ 57 | "http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter", 58 | "http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter", 59 | r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/item-search#filter", 60 | ] 61 | ) 62 | 63 | # Must set required conformances on class 64 | self.__class__.__required_conformances__ = required_conformances.union( 65 | getattr(self.__class__, "__required_conformances__", []) 66 | ) 67 | 68 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 69 | """Build the CQL2 filter, place on the request state.""" 70 | if scope["type"] != "http": 71 | return await self.app(scope, receive, send) 72 | 73 | request = Request(scope) 74 | 75 | filter_builder = self._get_filter(request.url.path) 76 | if not filter_builder: 77 | return await self.app(scope, receive, send) 78 | 79 | filter_expr = await filter_builder( 80 | { 81 | "req": { 82 | "path": request.url.path, 83 | "method": request.method, 84 | "query_params": dict(request.query_params), 85 | "path_params": requests.extract_variables(request.url.path), 86 | "headers": dict(request.headers), 87 | }, 88 | **scope["state"], 89 | } 90 | ) 91 | cql2_filter = Expr(filter_expr) 92 | try: 93 | cql2_filter.validate() 94 | except ValidationError: 95 | logger.error("Invalid CQL2 filter: %s", filter_expr) 96 | return Response(status_code=502, content="Invalid CQL2 filter") 97 | setattr(request.state, self.state_key, cql2_filter) 98 | 99 | return await self.app(scope, receive, send) 100 | 101 | def _get_filter( 102 | self, path: str 103 | ) -> Optional[Callable[..., Awaitable[str | dict[str, Any]]]]: 104 | """Get the CQL2 filter builder for the given path.""" 105 | endpoint_filters = [ 106 | (self.collections_filter_path, self.collections_filter), 107 | (self.items_filter_path, self.items_filter), 108 | ] 109 | for expr, builder in endpoint_filters: 110 | if re.match(expr, path): 111 | return builder 112 | return None 113 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to validate the response body with a CQL2 filter for single-record endpoints.""" 2 | 3 | import json 4 | import re 5 | from dataclasses import dataclass 6 | from logging import getLogger 7 | from typing import Optional 8 | 9 | from cql2 import Expr 10 | from starlette.requests import Request 11 | from starlette.types import ASGIApp, Message, Receive, Scope, Send 12 | 13 | from ..utils.middleware import required_conformance 14 | 15 | logger = getLogger(__name__) 16 | 17 | 18 | @required_conformance( 19 | r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", 20 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", 21 | r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", 22 | ) 23 | @dataclass 24 | class Cql2ValidateResponseBodyMiddleware: 25 | """ASGI middleware to validate the response body with a CQL2 filter for single-record endpoints.""" 26 | 27 | app: ASGIApp 28 | state_key: str = "cql2_filter" 29 | 30 | single_record_endpoints = [ 31 | r"^/collections/([^/]+)/items/([^/]+)$", 32 | r"^/collections/([^/]+)$", 33 | ] 34 | 35 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 36 | """Validate the response body with a CQL2 filter for single-record endpoints.""" 37 | if scope["type"] != "http": 38 | return await self.app(scope, receive, send) 39 | 40 | request = Request(scope) 41 | cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) 42 | if not cql2_filter: 43 | return await self.app(scope, receive, send) 44 | 45 | if not any( 46 | re.match(expr, request.url.path) for expr in self.single_record_endpoints 47 | ): 48 | return await self.app(scope, receive, send) 49 | 50 | # Intercept the response 51 | response_start = None 52 | body_chunks = [] 53 | more_body = True 54 | 55 | async def send_wrapper(message: Message): 56 | nonlocal response_start, body_chunks, more_body 57 | if message["type"] == "http.response.start": 58 | response_start = message 59 | elif message["type"] == "http.response.body": 60 | body_chunks.append(message.get("body", b"")) 61 | more_body = message.get("more_body", False) 62 | if not more_body: 63 | await self._process_and_send_response( 64 | response_start, body_chunks, send, cql2_filter 65 | ) 66 | else: 67 | await send(message) 68 | 69 | await self.app(scope, receive, send_wrapper) 70 | 71 | async def _process_and_send_response( 72 | self, response_start, body_chunks, send, cql2_filter 73 | ): 74 | body = b"".join(body_chunks) 75 | try: 76 | body_json = json.loads(body) 77 | except json.JSONDecodeError: 78 | logger.warning("Failed to parse response body as JSON") 79 | await self._send_json_response( 80 | send, 81 | status=502, 82 | content={ 83 | "code": "ParseError", 84 | "description": "Failed to parse response body as JSON", 85 | }, 86 | ) 87 | return 88 | 89 | try: 90 | cql2_matches = cql2_filter.matches(body_json) 91 | except Exception as e: 92 | cql2_matches = False 93 | logger.warning("Failed to apply filter: %s", e) 94 | 95 | if cql2_matches: 96 | logger.debug("Response matches filter, returning record") 97 | # Send the original response start 98 | await send(response_start) 99 | # Send the filtered body 100 | await send( 101 | { 102 | "type": "http.response.body", 103 | "body": json.dumps(body_json).encode("utf-8"), 104 | "more_body": False, 105 | } 106 | ) 107 | else: 108 | logger.debug("Response did not match filter, returning 404") 109 | await self._send_json_response( 110 | send, 111 | status=404, 112 | content={"code": "NotFoundError", "description": "Record not found."}, 113 | ) 114 | 115 | async def _send_json_response(self, send, status, content): 116 | response_bytes = json.dumps(content).encode("utf-8") 117 | await send( 118 | { 119 | "type": "http.response.start", 120 | "status": status, 121 | "headers": [ 122 | (b"content-type", b"application/json"), 123 | (b"content-length", str(len(response_bytes)).encode("latin1")), 124 | ], 125 | } 126 | ) 127 | await send( 128 | { 129 | "type": "http.response.body", 130 | "body": response_bytes, 131 | "more_body": False, 132 | } 133 | ) 134 | -------------------------------------------------------------------------------- /docs/user-guide/route-level-auth.md: -------------------------------------------------------------------------------- 1 | # Route-Level Authorization 2 | 3 | Route-level authorization can provide a base layer of security for the simplest use cases. This typically looks like: 4 | 5 | - the entire catalog being private, available only to authenticated users 6 | - most of the catalog being public, available to anonymous or authenticated users. However, a subset of endpoints (typically the [transactions extension](https://github.com/stac-api-extensions/transaction) endpoints) are only available to all or a subset of authenticated users 7 | 8 | ## Configuration Variables 9 | 10 | Route-level authorization is controlled by three key environment variables: 11 | 12 | - **[`DEFAULT_PUBLIC`](../../configuration/#default_public)**: Sets the default access policy for all endpoints 13 | - **[`PUBLIC_ENDPOINTS`](../../configuration/#public_endpoints)**: Marks endpoints as not requiring authentication (used only when `DEFAULT_PUBLIC=false`). By default, we keep the catalog root, OpenAPI spec, Swagger UI, Swagger UI auth redirect, and the proxy health endpoint as public. Note that these are all endpoints that don't serve actual STAC data; they only acknowledge the presence of a STAC catalog. This is defined by a mapping of regex path expressions to arrays of HTTP methods. 14 | - **[`PRIVATE_ENDPOINTS`](../../configuration/#private_endpoints)**: Marks endpoints as requiring authentication. By default, the transactions endpoints are all marked as private. This is defined by a mapping of regex path expressions to arrays of either HTTP methods or tuples of HTTP methods and space-separated required scopes. 15 | 16 | > [!TIP] 17 | > 18 | > Users typically don't need to specify both `PRIVATE_ENDPOINTS` and `PUBLIC_ENDPOINTS`. 19 | 20 | ## Strategies 21 | 22 | ### Private by Default 23 | 24 | Make the entire STAC API private, requiring authentication for all endpoints. 25 | 26 | > [!NOTE] 27 | > 28 | > This is the out-of-the-box configuration of the STAC Auth Proxy. 29 | 30 | **Configuration** 31 | 32 | ```bash 33 | # Set default policy to private 34 | DEFAULT_PUBLIC=false 35 | 36 | # The default public endpoints are typically sufficient. Otherwise, they can be specified. 37 | # PUBLIC_ENDPOINTS='{ ... }' 38 | ``` 39 | 40 | **Behavior** 41 | 42 | - All endpoints require authentication by default 43 | - Only explicitly listed endpoints in `PUBLIC_ENDPOINTS` are accessible without authentication. By default, these are endpoints that don't reveal STAC data 44 | - Useful for internal or enterprise STAC APIs where all data should be protected 45 | 46 | ### Public by Default with Protected Write Operations 47 | 48 | Make the STAC API mostly public for read operations, but require authentication for write operations (transactions extension). 49 | 50 | **Configuration** 51 | 52 | ```bash 53 | # Set default policy to public 54 | DEFAULT_PUBLIC=true 55 | 56 | # The default private endpoints are typically sufficient. Otherwise, they can be specified. 57 | # PRIVATE_ENDPOINTS='{ ... }' 58 | ``` 59 | 60 | **Behavior** 61 | 62 | - Read operations (GET requests) are accessible to everyone 63 | - Write operations require authentication 64 | - Default configuration matches this pattern 65 | - Ideal for public STAC catalogs where data discovery is open but modifications are restricted 66 | 67 | ### Authenticated Access with Scope-based Authorization 68 | 69 | For a level of control beyond simple anonymous vs. authenticated status, the proxy can be configured so that path/method access requires JWTs containing particular permissions in the form of the [scopes claim](https://datatracker.ietf.org/doc/html/rfc8693#name-scope-scopes-claim). 70 | 71 | **Configuration** 72 | 73 | For granular permissions on a public API: 74 | 75 | ```bash 76 | # Set default policy to public 77 | DEFAULT_PUBLIC=true 78 | 79 | # Require specific scopes for write operations 80 | PRIVATE_ENDPOINTS='{ 81 | "^/collections$": [["POST", "collection:create"]], 82 | "^/collections/([^/]+)$": [["PUT", "collection:update"], ["PATCH", "collection:update"], ["DELETE", "collection:delete"]], 83 | "^/collections/([^/]+)/items$": [["POST", "item:create"]], 84 | "^/collections/([^/]+)/items/([^/]+)$": [["PUT", "item:update"], ["PATCH", "item:update"], ["DELETE", "item:delete"]], 85 | "^/collections/([^/]+)/bulk_items$": [["POST", "item:create"]] 86 | }' 87 | ``` 88 | 89 | For role-based permissions on a private API: 90 | 91 | ```bash 92 | # Set default policy to private 93 | DEFAULT_PUBLIC=false 94 | 95 | # Require specific scopes for write operations 96 | PRIVATE_ENDPOINTS='{ 97 | "^/collections$": [["POST", "admin"]], 98 | "^/collections/([^/]+)$": [["PUT", "admin"], ["PATCH", "admin"], ["DELETE", "admin"]], 99 | "^/collections/([^/]+)/items$": [["POST", "editor"]], 100 | "^/collections/([^/]+)/items/([^/]+)$": [["PUT", "editor"], ["PATCH", "editor"], ["DELETE", "editor"]], 101 | "^/collections/([^/]+)/bulk_items$": [["POST", "editor"]] 102 | }' 103 | ``` 104 | 105 | **Behavior** 106 | 107 | - Users must be authenticated AND have the required scope(s) 108 | - Different HTTP methods can require different scopes 109 | - Scopes are checked against the user's JWT scope claim 110 | - Unauthorized requests receive a 401 Unauthorized response 111 | 112 | > [!TIP] 113 | > 114 | > Multiple scopes can be provided in a space-separated format, such as `["POST", "scope_a scope_b scope_c"]`. These scope requirements are applied with AND logic, meaning that the incoming JWT must contain all the mentioned scopes. 115 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/utils/middleware.py: -------------------------------------------------------------------------------- 1 | """Utilities for middleware response handling.""" 2 | 3 | import json 4 | import logging 5 | from abc import ABC, abstractmethod 6 | from typing import Any, Optional 7 | 8 | from starlette.datastructures import MutableHeaders 9 | from starlette.requests import Request 10 | from starlette.responses import JSONResponse 11 | from starlette.types import ASGIApp, Message, Receive, Scope, Send 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class JsonResponseMiddleware(ABC): 17 | """Base class for middleware that transforms JSON response bodies.""" 18 | 19 | app: ASGIApp 20 | 21 | # Expected data type for JSON responses. Only responses matching this type will be transformed. 22 | # If None, all JSON responses will be transformed regardless of type. 23 | expected_data_type: Optional[type] = dict 24 | 25 | @abstractmethod 26 | def should_transform_response( 27 | self, request: Request, scope: Scope 28 | ) -> bool: # mypy: ignore 29 | """ 30 | Determine if this response should be transformed. At a minimum, this 31 | should check the request's path and content type. 32 | 33 | Returns 34 | ------- 35 | bool: True if the response should be transformed 36 | 37 | """ 38 | ... 39 | 40 | @abstractmethod 41 | def transform_json(self, data: Any, request: Request) -> Any: 42 | """ 43 | Transform the JSON data. 44 | 45 | Args: 46 | data: The parsed JSON data 47 | request: The HTTP request object 48 | 49 | Returns: 50 | ------- 51 | The transformed JSON data 52 | 53 | """ 54 | ... 55 | 56 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 57 | """Process the request/response.""" 58 | if scope["type"] != "http": 59 | return await self.app(scope, receive, send) 60 | 61 | start_message: Optional[Message] = None 62 | body = b"" 63 | 64 | async def transform_response(message: Message) -> None: 65 | nonlocal start_message 66 | nonlocal body 67 | 68 | start_message = start_message or message 69 | headers = MutableHeaders(scope=start_message) 70 | request = Request(scope) 71 | 72 | if not self.should_transform_response( 73 | request=request, 74 | scope=start_message, 75 | ): 76 | # For non-JSON responses, send the start message immediately 77 | await send(message) 78 | return 79 | 80 | # Delay sending start message until we've processed the body 81 | if message["type"] == "http.response.start": 82 | return 83 | 84 | body += message["body"] 85 | 86 | # Skip body chunks until all chunks have been received 87 | if message.get("more_body"): 88 | return 89 | 90 | # Transform the JSON body 91 | if body: 92 | try: 93 | data = json.loads(body) 94 | except json.JSONDecodeError as e: 95 | logger.error("Error parsing JSON: %s", e) 96 | logger.error("Body: %s", body) 97 | logger.error("Response scope: %s", scope) 98 | response = JSONResponse( 99 | {"error": "Received invalid JSON from upstream server"}, 100 | status_code=502, 101 | ) 102 | await response(scope, receive, send) 103 | return 104 | 105 | if self.expected_data_type is None or isinstance( 106 | data, self.expected_data_type 107 | ): 108 | transformed = self.transform_json(data, request=request) 109 | body = json.dumps(transformed).encode() 110 | else: 111 | logger.warning( 112 | "Received JSON response with unexpected data type %r from upstream server (%r %r), " 113 | "skipping transformation (expected: %r)", 114 | type(data).__name__, 115 | request.method, 116 | request.url, 117 | self.expected_data_type.__name__, 118 | ) 119 | 120 | # Update content-length header 121 | headers["content-length"] = str(len(body)) 122 | start_message["headers"] = [ 123 | (key.encode(), value.encode()) for key, value in headers.items() 124 | ] 125 | 126 | # Send response 127 | await send(start_message) 128 | await send( 129 | { 130 | "type": "http.response.body", 131 | "body": body, 132 | "more_body": False, 133 | } 134 | ) 135 | 136 | return await self.app(scope, receive, transform_response) 137 | 138 | 139 | def required_conformance( 140 | *conformances: str, 141 | attr_name: str = "__required_conformances__", 142 | ): 143 | """Register required conformance classes with a middleware class.""" 144 | 145 | def decorator(middleware): 146 | setattr(middleware, attr_name, list(conformances)) 147 | return middleware 148 | 149 | return decorator 150 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/handlers/reverse_proxy.py: -------------------------------------------------------------------------------- 1 | """Tooling to manage the reverse proxying of requests to an upstream STAC API.""" 2 | 3 | import logging 4 | import time 5 | from dataclasses import dataclass, field 6 | 7 | import httpx 8 | from fastapi import Request 9 | from starlette.datastructures import MutableHeaders 10 | from starlette.responses import Response 11 | 12 | from stac_auth_proxy.utils.requests import build_server_timing_header 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @dataclass 18 | class ReverseProxyHandler: 19 | """Reverse proxy functionality.""" 20 | 21 | upstream: str 22 | client: httpx.AsyncClient = None 23 | timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0)) 24 | 25 | proxy_name: str = "stac-auth-proxy" 26 | override_host: bool = True 27 | legacy_forwarded_headers: bool = False 28 | 29 | def __post_init__(self): 30 | """Initialize the HTTP client.""" 31 | self.client = self.client or httpx.AsyncClient( 32 | base_url=self.upstream, 33 | timeout=self.timeout, 34 | http2=True, 35 | ) 36 | 37 | def _prepare_headers(self, request: Request) -> MutableHeaders: 38 | """ 39 | Prepare headers for the proxied request. Construct a Forwarded header to inform 40 | the upstream API about the original request context, which will allow it to 41 | properly construct URLs in responses (namely, in the Links). If there are 42 | existing X-Forwarded-*/Forwarded headers (typically, in situations where the 43 | STAC Auth Proxy is behind a proxy like Traefik or NGINX), we use those values. 44 | """ 45 | headers = MutableHeaders(request.headers) 46 | headers.setdefault("Via", f"1.1 {self.proxy_name}") 47 | 48 | proxy_client = headers.get( 49 | "X-Forwarded-For", request.client.host if request.client else "unknown" 50 | ) 51 | proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme) 52 | proxy_host = headers.get("X-Forwarded-Host", request.url.netloc) 53 | proxy_port = str(headers.get("X-Forwarded-Port", request.url.port)) 54 | proxy_path = headers.get("X-Forwarded-Path", request.base_url.path) 55 | 56 | # NOTE: If we don't include a port, it's possible that the upstream server may 57 | # mistakenly use the port from the Host header (which may be the internal port 58 | # of the upstream server) when constructing URLs. 59 | forwarded_host = proxy_host 60 | if proxy_port: 61 | forwarded_host = f"{forwarded_host}:{proxy_port}" 62 | 63 | headers.setdefault( 64 | "Forwarded", 65 | f"for={proxy_client};host={forwarded_host};proto={proxy_proto};path={proxy_path}", 66 | ) 67 | 68 | # NOTE: This is useful if the upstream API does not support the Forwarded header 69 | # and there were no existing X-Forwarded-* headers on the incoming request. 70 | if self.legacy_forwarded_headers: 71 | headers.setdefault("X-Forwarded-For", proxy_client) 72 | headers.setdefault("X-Forwarded-Host", proxy_host) 73 | headers.setdefault("X-Forwarded-Path", proxy_path) 74 | headers.setdefault("X-Forwarded-Proto", proxy_proto) 75 | headers.setdefault("X-Forwarded-Port", proxy_port) 76 | 77 | # Set host to the upstream host 78 | if self.override_host: 79 | headers["Host"] = self.client.base_url.netloc.decode("utf-8") 80 | 81 | return headers 82 | 83 | async def proxy_request(self, request: Request) -> Response: 84 | """Proxy a request to the upstream STAC API.""" 85 | headers = self._prepare_headers(request) 86 | 87 | # https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466 88 | rp_req = self.client.build_request( 89 | request.method, 90 | url=httpx.URL( 91 | path=request.url.path, 92 | query=request.url.query.encode("utf-8"), 93 | ), 94 | headers=headers, 95 | content=request.stream(), 96 | ) 97 | 98 | # NOTE: HTTPX adds headers, so we need to trim them before sending request 99 | for h in rp_req.headers: 100 | if h not in headers: 101 | del rp_req.headers[h] 102 | 103 | logger.debug(f"Proxying request to {rp_req.url}") 104 | 105 | start_time = time.perf_counter() 106 | rp_resp = await self.client.send(rp_req, stream=True) 107 | proxy_time = time.perf_counter() - start_time 108 | rp_resp.headers["Server-Timing"] = build_server_timing_header( 109 | rp_resp.headers.get("Server-Timing"), 110 | name="upstream", 111 | dur=proxy_time, 112 | desc="Upstream processing time", 113 | ) 114 | logger.debug( 115 | f"Received response status {rp_resp.status_code!r} from {rp_req.url} in {proxy_time:.3f}s" 116 | ) 117 | 118 | # We read the content here to make use of HTTPX's decompression, ensuring we have 119 | # non-compressed content for the middleware to work with. 120 | content = await rp_resp.aread() 121 | if rp_resp.headers.get("Content-Encoding"): 122 | del rp_resp.headers["Content-Encoding"] 123 | 124 | return Response( 125 | content=content, 126 | status_code=rp_resp.status_code, 127 | headers=dict(rp_resp.headers), 128 | ) 129 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/lifespan.py: -------------------------------------------------------------------------------- 1 | """Reusable lifespan handler for FastAPI applications.""" 2 | 3 | import asyncio 4 | import logging 5 | import re 6 | from contextlib import asynccontextmanager 7 | from typing import Any 8 | 9 | import httpx 10 | from fastapi import FastAPI 11 | from pydantic import HttpUrl 12 | from starlette.middleware import Middleware 13 | 14 | from .config import Settings 15 | 16 | logger = logging.getLogger(__name__) 17 | __all__ = ["build_lifespan", "check_conformance", "check_server_health"] 18 | 19 | 20 | async def check_server_healths(*urls: str | HttpUrl) -> None: 21 | """Wait for upstream APIs to become available.""" 22 | logger.info("Running upstream server health checks...") 23 | for url in urls: 24 | await check_server_health(url) 25 | logger.info( 26 | "Upstream servers are healthy:\n%s", 27 | "\n".join([f" - {url}" for url in urls]), 28 | ) 29 | 30 | 31 | async def check_server_health( 32 | url: str | HttpUrl, 33 | max_retries: int = 10, 34 | retry_delay: float = 1.0, 35 | retry_delay_max: float = 5.0, 36 | timeout: float = 5.0, 37 | ) -> None: 38 | """Wait for upstream API to become available.""" 39 | # Convert url to string if it's a HttpUrl 40 | if isinstance(url, HttpUrl): 41 | url = str(url) 42 | 43 | async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: 44 | for attempt in range(max_retries): 45 | try: 46 | response = await client.get(url) 47 | response.raise_for_status() 48 | logger.info(f"Upstream API {url!r} is healthy") 49 | return 50 | except httpx.ConnectError as e: 51 | logger.warning(f"Upstream health check for {url!r} failed: {e}") 52 | retry_in = min(retry_delay * (2**attempt), retry_delay_max) 53 | logger.warning( 54 | f"Upstream API {url!r} not healthy, retrying in {retry_in:.1f}s " 55 | f"(attempt {attempt + 1}/{max_retries})" 56 | ) 57 | await asyncio.sleep(retry_in) 58 | 59 | raise RuntimeError( 60 | f"Upstream API {url!r} failed to respond after {max_retries} attempts" 61 | ) 62 | 63 | 64 | async def check_conformance( 65 | middleware_classes: list[Middleware], 66 | api_url: str, 67 | attr_name: str = "__required_conformances__", 68 | endpoint: str = "/conformance", 69 | ): 70 | """Check if the upstream API supports a given conformance class.""" 71 | required_conformances: dict[str, list[str]] = {} 72 | for middleware in middleware_classes: 73 | for conformance in getattr(middleware.cls, attr_name, []): 74 | required_conformances.setdefault(conformance, []).append( 75 | middleware.cls.__name__ 76 | ) 77 | 78 | async with httpx.AsyncClient(base_url=api_url) as client: 79 | response = await client.get(endpoint) 80 | response.raise_for_status() 81 | api_conforms_to = response.json().get("conformsTo", []) 82 | 83 | missing = [ 84 | req_conformance 85 | for req_conformance in required_conformances.keys() 86 | if not any( 87 | re.match(req_conformance, conformance) for conformance in api_conforms_to 88 | ) 89 | ] 90 | 91 | def conformance_str(conformance: str) -> str: 92 | return f" - {conformance} [{','.join(required_conformances[conformance])}]" 93 | 94 | if missing: 95 | missing_str = [conformance_str(c) for c in missing] 96 | raise RuntimeError( 97 | "\n".join( 98 | [ 99 | "Upstream catalog is missing the following conformance classes:", 100 | *missing_str, 101 | ] 102 | ) 103 | ) 104 | logger.info( 105 | "Upstream catalog conforms to the following required conformance classes: \n%s", 106 | "\n".join([conformance_str(c) for c in required_conformances]), 107 | ) 108 | 109 | 110 | def build_lifespan(settings: Settings | None = None, **settings_kwargs: Any): 111 | """ 112 | Create a lifespan handler that runs startup checks. 113 | 114 | Parameters 115 | ---------- 116 | settings : Settings | None, optional 117 | Pre-built settings instance. If omitted, a new one is constructed from 118 | ``settings_kwargs``. 119 | **settings_kwargs : Any 120 | Keyword arguments used to configure the health and conformance checks if 121 | ``settings`` is not provided. 122 | 123 | Returns 124 | ------- 125 | Callable[[FastAPI], AsyncContextManager[Any]] 126 | A callable suitable for the ``lifespan`` parameter of ``FastAPI``. 127 | 128 | """ 129 | if settings is None: 130 | settings = Settings(**settings_kwargs) 131 | 132 | @asynccontextmanager 133 | async def lifespan(app: "FastAPI"): 134 | assert settings is not None # Required for type checking 135 | 136 | # Wait for upstream servers to become available 137 | if settings.wait_for_upstream: 138 | await check_server_healths( 139 | settings.upstream_url, settings.oidc_discovery_internal_url 140 | ) 141 | 142 | # Log all middleware connected to the app 143 | logger.info( 144 | "Connected middleware:\n%s", 145 | "\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]), 146 | ) 147 | 148 | if settings.check_conformance: 149 | await check_conformance(app.user_middleware, str(settings.upstream_url)) 150 | 151 | yield 152 | 153 | return lifespan 154 | -------------------------------------------------------------------------------- /docs/architecture/middleware-stack.md: -------------------------------------------------------------------------------- 1 | # Middleware Stack 2 | 3 | Aside from the actual communication with the upstream STAC API, the majority of the proxy's functionality occurs within a chain of middlewares. Each request passes through this chain, wherein each middleware performs a specific task. The middleware chain is ordered from last added (first to run) to first added (last to run). 4 | 5 | > [!TIP] 6 | > If you want to apply just the middleware onto your existing FastAPI application, you can do this with [`configure_app`][stac_auth_proxy.configure_app] rather than setting up a separate proxy application. 7 | 8 | > [!IMPORTANT] 9 | > The order of middleware execution is **critical**. For example, `RemoveRootPathMiddleware` must run before `EnforceAuthMiddleware` so that authentication decisions are made on the correct path after root path removal. 10 | 11 | 1. **[`CompressionMiddleware`](https://github.com/developmentseed/starlette-cramjam)** 12 | 13 | - **Enabled if:** [`ENABLE_COMPRESSION`](../../user-guide/configuration#enable_compression) is enabled 14 | - Handles response compression 15 | - Reduces response size for better performance 16 | 17 | 2. **[`RemoveRootPathMiddleware`][stac_auth_proxy.middleware.RemoveRootPathMiddleware]** 18 | 19 | - **Enabled if:** [`ROOT_PATH`](../../user-guide/configuration#root_path) is configured 20 | - Removes the application root path from incoming requests 21 | - Ensures requests are properly routed to upstream API 22 | 23 | 3. **[`ProcessLinksMiddleware`][stac_auth_proxy.middleware.ProcessLinksMiddleware]** 24 | 25 | - **Enabled if:** [`ROOT_PATH`](../../user-guide/configuration#root_path) is set or [`UPSTREAM_URL`](../../user-guide/configuration#upstream_url) path is not `"/"` 26 | - Updates links in JSON responses to handle root path and upstream URL path differences 27 | - Removes upstream URL path from links and adds root path if configured 28 | 29 | 4. **[`EnforceAuthMiddleware`][stac_auth_proxy.middleware.EnforceAuthMiddleware]** 30 | 31 | - **Enabled if:** Always active (core authentication middleware) 32 | - Handles authentication and authorization 33 | - Configurable public/private endpoints via [`PUBLIC_ENDPOINTS`](../../user-guide/configuration#public_endpoints) and [`PRIVATE_ENDPOINTS`](../../user-guide/configuration#private_endpoints) 34 | - OIDC integration via [`OIDC_DISCOVERY_INTERNAL_URL`](../../user-guide/configuration#oidc_discovery_internal_url) 35 | - JWT audience validation via [`ALLOWED_JWT_AUDIENCES`](../../user-guide/configuration#allowed_jwt_audiences) 36 | - Places auth token payload in request state 37 | 38 | 5. **[`AddProcessTimeHeaderMiddleware`][stac_auth_proxy.middleware.AddProcessTimeHeaderMiddleware]** 39 | 40 | - **Enabled if:** Always active (monitoring middleware) 41 | - Adds processing time headers to responses 42 | - Useful for monitoring and debugging 43 | 44 | 6. **[`Cql2BuildFilterMiddleware`][stac_auth_proxy.middleware.Cql2BuildFilterMiddleware]** 45 | 46 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured 47 | - Builds CQL2 filters based on request context/state 48 | - Places [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) in request state 49 | 50 | 7. **[`Cql2RewriteLinksFilterMiddleware`][stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware]** 51 | 52 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured 53 | - Rewrites filter parameters in response links to remove applied filters 54 | - Ensures links in responses show the original filter state 55 | 56 | 8. **[`Cql2ApplyFilterQueryStringMiddleware`][stac_auth_proxy.middleware.Cql2ApplyFilterQueryStringMiddleware]** 57 | 58 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured 59 | - Retrieves [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) from request state 60 | - Augments `GET` requests with CQL2 filter by appending to querystring 61 | 62 | 9. **[`Cql2ApplyFilterBodyMiddleware`][stac_auth_proxy.middleware.Cql2ApplyFilterBodyMiddleware]** 63 | 64 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured 65 | - Retrieves [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) from request state 66 | - Augments `POST`/`PUT`/`PATCH` requests with CQL2 filter by modifying body 67 | 68 | 10. **[`Cql2ValidateResponseBodyMiddleware`][stac_auth_proxy.middleware.Cql2ValidateResponseBodyMiddleware]** 69 | 70 | - **Enabled if:** [`ITEMS_FILTER_CLS`](../../user-guide/configuration#items_filter_cls) or [`COLLECTIONS_FILTER_CLS`](../../user-guide/configuration#collections_filter_cls) is configured 71 | - Retrieves [CQL2 expression](http://developmentseed.org/cql2-rs/latest/python/#cql2.Expr) from request state 72 | - Validates response against CQL2 filter for non-filterable endpoints 73 | 74 | 11. **[`OpenApiMiddleware`][stac_auth_proxy.middleware.OpenApiMiddleware]** 75 | 76 | - **Enabled if:** [`OPENAPI_SPEC_ENDPOINT`](../../user-guide/configuration#openapi_spec_endpoint) is set 77 | - Modifies OpenAPI specification based on endpoint configuration, adding security requirements 78 | - Configurable via [`OPENAPI_AUTH_SCHEME_NAME`](../../user-guide/configuration#openapi_auth_scheme_name) and [`OPENAPI_AUTH_SCHEME_OVERRIDE`](../../user-guide/configuration#openapi_auth_scheme_override) 79 | 80 | 12. **[`AuthenticationExtensionMiddleware`][stac_auth_proxy.middleware.AuthenticationExtensionMiddleware]** 81 | - **Enabled if:** [`ENABLE_AUTHENTICATION_EXTENSION`](../../user-guide/configuration#enable_authentication_extension) is enabled 82 | - Adds authentication extension information to STAC responses 83 | - Annotates links with authentication requirements based on [`PUBLIC_ENDPOINTS`](../../user-guide/configuration#public_endpoints) and [`PRIVATE_ENDPOINTS`](../../user-guide/configuration#private_endpoints) 84 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | STAC Auth Proxy API. 3 | 4 | This module defines the FastAPI application for the STAC Auth Proxy, which handles 5 | authentication, authorization, and proxying of requests to some internal STAC API. 6 | """ 7 | 8 | import logging 9 | from typing import Any, Optional 10 | 11 | from fastapi import FastAPI 12 | from starlette_cramjam.middleware import CompressionMiddleware 13 | 14 | from .config import Settings 15 | from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI 16 | from .lifespan import build_lifespan 17 | from .middleware import ( 18 | AddProcessTimeHeaderMiddleware, 19 | AuthenticationExtensionMiddleware, 20 | Cql2ApplyFilterBodyMiddleware, 21 | Cql2ApplyFilterQueryStringMiddleware, 22 | Cql2BuildFilterMiddleware, 23 | Cql2RewriteLinksFilterMiddleware, 24 | Cql2ValidateResponseBodyMiddleware, 25 | EnforceAuthMiddleware, 26 | OpenApiMiddleware, 27 | ProcessLinksMiddleware, 28 | RemoveRootPathMiddleware, 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def configure_app( 35 | app: FastAPI, 36 | settings: Optional[Settings] = None, 37 | **settings_kwargs: Any, 38 | ) -> FastAPI: 39 | """ 40 | Apply routes and middleware to a FastAPI app. 41 | 42 | Parameters 43 | ---------- 44 | app : FastAPI 45 | The FastAPI app to configure. 46 | settings : Settings | None, optional 47 | Pre-built settings instance. If omitted, a new one is constructed from 48 | ``settings_kwargs``. 49 | **settings_kwargs : Any 50 | Keyword arguments used to configure the health and conformance checks if 51 | ``settings`` is not provided. 52 | 53 | """ 54 | settings = settings or Settings(**settings_kwargs) 55 | 56 | # 57 | # Route Handlers 58 | # 59 | 60 | # If we have customized Swagger UI Init settings (e.g. a provided client_id) 61 | # then we need to serve our own Swagger UI in place of the upstream's. This requires 62 | # that we know the Swagger UI endpoint and the OpenAPI spec endpoint. 63 | if all( 64 | [ 65 | settings.swagger_ui_endpoint, 66 | settings.openapi_spec_endpoint, 67 | settings.swagger_ui_init_oauth, 68 | ] 69 | ): 70 | app.add_route( 71 | settings.swagger_ui_endpoint, 72 | SwaggerUI( 73 | openapi_url=settings.openapi_spec_endpoint, # type: ignore 74 | init_oauth=settings.swagger_ui_init_oauth, 75 | ).route, 76 | include_in_schema=False, 77 | ) 78 | 79 | if settings.healthz_prefix: 80 | app.include_router( 81 | HealthzHandler(upstream_url=str(settings.upstream_url)).router, 82 | prefix=settings.healthz_prefix, 83 | ) 84 | 85 | # 86 | # Middleware (order is important, last added = first to run) 87 | # 88 | 89 | if settings.enable_authentication_extension: 90 | app.add_middleware( 91 | AuthenticationExtensionMiddleware, 92 | default_public=settings.default_public, 93 | public_endpoints=settings.public_endpoints, 94 | private_endpoints=settings.private_endpoints, 95 | oidc_discovery_url=str(settings.oidc_discovery_url), 96 | ) 97 | 98 | if settings.openapi_spec_endpoint: 99 | app.add_middleware( 100 | OpenApiMiddleware, 101 | openapi_spec_path=settings.openapi_spec_endpoint, 102 | oidc_discovery_url=str(settings.oidc_discovery_url), 103 | public_endpoints=settings.public_endpoints, 104 | private_endpoints=settings.private_endpoints, 105 | default_public=settings.default_public, 106 | root_path=settings.root_path, 107 | auth_scheme_name=settings.openapi_auth_scheme_name, 108 | auth_scheme_override=settings.openapi_auth_scheme_override, 109 | ) 110 | 111 | if settings.items_filter or settings.collections_filter: 112 | app.add_middleware(Cql2ValidateResponseBodyMiddleware) 113 | app.add_middleware(Cql2ApplyFilterBodyMiddleware) 114 | app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) 115 | app.add_middleware(Cql2RewriteLinksFilterMiddleware) 116 | app.add_middleware( 117 | Cql2BuildFilterMiddleware, 118 | items_filter=settings.items_filter() if settings.items_filter else None, 119 | collections_filter=( 120 | settings.collections_filter() if settings.collections_filter else None 121 | ), 122 | collections_filter_path=settings.collections_filter_path, 123 | items_filter_path=settings.items_filter_path, 124 | ) 125 | 126 | app.add_middleware( 127 | AddProcessTimeHeaderMiddleware, 128 | ) 129 | 130 | app.add_middleware( 131 | EnforceAuthMiddleware, 132 | public_endpoints=settings.public_endpoints, 133 | private_endpoints=settings.private_endpoints, 134 | default_public=settings.default_public, 135 | oidc_discovery_url=settings.oidc_discovery_internal_url, 136 | allowed_jwt_audiences=settings.allowed_jwt_audiences, 137 | ) 138 | 139 | if settings.root_path or settings.upstream_url.path != "/": 140 | app.add_middleware( 141 | ProcessLinksMiddleware, 142 | upstream_url=str(settings.upstream_url), 143 | root_path=settings.root_path, 144 | ) 145 | 146 | if settings.root_path: 147 | app.add_middleware( 148 | RemoveRootPathMiddleware, 149 | root_path=settings.root_path, 150 | ) 151 | 152 | if settings.enable_compression: 153 | app.add_middleware( 154 | CompressionMiddleware, 155 | ) 156 | 157 | return app 158 | 159 | 160 | def create_app(settings: Optional[Settings] = None) -> FastAPI: 161 | """FastAPI Application Factory.""" 162 | settings = settings or Settings() 163 | 164 | app = FastAPI( 165 | openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema 166 | lifespan=build_lifespan(settings=settings), 167 | root_path=settings.root_path, 168 | ) 169 | if app.root_path: 170 | logger.debug("Mounted app at %s", app.root_path) 171 | 172 | configure_app(app, settings) 173 | 174 | app.add_api_route( 175 | "/{path:path}", 176 | ReverseProxyHandler( 177 | upstream=str(settings.upstream_url), 178 | override_host=settings.override_host, 179 | ).proxy_request, 180 | methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], 181 | ) 182 | 183 | return app 184 | -------------------------------------------------------------------------------- /tests/test_auth_extension.py: -------------------------------------------------------------------------------- 1 | """Tests for AuthenticationExtensionMiddleware.""" 2 | 3 | import pytest 4 | from starlette.requests import Request 5 | 6 | from stac_auth_proxy.config import EndpointMethods 7 | from stac_auth_proxy.middleware.AuthenticationExtensionMiddleware import ( 8 | AuthenticationExtensionMiddleware, 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def oidc_discovery_url(): 14 | """Create test OIDC discovery URL.""" 15 | return "https://auth.example.com/discovery" 16 | 17 | 18 | @pytest.fixture 19 | def middleware(oidc_discovery_url): 20 | """Create a test instance of the middleware.""" 21 | return AuthenticationExtensionMiddleware( 22 | app=None, # We don't need the actual app for these tests 23 | default_public=True, 24 | private_endpoints=EndpointMethods(), 25 | public_endpoints=EndpointMethods(), 26 | oidc_discovery_url=oidc_discovery_url, 27 | auth_scheme_name="test_auth", 28 | auth_scheme={}, 29 | ) 30 | 31 | 32 | @pytest.fixture 33 | def request_scope(): 34 | """Create a basic request scope.""" 35 | return { 36 | "type": "http", 37 | "method": "GET", 38 | "path": "/", 39 | "headers": [], 40 | } 41 | 42 | 43 | @pytest.fixture(params=[b"application/json", b"application/geo+json"]) 44 | def initial_message(request): 45 | """Create headers with JSON content type.""" 46 | return { 47 | "type": "http.response.start", 48 | "status": 200, 49 | "headers": [ 50 | (b"date", b"Mon, 07 Apr 2025 06:55:37 GMT"), 51 | (b"server", b"uvicorn"), 52 | (b"content-length", b"27642"), 53 | (b"content-type", request.param), 54 | (b"x-upstream-time", b"0.063"), 55 | ], 56 | } 57 | 58 | 59 | def test_should_transform_response_valid_paths( 60 | middleware, request_scope, initial_message 61 | ): 62 | """Test that valid STAC paths are transformed.""" 63 | valid_paths = [ 64 | "/", 65 | "/collections", 66 | "/collections/test-collection", 67 | "/collections/test-collection/items", 68 | "/collections/test-collection/items/test-item", 69 | "/search", 70 | ] 71 | 72 | for path in valid_paths: 73 | request_scope["path"] = path 74 | request = Request(request_scope) 75 | assert middleware.should_transform_response(request, initial_message) 76 | 77 | 78 | def test_should_transform_response_invalid_paths( 79 | middleware, request_scope, initial_message 80 | ): 81 | """Test that invalid paths are not transformed.""" 82 | invalid_paths = [ 83 | "/api", 84 | "/collections/test-collection/items/test-item/assets", 85 | "/random", 86 | ] 87 | 88 | for path in invalid_paths: 89 | request_scope["path"] = path 90 | request = Request(request_scope) 91 | assert not middleware.should_transform_response(request, initial_message) 92 | 93 | 94 | def test_should_transform_response_invalid_content_type(middleware, request_scope): 95 | """Test that non-JSON content types are not transformed.""" 96 | request = Request(request_scope) 97 | assert not middleware.should_transform_response( 98 | request, 99 | { 100 | "type": "http.response.start", 101 | "status": 200, 102 | "headers": [ 103 | (b"date", b"Mon, 07 Apr 2025 06:55:37 GMT"), 104 | (b"server", b"uvicorn"), 105 | (b"content-length", b"27642"), 106 | (b"content-type", b"text/html"), 107 | (b"x-upstream-time", b"0.063"), 108 | ], 109 | }, 110 | ) 111 | 112 | 113 | def test_transform_json_catalog(middleware, request_scope, oidc_discovery_url): 114 | """Test transforming a STAC catalog.""" 115 | request = Request(request_scope) 116 | 117 | catalog = { 118 | "stac_version": "1.0.0", 119 | "id": "test-catalog", 120 | "description": "Test catalog", 121 | "links": [ 122 | {"rel": "self", "href": "/"}, 123 | {"rel": "root", "href": "/"}, 124 | ], 125 | } 126 | 127 | transformed = middleware.transform_json(catalog, request) 128 | 129 | assert "stac_extensions" in transformed 130 | assert middleware.extension_url in transformed["stac_extensions"] 131 | assert "auth:schemes" in transformed 132 | assert "test_auth" in transformed["auth:schemes"] 133 | 134 | scheme = transformed["auth:schemes"]["test_auth"] 135 | assert scheme["type"] == "openIdConnect" 136 | assert scheme["openIdConnectUrl"] == oidc_discovery_url 137 | 138 | 139 | def test_transform_json_collection(middleware, request_scope): 140 | """Test transforming a STAC collection.""" 141 | request = Request(request_scope) 142 | 143 | collection = { 144 | "stac_version": "1.0.0", 145 | "type": "Collection", 146 | "id": "test-collection", 147 | "description": "Test collection", 148 | "links": [ 149 | {"rel": "self", "href": "/collections/test-collection"}, 150 | {"rel": "items", "href": "/collections/test-collection/items"}, 151 | ], 152 | } 153 | 154 | transformed = middleware.transform_json(collection, request) 155 | 156 | assert "stac_extensions" in transformed 157 | assert middleware.extension_url in transformed["stac_extensions"] 158 | assert "auth:schemes" in transformed 159 | assert "test_auth" in transformed["auth:schemes"] 160 | 161 | 162 | def test_transform_json_item(middleware, request_scope): 163 | """Test transforming a STAC item.""" 164 | request = Request(request_scope) 165 | 166 | item = { 167 | "stac_version": "1.0.0", 168 | "type": "Feature", 169 | "id": "test-item", 170 | "properties": {}, 171 | "links": [ 172 | {"rel": "self", "href": "/collections/test-collection/items/test-item"}, 173 | {"rel": "collection", "href": "/collections/test-collection"}, 174 | ], 175 | } 176 | 177 | transformed = middleware.transform_json(item, request) 178 | 179 | assert "stac_extensions" in transformed 180 | assert middleware.extension_url in transformed["stac_extensions"] 181 | assert "auth:schemes" in transformed["properties"] 182 | assert "test_auth" in transformed["properties"]["auth:schemes"] 183 | 184 | 185 | def test_transform_json_missing_oidc_metadata(middleware, request_scope): 186 | """Test transforming when OIDC metadata is missing.""" 187 | request = Request(request_scope) 188 | 189 | catalog = { 190 | "stac_version": "1.0.0", 191 | "id": "test-catalog", 192 | "description": "Test catalog", 193 | } 194 | 195 | transformed = middleware.transform_json(catalog, request) 196 | # Should return unchanged when OIDC metadata is missing 197 | assert transformed == catalog 198 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py: -------------------------------------------------------------------------------- 1 | """Middleware to enforce authentication.""" 2 | 3 | import logging 4 | from dataclasses import dataclass, field 5 | from typing import Annotated, Any, Optional, Sequence 6 | from urllib.parse import urlparse, urlunparse 7 | 8 | import httpx 9 | import jwt 10 | from fastapi import HTTPException, Request, Security, status 11 | from pydantic import HttpUrl 12 | from starlette.responses import JSONResponse 13 | from starlette.types import ASGIApp, Receive, Scope, Send 14 | 15 | from ..config import EndpointMethods 16 | from ..utils.requests import find_match 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @dataclass 22 | class OidcService: 23 | """OIDC configuration and JWKS client.""" 24 | 25 | oidc_discovery_url: HttpUrl 26 | jwks_client: jwt.PyJWKClient = field(init=False) 27 | metadata: dict[str, Any] = field(init=False) 28 | 29 | def __post_init__(self) -> None: 30 | """Initialize OIDC config and JWKS client.""" 31 | logger.debug("Requesting OIDC config") 32 | origin_url = str(self.oidc_discovery_url) 33 | 34 | try: 35 | response = httpx.get(origin_url) 36 | response.raise_for_status() 37 | self.metadata = response.json() 38 | assert self.metadata, "OIDC metadata is empty" 39 | 40 | # NOTE: We manually replace the origin of the jwks_uri in the event that 41 | # the jwks_uri is not available from within the proxy. 42 | oidc_url = urlparse(origin_url) 43 | jwks_uri = urlunparse( 44 | urlparse(self.metadata["jwks_uri"])._replace( 45 | netloc=oidc_url.netloc, scheme=oidc_url.scheme 46 | ) 47 | ) 48 | if jwks_uri != self.metadata["jwks_uri"]: 49 | logger.warning( 50 | "OIDC Discovery reported a JWKS URI of %s but we're going to use %s to match the OIDC Discovery URL", 51 | self.metadata["jwks_uri"], 52 | jwks_uri, 53 | ) 54 | self.jwks_client = jwt.PyJWKClient(jwks_uri) 55 | except httpx.HTTPStatusError as e: 56 | logger.error( 57 | "Received a non-200 response when fetching OIDC config: %s", 58 | e.response.text, 59 | ) 60 | raise OidcFetchError( 61 | f"Request for OIDC config failed with status {e.response.status_code}" 62 | ) from e 63 | except httpx.RequestError as e: 64 | logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e)) 65 | raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e 66 | 67 | 68 | @dataclass 69 | class EnforceAuthMiddleware: 70 | """Middleware to enforce authentication.""" 71 | 72 | app: ASGIApp 73 | private_endpoints: EndpointMethods 74 | public_endpoints: EndpointMethods 75 | default_public: bool 76 | oidc_discovery_url: HttpUrl 77 | allowed_jwt_audiences: Optional[Sequence[str]] = None 78 | state_key: str = "payload" 79 | 80 | _oidc_config: Optional[OidcService] = None 81 | 82 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 83 | """Enforce authentication.""" 84 | if scope["type"] != "http": 85 | return await self.app(scope, receive, send) 86 | 87 | request = Request(scope) 88 | 89 | # Skip authentication for OPTIONS requests, https://fetch.spec.whatwg.org/#cors-protocol-and-credentials 90 | if request.method == "OPTIONS": 91 | return await self.app(scope, receive, send) 92 | 93 | match = find_match( 94 | request.url.path, 95 | request.method, 96 | private_endpoints=self.private_endpoints, 97 | public_endpoints=self.public_endpoints, 98 | default_public=self.default_public, 99 | ) 100 | try: 101 | payload = self.validate_token( 102 | request.headers.get("Authorization"), 103 | auto_error=match.is_private, 104 | required_scopes=match.required_scopes, 105 | ) 106 | 107 | except HTTPException as e: 108 | response = JSONResponse({"detail": e.detail}, status_code=e.status_code) 109 | return await response(scope, receive, send) 110 | 111 | # Set the payload in the request state 112 | setattr( 113 | request.state, 114 | self.state_key, 115 | payload, 116 | ) 117 | setattr(request.state, "oidc_metadata", self.oidc_config.metadata) 118 | return await self.app(scope, receive, send) 119 | 120 | def validate_token( 121 | self, 122 | auth_header: Annotated[str, Security(...)], 123 | auto_error: bool = True, 124 | required_scopes: Optional[Sequence[str]] = None, 125 | ) -> Optional[dict[str, Any]]: 126 | """Dependency to validate an OIDC token.""" 127 | if not auth_header: 128 | if auto_error: 129 | raise HTTPException( 130 | status_code=status.HTTP_401_UNAUTHORIZED, 131 | detail="Not authenticated", 132 | headers={"WWW-Authenticate": "Bearer"}, 133 | ) 134 | return None 135 | 136 | # Extract token from header 137 | token_parts = auth_header.split(" ") 138 | if len(token_parts) != 2 or token_parts[0].lower() != "bearer": 139 | logger.error("Invalid Authorization header format") 140 | raise HTTPException( 141 | status_code=status.HTTP_401_UNAUTHORIZED, 142 | detail="Invalid Authorization header format", 143 | headers={"WWW-Authenticate": "Bearer"}, 144 | ) 145 | [_, token] = token_parts 146 | 147 | # Parse & validate token 148 | try: 149 | key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key 150 | payload = jwt.decode( 151 | token, 152 | key, 153 | algorithms=["RS256"], 154 | # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40) 155 | audience=self.allowed_jwt_audiences, 156 | ) 157 | except jwt.InvalidAudienceError as e: 158 | logger.error("Token audience validation failed: %s", str(e)) 159 | raise HTTPException( 160 | status_code=status.HTTP_401_UNAUTHORIZED, 161 | detail="Invalid token audience", 162 | headers={"WWW-Authenticate": "Bearer"}, 163 | ) 164 | except ( 165 | jwt.exceptions.InvalidTokenError, 166 | jwt.exceptions.DecodeError, 167 | jwt.exceptions.PyJWKClientError, 168 | ) as e: 169 | logger.error("Token validation failed: %s", type(e).__name__) 170 | raise HTTPException( 171 | status_code=status.HTTP_401_UNAUTHORIZED, 172 | detail="Invalid or expired token", 173 | headers={"WWW-Authenticate": "Bearer"}, 174 | ) from e 175 | 176 | # Check authorization (scopes) 177 | if required_scopes: 178 | token_scopes = set(payload.get("scope", "").split()) 179 | missing_scopes = set(required_scopes) - token_scopes 180 | if missing_scopes: 181 | logger.warning( 182 | "Insufficient scopes for user %s. Required: %s, Has: %s", 183 | payload.get("sub"), 184 | required_scopes, 185 | token_scopes, 186 | ) 187 | raise HTTPException( 188 | status_code=status.HTTP_403_FORBIDDEN, 189 | detail=f"Insufficient permissions. Required scopes: {', '.join(missing_scopes)}", 190 | headers={ 191 | "WWW-Authenticate": f'Bearer scope="{" ".join(required_scopes)}"' 192 | }, 193 | ) 194 | 195 | return payload 196 | 197 | @property 198 | def oidc_config(self) -> OidcService: 199 | """Get the OIDC configuration.""" 200 | if not self._oidc_config: 201 | self._oidc_config = OidcService(oidc_discovery_url=self.oidc_discovery_url) 202 | return self._oidc_config 203 | 204 | 205 | class OidcFetchError(Exception): 206 | """Error fetching OIDC configuration.""" 207 | 208 | ... 209 | -------------------------------------------------------------------------------- /src/stac_auth_proxy/utils/requests.py: -------------------------------------------------------------------------------- 1 | """Utility functions for working with HTTP requests.""" 2 | 3 | import json 4 | import logging 5 | import re 6 | from dataclasses import dataclass, field 7 | from typing import Dict, Optional, Sequence 8 | from urllib.parse import urlparse 9 | 10 | from starlette.requests import Request 11 | 12 | from ..config import EndpointMethods 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def extract_variables(url: str) -> dict: 18 | """ 19 | Extract variables from a URL path. Being that we use a catch-all endpoint for the proxy, 20 | we can't rely on the path parameters that FastAPI provides. 21 | """ 22 | path = urlparse(url).path 23 | # This allows either /items or /bulk_items, with an optional item_id following. 24 | pattern = r"^/collections/(?P[^/]+)(?:/(?:items|bulk_items)(?:/(?P[^/]+))?)?/?$" 25 | match = re.match(pattern, path) 26 | return {k: v for k, v in match.groupdict().items() if v} if match else {} 27 | 28 | 29 | def dict_to_bytes(d: dict) -> bytes: 30 | """Convert a dictionary to a body.""" 31 | return json.dumps(d, separators=(",", ":")).encode("utf-8") 32 | 33 | 34 | def _check_endpoint_match( 35 | path: str, 36 | method: str, 37 | endpoints: EndpointMethods, 38 | ) -> tuple[bool, Sequence[str]]: 39 | """Check if the path and method match any endpoint in the given endpoints map.""" 40 | for pattern, endpoint_methods in endpoints.items(): 41 | if re.match(pattern, path): 42 | for endpoint_method in endpoint_methods: 43 | required_scopes: Sequence[str] = [] 44 | if isinstance(endpoint_method, tuple): 45 | endpoint_method, _required_scopes = endpoint_method 46 | if _required_scopes: # Ignore empty scopes, e.g. `["POST", ""]` 47 | required_scopes = _required_scopes.split(" ") 48 | if method.casefold() == endpoint_method.casefold(): 49 | return True, required_scopes 50 | return False, [] 51 | 52 | 53 | def find_match( 54 | path: str, 55 | method: str, 56 | private_endpoints: EndpointMethods, 57 | public_endpoints: EndpointMethods, 58 | default_public: bool, 59 | ) -> "MatchResult": 60 | """Check if the given path and method match any of the regex patterns and methods in the endpoints.""" 61 | primary_endpoints = private_endpoints if default_public else public_endpoints 62 | matched, required_scopes = _check_endpoint_match(path, method, primary_endpoints) 63 | if matched: 64 | return MatchResult( 65 | is_private=default_public, 66 | required_scopes=required_scopes, 67 | ) 68 | 69 | # If default_public and no match found in private_endpoints, it's public 70 | if default_public: 71 | return MatchResult(is_private=False) 72 | 73 | # If not default_public, check private_endpoints for required scopes 74 | matched, required_scopes = _check_endpoint_match(path, method, private_endpoints) 75 | if matched: 76 | return MatchResult(is_private=True, required_scopes=required_scopes) 77 | 78 | # Default case: if not default_public and no explicit match, it's private 79 | return MatchResult(is_private=True) 80 | 81 | 82 | @dataclass 83 | class MatchResult: 84 | """Result of a match between a path and method and a set of endpoints.""" 85 | 86 | is_private: bool 87 | required_scopes: Sequence[str] = field(default_factory=list) 88 | 89 | 90 | def build_server_timing_header( 91 | current_value: Optional[str] = None, *, name: str, desc: str, dur: float 92 | ): 93 | """Append a timing header to headers.""" 94 | metric = f'{name};desc="{desc}";dur={dur:.3f}' 95 | if current_value: 96 | return f"{current_value}, {metric}" 97 | return metric 98 | 99 | 100 | def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]: 101 | """ 102 | Parse the Forwarded header according to RFC 7239. 103 | 104 | Args: 105 | forwarded_header: The Forwarded header value 106 | 107 | Returns: 108 | Dictionary containing parsed forwarded information (proto, host, for, by, etc.) 109 | 110 | Example: 111 | >>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com") 112 | {'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'} 113 | 114 | """ 115 | # Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com" 116 | # The format is: for=value1, for=value2; by=value; proto=value; host=value 117 | # We need to parse all the key=value pairs, taking the first 'for' value 118 | forwarded_info = {} 119 | 120 | try: 121 | # Parse all key=value pairs separated by semicolons 122 | for pair in forwarded_header.split(";"): 123 | pair = pair.strip() 124 | if "=" in pair: 125 | key, value = pair.split("=", 1) 126 | key = key.strip() 127 | value = value.strip().strip('"') 128 | 129 | # For 'for' field, only take the first value if there are multiple 130 | if key == "for" and key not in forwarded_info: 131 | # Extract the first for value (before comma if present) 132 | first_for_value = value.split(",")[0].strip() 133 | forwarded_info[key] = first_for_value 134 | elif key != "for": 135 | # For other fields, just use the value as-is 136 | forwarded_info[key] = value 137 | except Exception as e: 138 | logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}") 139 | return {} 140 | 141 | return forwarded_info 142 | 143 | 144 | def get_base_url(request: Request) -> str: 145 | """ 146 | Get the request's base URL, accounting for forwarded headers from load balancers/proxies. 147 | 148 | This function handles both the standard Forwarded header (RFC 7239) and legacy 149 | X-Forwarded-* headers to reconstruct the original client URL when the service 150 | is deployed behind load balancers or reverse proxies. 151 | 152 | Args: 153 | request: The Starlette request object 154 | 155 | Returns: 156 | The reconstructed client base URL 157 | 158 | Example: 159 | >>> # With Forwarded header 160 | >>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"} 161 | >>> get_base_url(request) 162 | "https://api.example.com/" 163 | 164 | >>> # With X-Forwarded-* headers 165 | >>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"} 166 | >>> get_base_url(request) 167 | "https://api.example.com/" 168 | 169 | """ 170 | # Check for standard Forwarded header first (RFC 7239) 171 | forwarded_header = request.headers.get("Forwarded") 172 | if forwarded_header: 173 | try: 174 | forwarded_info = parse_forwarded_header(forwarded_header) 175 | # Only use Forwarded header if we successfully parsed it and got useful info 176 | if forwarded_info and ( 177 | "proto" in forwarded_info or "host" in forwarded_info 178 | ): 179 | scheme = forwarded_info.get("proto", request.url.scheme) 180 | host = forwarded_info.get("host", request.url.netloc) 181 | # Note: Forwarded header doesn't include path, so we use request.base_url.path 182 | path = request.base_url.path 183 | return f"{scheme}://{host}{path}" 184 | except Exception as e: 185 | logger.warning(f"Failed to parse Forwarded header: {e}") 186 | 187 | # Fall back to legacy X-Forwarded-* headers 188 | forwarded_host = request.headers.get("X-Forwarded-Host") 189 | forwarded_proto = request.headers.get("X-Forwarded-Proto") 190 | forwarded_path = request.headers.get("X-Forwarded-Path") 191 | 192 | if forwarded_host: 193 | # Use forwarded headers to reconstruct the original client URL 194 | scheme = forwarded_proto or request.url.scheme 195 | netloc = forwarded_host 196 | # Use forwarded path if available, otherwise use request base URL path 197 | path = forwarded_path or request.base_url.path 198 | else: 199 | # Fall back to the request's base URL if no forwarded headers 200 | scheme = request.url.scheme 201 | netloc = request.url.netloc 202 | path = request.base_url.path 203 | 204 | return f"{scheme}://{netloc}{path}" 205 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest fixtures.""" 2 | 3 | import os 4 | import socket 5 | import threading 6 | from functools import partial 7 | from typing import Any, AsyncGenerator 8 | from unittest.mock import DEFAULT, AsyncMock, MagicMock, patch 9 | 10 | import pytest 11 | import uvicorn 12 | from fastapi import FastAPI 13 | from jwcrypto import jwk, jwt 14 | from starlette_cramjam.middleware import CompressionMiddleware 15 | from utils import single_chunk_async_stream_response 16 | 17 | 18 | @pytest.fixture 19 | def test_key() -> jwk.JWK: 20 | """Generate a test RSA key.""" 21 | return jwk.JWK.generate( 22 | kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256" 23 | ) 24 | 25 | 26 | @pytest.fixture 27 | def public_key(test_key: jwk.JWK) -> dict[str, Any]: 28 | """Export public key.""" 29 | return test_key.export_public(as_dict=True) 30 | 31 | 32 | @pytest.fixture(autouse=True) 33 | def mock_jwks(public_key: dict[str, Any]): 34 | """Mock JWKS endpoint.""" 35 | mock_oidc_config = { 36 | "jwks_uri": "https://example.com/jwks", 37 | "authorization_endpoint": "https://example.com/auth", 38 | "token_endpoint": "https://example.com/token", 39 | "scopes_supported": ["openid", "profile", "email", "collection:create"], 40 | } 41 | 42 | mock_jwks = {"keys": [public_key]} 43 | 44 | with ( 45 | patch("httpx.get") as mock_urlopen, 46 | patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data, 47 | ): 48 | mock_oidc_config_response = MagicMock() 49 | mock_oidc_config_response.json.return_value = mock_oidc_config 50 | mock_oidc_config_response.status = 200 51 | 52 | mock_urlopen.return_value = mock_oidc_config_response 53 | mock_fetch_data.return_value = mock_jwks 54 | yield mock_urlopen 55 | 56 | 57 | @pytest.fixture 58 | def token_builder(test_key: jwk.JWK): 59 | """Generate a valid JWT token builder.""" 60 | 61 | def build_token(payload: dict[str, Any], key=None) -> str: 62 | jwt_token = jwt.JWT( 63 | header={k: test_key.get(k) for k in ["alg", "kid"]}, 64 | claims=payload, 65 | ) 66 | jwt_token.make_signed_token(key or test_key) 67 | return jwt_token.serialize() 68 | 69 | return build_token 70 | 71 | 72 | @pytest.fixture(scope="session") 73 | def source_api(): 74 | """ 75 | Create upstream API for testing purposes. 76 | 77 | You can customize the response for each endpoint by passing a dict of responses: 78 | { 79 | "path": { 80 | "method": response_body 81 | } 82 | } 83 | """ 84 | app = FastAPI(docs_url="/api.html", openapi_url="/api") 85 | 86 | app.add_middleware(CompressionMiddleware, minimum_size=0, compression_level=1) 87 | 88 | # Default responses for each endpoint 89 | default_responses = { 90 | "/": { 91 | "GET": {"id": "Response from GET@"}, 92 | "OPTIONS": {"id": "Response from OPTIONS@"}, 93 | }, 94 | "/conformance": { 95 | "GET": {"conformsTo": ["http://example.com/conformance"]}, 96 | "OPTIONS": {"id": "Response from OPTIONS@"}, 97 | }, 98 | "/queryables": { 99 | "GET": {"queryables": {}}, 100 | "OPTIONS": {"id": "Response from OPTIONS@"}, 101 | }, 102 | "/search": { 103 | "GET": {"type": "FeatureCollection", "features": []}, 104 | "POST": {"type": "FeatureCollection", "features": []}, 105 | "OPTIONS": {"id": "Response from OPTIONS@"}, 106 | }, 107 | "/collections": { 108 | "GET": {"collections": []}, 109 | "POST": {"id": "Response from POST@"}, 110 | "OPTIONS": {"id": "Response from OPTIONS@"}, 111 | }, 112 | "/collections/{collection_id}": { 113 | "GET": {"id": "Response from GET@"}, 114 | "PUT": {"id": "Response from PUT@"}, 115 | "PATCH": {"id": "Response from PATCH@"}, 116 | "DELETE": {"id": "Response from DELETE@"}, 117 | "OPTIONS": {"id": "Response from OPTIONS@"}, 118 | }, 119 | "/collections/{collection_id}/items": { 120 | "GET": {"type": "FeatureCollection", "features": []}, 121 | "POST": {"id": "Response from POST@"}, 122 | "OPTIONS": {"id": "Response from OPTIONS@"}, 123 | }, 124 | "/collections/{collection_id}/items/{item_id}": { 125 | "GET": {"id": "Response from GET@"}, 126 | "PUT": {"id": "Response from PUT@"}, 127 | "PATCH": {"id": "Response from PATCH@"}, 128 | "DELETE": {"id": "Response from DELETE@"}, 129 | "OPTIONS": {"id": "Response from OPTIONS@"}, 130 | }, 131 | "/collections/{collection_id}/bulk_items": { 132 | "POST": {"id": "Response from POST@"}, 133 | "OPTIONS": {"id": "Response from OPTIONS@"}, 134 | }, 135 | } 136 | 137 | # Store responses in app state 138 | app.state.default_responses = default_responses 139 | 140 | def get_response(path: str, method: str) -> dict: 141 | """Get response for a given path and method.""" 142 | return app.state.default_responses.get(path, {}).get( 143 | method, {"id": f"Response from {method}@{path}"} 144 | ) 145 | 146 | for path, methods in default_responses.items(): 147 | for method in methods: 148 | app.add_api_route( 149 | path, 150 | partial(get_response, path, method), 151 | methods=[method], 152 | ) 153 | 154 | return app 155 | 156 | 157 | @pytest.fixture 158 | def source_api_responses(source_api): 159 | """ 160 | Fixture to override source API responses for specific tests. 161 | 162 | Usage: 163 | def test_something(source_api_responses): 164 | # Override responses for specific endpoints 165 | source_api_responses["/collections"]["GET"] = {"collections": [{"id": "test"}]} 166 | source_api_responses["/search"]["POST"] = {"type": "FeatureCollection", "features": [{"id": "test"}]} 167 | 168 | # Your test code here 169 | """ 170 | # Get the default responses from the source_api fixture 171 | default_responses = source_api.state.default_responses 172 | 173 | # Create a new dict that can be modified by tests 174 | responses = {} 175 | for path, methods in default_responses.items(): 176 | responses[path] = methods.copy() 177 | 178 | # Store the responses in the app state for the get_response function to use 179 | source_api.state.default_responses = responses 180 | 181 | yield responses 182 | 183 | # Restore the original responses after the test 184 | source_api.state.default_responses = default_responses 185 | 186 | 187 | @pytest.fixture(scope="session") 188 | def free_port(): 189 | """Get a free port.""" 190 | sock = socket.socket() 191 | # Needed for Github Actions, https://stackoverflow.com/a/4466035 192 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 193 | sock.bind(("", 0)) 194 | return sock.getsockname()[1] 195 | 196 | 197 | @pytest.fixture(scope="session") 198 | def source_api_server(source_api, free_port): 199 | """Run the source API in a background thread.""" 200 | host = "127.0.0.1" 201 | server = uvicorn.Server( 202 | uvicorn.Config( 203 | source_api, 204 | host=host, 205 | port=free_port, 206 | ) 207 | ) 208 | thread = threading.Thread(target=server.run) 209 | thread.start() 210 | yield f"http://{host}:{free_port}" 211 | server.should_exit = True 212 | thread.join() 213 | 214 | 215 | @pytest.fixture(autouse=True, scope="session") 216 | def mock_env(): 217 | """Clear environment variables to avoid poluting configs from runtime env.""" 218 | with patch.dict(os.environ, clear=True): 219 | yield 220 | 221 | 222 | @pytest.fixture 223 | async def mock_upstream() -> AsyncGenerator[MagicMock, None]: 224 | """Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API.""" 225 | # NOTE: This fixture will interfere with the source_api_responses fixture 226 | 227 | async def store_body(request, **kwargs): 228 | """Exhaust and store the request body.""" 229 | _streamed_body = b"" 230 | async for chunk in request.stream: 231 | _streamed_body += chunk 232 | setattr(request, "_streamed_body", _streamed_body) 233 | return DEFAULT 234 | 235 | with patch( 236 | "stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send", 237 | new_callable=AsyncMock, 238 | side_effect=store_body, 239 | return_value=single_chunk_async_stream_response(b"{}"), 240 | ) as mock_send_method: 241 | yield mock_send_method 242 | --------------------------------------------------------------------------------