├── .devcontainer ├── devcontainer.json ├── docker-compose.yml ├── post-create.sh ├── postgres_init.sql └── telegraf.conf ├── .flake8 ├── .github └── workflows │ └── release.yaml ├── .gitignore ├── .svu.yaml ├── .vscode ├── launch.json └── settings.json ├── Dockerfile ├── LICENSE ├── README.md ├── aip.png ├── alembic.ini ├── alembic ├── README ├── env.py ├── script.py.mako └── versions │ └── 380b77abd479_init.py ├── bin └── scripts │ ├── entrypoint.sh │ ├── init-db.sh │ └── init.sh ├── docker-compose.yml ├── pdm.lock ├── pyproject.toml ├── src └── social │ ├── __init__.py │ └── graze │ ├── __init__.py │ └── aip │ ├── __init__.py │ ├── app │ ├── __init__.py │ ├── __main__.py │ ├── config.py │ ├── cors.py │ ├── handlers │ │ ├── __init__.py │ │ ├── app_password.py │ │ ├── credentials.py │ │ ├── helpers.py │ │ ├── internal.py │ │ ├── oauth.py │ │ ├── permissions.py │ │ └── proxy.py │ ├── server.py │ ├── tasks.py │ └── util │ │ ├── __init__.py │ │ └── __main__.py │ ├── atproto │ ├── __init__.py │ ├── app_password.py │ ├── chain.py │ ├── oauth.py │ └── pds.py │ ├── model │ ├── __init__.py │ ├── app_password.py │ ├── base.py │ ├── handles.py │ ├── health.py │ └── oauth.py │ └── resolve │ ├── __init__.py │ ├── __main__.py │ └── handle.py ├── static ├── login.css ├── pico.classless.blue.css └── pico.colors.min.css ├── telegraf.conf ├── templates ├── atproto_debug.html ├── atproto_login.html ├── base.html ├── index.html └── settings.html └── tests ├── __init__.py └── test_example.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "dockerComposeFile": [ 3 | "docker-compose.yml" 4 | ], 5 | "service": "devcontainer", 6 | "workspaceFolder": "/workspace", 7 | "features": { 8 | "ghcr.io/devcontainers/features/docker-in-docker:2": {} 9 | }, 10 | "postCreateCommand": "/workspace/.devcontainer/post-create.sh", 11 | "forwardPorts": [5100] 12 | } -------------------------------------------------------------------------------- /.devcontainer/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | volumes: 3 | aip_db: 4 | aip_ts: 5 | services: 6 | devcontainer: 7 | image: mcr.microsoft.com/devcontainers/python:3.13 8 | volumes: 9 | - ..:/workspace:cached 10 | - /var/run/docker.sock:/var/run/docker.sock 11 | command: sleep infinity 12 | environment: 13 | - HTTP_PORT=5100 14 | - TZ=America/New_York 15 | - DATABASE_URL=postgres://postgres:password@postgres/aip 16 | - JSON_WEB_KEYS=/workspace/signing_keys.json 17 | 18 | postgres: 19 | image: postgres:17-alpine 20 | restart: unless-stopped 21 | volumes: 22 | - aip_db:/var/lib/postgresql/data 23 | - ./postgres_init.sql:/docker-entrypoint-initdb.d/init.sql 24 | environment: 25 | - POSTGRES_PASSWORD=password 26 | healthcheck: 27 | test: 'pg_isready -U postgres' 28 | interval: 500ms 29 | timeout: 10s 30 | retries: 20 31 | 32 | telegraf: 33 | image: docker.io/telegraf:latest 34 | volumes: 35 | - ./telegraf.conf:/etc/telegraf/telegraf.conf 36 | 37 | valkey: 38 | image: valkey/valkey:8-alpine 39 | 40 | tailscale: 41 | image: tailscale/tailscale:latest 42 | restart: unless-stopped 43 | environment: 44 | - TS_STATE_DIR=/var/run/tailscale 45 | - TS_EXTRA_ARGS=--advertise-tags=tag:aip 46 | volumes: 47 | - aip_ts:/var/run/tailscale 48 | -------------------------------------------------------------------------------- /.devcontainer/post-create.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | sudo usermod -a -G docker vscode 4 | 5 | sudo apt-get update 6 | sudo apt-get install -y postgresql-client 7 | 8 | curl -sSL https://pdm-project.org/install-pdm.py | python3 - 9 | -------------------------------------------------------------------------------- /.devcontainer/postgres_init.sql: -------------------------------------------------------------------------------- 1 | -- aip 2 | CREATE DATABASE aip; 3 | GRANT ALL PRIVILEGES ON DATABASE aip TO postgres; 4 | -------------------------------------------------------------------------------- /.devcontainer/telegraf.conf: -------------------------------------------------------------------------------- 1 | 2 | [global_tags] 3 | [agent] 4 | interval = "10s" 5 | round_interval = true 6 | metric_batch_size = 1000 7 | metric_buffer_limit = 10000 8 | collection_jitter = "0s" 9 | flush_interval = "10s" 10 | flush_jitter = "0s" 11 | precision = "0s" 12 | [[outputs.influxdb_v2]] 13 | urls = ["http://metrics.bowfin-woodpecker.ts.net:8086"] 14 | token = "7eAc0CgtNV4-yeDwnKl01tBYdxMrMPtlmmz3h-urW6uBGt2Uv3byhsTkgwiHnkn45Vr0gdaqnc6tbyfhUWxpEw==" 15 | organization = "graze" 16 | bucket = "telegraf" 17 | [[inputs.statsd]] 18 | protocol = "udp" 19 | service_address = ":8125" 20 | datadog_extensions = true 21 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend-ignore = E203 3 | exclude = 4 | .git, 5 | .venv, 6 | .pdm-build, 7 | .pytest_cache, 8 | __pycache__, 9 | alembic, 10 | dist 11 | max-complexity = 20 12 | max-line-length = 120 -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Create AIP Release 2 | on: 3 | pull_request: 4 | types: 5 | - closed 6 | branches: 7 | - main 8 | jobs: 9 | release_aip: 10 | runs-on: ubuntu-latest 11 | name: create_release 12 | permissions: 13 | contents: write 14 | id-token: write 15 | steps: 16 | - name: checkout 17 | uses: actions/checkout@v3 18 | with: 19 | fetch-depth: 0 20 | 21 | - name: setup_svu 22 | run: curl -kL https://github.com/caarlos0/svu/releases/download/v3.2.2/svu_3.2.2_linux_amd64.tar.gz | tar zx && mv svu /usr/local/bin/svu && chmod +x /usr/local/bin/svu 23 | 24 | - name: create_tag 25 | id: create_tag 26 | run: | 27 | echo "VERSION_TAG=$(svu next)" >> $GITHUB_ENV 28 | echo "VERSION_TAG=$(svu next)" >> $GITHUB_OUTPUT 29 | 30 | - name: create_release 31 | env: 32 | GH_TOKEN: ${{ github.token }} 33 | run: |- 34 | gh release create ${{ env.VERSION_TAG }} -t ${{ env.VERSION_TAG }} --generate-notes 35 | 36 | - name: repo_dispatch 37 | uses: peter-evans/repository-dispatch@v3 38 | with: 39 | repository: graze-social/turbo-deploy 40 | event-type: aip-release 41 | token: ${{ secrets.DISPATCH_GH_TOKEN }} 42 | client-payload: |- 43 | { 44 | "ref": "${{ github.ref }}", 45 | "sha": "${{ github.sha }}", 46 | "version_tag": "${{env.VERSION_TAG}}", 47 | "pr_context": ${{toJson(github.event.pull_request)}} 48 | } 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | scratch/ 173 | .DS_Store 174 | **/*/.DS_Store 175 | -------------------------------------------------------------------------------- /.svu.yaml: -------------------------------------------------------------------------------- 1 | always: true 2 | v0: true -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Python Debugger: Module", 5 | "type": "debugpy", 6 | "request": "launch", 7 | "module": "social.graze.aip.app.__main__", 8 | "env": { 9 | "ACTIVE_SIGNING_KEYS": "[\"01JK8QQSW7YSVDEYQS3J7ZE3XJ\"]", 10 | "EXTERNAL_HOSTNAME": "grazeaip.tunn.dev", 11 | "WORKER_ID": "dev1", 12 | "SERVICE_AUTH_KEYS": "[\"01JK8QQSW7YSVDEYQS3J7ZE3XJ\"]", 13 | "DATABASE_URL": "postgres://postgres:password@postgres/aip", 14 | "ENCRYPTION_KEY": "TFZTT0lOSWlTSk8xbWhJMzJYY0lBS2dQcXFQMTVjX0o4ZUlJTFNyWVpzQT0=", 15 | "JSON_WEB_KEYS": "/workspace/signing_keys.json", 16 | "PLC_HOSTNAME": "plc.bowfin-woodpecker.ts.net", 17 | "DEBUG": "true" 18 | }, 19 | } 20 | ] 21 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.inlayHints.pytestParameters": true, 3 | "python.analysis.inlayHints.functionReturnTypes": true, 4 | "python.analysis.inlayHints.variableTypes": true, 5 | "python.analysis.typeCheckingMode": "standard" 6 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use a lightweight Python image 2 | FROM python:3.13-slim 3 | 4 | # Set working directory 5 | WORKDIR /app 6 | 7 | # Install system dependencies 8 | RUN apt-get update && apt-get install -y \ 9 | curl \ 10 | libpq-dev \ 11 | postgresql-client \ 12 | gcc \ 13 | python3-dev \ 14 | jq \ 15 | && rm -rf /var/lib/apt/lists/* 16 | 17 | # Install PDM globally 18 | RUN pip install pdm 19 | 20 | # TODO: we can split into a multi-stage build and just ship the .venv 21 | # just keeping everything for now 22 | COPY . . 23 | 24 | # Ensure PDM always uses the in-project virtual environment 25 | ENV PDM_VENV_IN_PROJECT=1 26 | ENV PATH="/app/.venv/bin:$PATH" 27 | 28 | # Install dependencies properly inside the virtual environment 29 | RUN pdm install 30 | 31 | # Expose the application port 32 | # TODO: These should be configurable, not hard-coded and thus publishing them here is moot. 33 | EXPOSE 8080 34 | EXPOSE 5100 35 | 36 | # Available CMDs 37 | # See pyproject.toml for more details 38 | # CMD ["pdm", "run", "aipserver"] 39 | # CMD ["pdm", "run", "resolve"] 40 | # CMD ["pdm", "run", "aiputil"] 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Graze Social 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ATmosphere Authentication, Identity, and Permission Proxy 2 | 3 | ![Image from 391 Vol 1– 19 by Francis Picabia, https://archive.org/details/391-vol-1-19/page/n98/mode/1up](./aip.png) 4 | ## Running Locally 5 | http://localhost:8080/internal/api/resolve?subject=ngerakines.me&subject=mattie.thegem.city 6 | 7 | 1. Install `pdm install` (Note that `pdm install` may require `sudo apt install -y clang libpq-dev python3-dev build-essential` to build for postgres requirements!) 8 | 2. Start up a postgres server, `export DATABASE_URL=address` for that server 9 | 3. Run migrations `pdm run alembic upgrade head` 10 | 4. Populate signing keys: `OUT=$(pdm run aiputil gen-jwk)` and then `echo "{\"keys\":[$OUT]}" > signing_keys.json` 11 | 12 | 5. Set hostname: `export EXTERNAL_HOSTNAME=grazeaip.tunn.dev` 13 | 6. Set plc hostname: `export PLC_HOSTNAME=plc.bowfin-woodpecker.ts.net` 14 | 7. Set active signing keys: `export ACTIVE_SIGNING_KEYS='["{KEY_ID}"]'` 15 | 8. Start service `pdm run aipserver` 16 | 9. Generate a handle: https://pdsdns.bowfin-woodpecker.ts.net 17 | 10. Verify resolution: `pdm run resolve --plc-hostname ${PLC_HOSTNAME} enabling-boxer.pyroclastic.cloud` 18 | 11. Auth with it: https://grazeaip.tunn.dev/auth/atproto 19 | 20 | ## Running via Docker 21 | 22 | 1. Install `pdm install` (Note that `pdm install` may require `sudo apt install -y clang libpq-dev python3-dev build-essential` to build for postgres requirements!) 23 | 2. Populate signing keys: `OUT=$(pdm run aiputil gen-jwk)` and then `echo "{\"keys\":[$OUT]}" > signing_keys.json` 24 | 3. Add `EXTERNAL_HOSTNAME: your-host` to your docker compose file[1]. 25 | 4. `docker compose build && docker compose up` 26 | 5. Then navigate to your host at the path /auth/atproto. This is your login page! 27 | 6. Note that in the config you can change colors, text, display image, and default post-login destination. You can *also* forward the response to *any* URL with a ?destination={URL} parameter on the sign-in page 28 | 29 | [1]: This must match the URL this service is running on. For development purposes, you can install ngrok then run `ngrok http 8080` which will forward traffic on https to a specified ngrok URL. You would then take that host (without https) and put it in your docker compose before starting up. 30 | 31 | ## How to Use AIP tokens to access ATProto / Bluesky: 32 | 33 | Please see this [example usage file in python](https://gist.github.com/DGaffney/99f209e5ff9bb01cc50c4202c9c46554) and port to your use case! 34 | -------------------------------------------------------------------------------- /aip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/aip.png -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | [alembic] 2 | script_location = alembic 3 | 4 | prepend_sys_path = . 5 | 6 | version_path_separator = os 7 | 8 | sqlalchemy.url = %(DATABASE_URL)s 9 | 10 | 11 | [post_write_hooks] 12 | 13 | [loggers] 14 | keys = root,sqlalchemy,alembic 15 | 16 | [handlers] 17 | keys = console 18 | 19 | [formatters] 20 | keys = generic 21 | 22 | [logger_root] 23 | level = WARNING 24 | handlers = console 25 | qualname = 26 | 27 | [logger_sqlalchemy] 28 | level = WARNING 29 | handlers = 30 | qualname = sqlalchemy.engine 31 | 32 | [logger_alembic] 33 | level = INFO 34 | handlers = 35 | qualname = alembic 36 | 37 | [handler_console] 38 | class = StreamHandler 39 | args = (sys.stderr,) 40 | level = NOTSET 41 | formatter = generic 42 | 43 | [formatter_generic] 44 | format = %(levelname)-5.5s [%(name)s] %(message)s 45 | datefmt = %H:%M:%S 46 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging.config import fileConfig 3 | 4 | from sqlalchemy import engine_from_config, pool 5 | from alembic import context 6 | 7 | # Load Alembic Config 8 | config = context.config 9 | 10 | # Set up logging 11 | if config.config_file_name is not None: 12 | fileConfig(config.config_file_name) 13 | 14 | # Use DATABASE_URL from environment variables, with a fallback 15 | DATABASE_URL = os.getenv("DATABASE_URL") 16 | DATABASE_URL = DATABASE_URL.replace("postgresql+asyncpg", "postgresql+psycopg2") 17 | 18 | # Override sqlalchemy.url in Alembic config dynamically 19 | config.set_main_option("sqlalchemy.url", DATABASE_URL) 20 | 21 | # Metadata for migrations (update if needed) 22 | target_metadata = None 23 | 24 | 25 | def run_migrations_offline() -> None: 26 | """Run migrations in 'offline' mode.""" 27 | context.configure( 28 | url=DATABASE_URL, # Use the dynamic URL here 29 | target_metadata=target_metadata, 30 | literal_binds=True, 31 | dialect_opts={"paramstyle": "named"}, 32 | ) 33 | 34 | with context.begin_transaction(): 35 | context.run_migrations() 36 | 37 | 38 | def run_migrations_online() -> None: 39 | """Run migrations in 'online' mode.""" 40 | connectable = engine_from_config( 41 | config.get_section(config.config_ini_section, {}), 42 | prefix="sqlalchemy.", 43 | poolclass=pool.NullPool, 44 | ) 45 | 46 | with connectable.connect() as connection: 47 | context.configure(connection=connection, target_metadata=target_metadata) 48 | 49 | with context.begin_transaction(): 50 | context.run_migrations() 51 | 52 | 53 | if context.is_offline_mode(): 54 | run_migrations_offline() 55 | else: 56 | run_migrations_online() 57 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence, Union 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | ${imports if imports else ""} 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = ${repr(up_revision)} 16 | down_revision: Union[str, None] = ${repr(down_revision)} 17 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} 18 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} 19 | 20 | 21 | def upgrade() -> None: 22 | ${upgrades if upgrades else "pass"} 23 | 24 | 25 | def downgrade() -> None: 26 | ${downgrades if downgrades else "pass"} 27 | -------------------------------------------------------------------------------- /alembic/versions/380b77abd479_init.py: -------------------------------------------------------------------------------- 1 | """init 2 | 3 | Revision ID: 380b77abd479 4 | Revises: 5 | Create Date: 2025-02-03 16:15:52.235674 6 | 7 | """ 8 | 9 | from typing import Sequence, Union 10 | 11 | from alembic import op 12 | import sqlalchemy as sa 13 | 14 | 15 | # revision identifiers, used by Alembic. 16 | revision: str = "380b77abd479" 17 | down_revision: Union[str, None] = None 18 | branch_labels: Union[str, Sequence[str], None] = None 19 | depends_on: Union[str, Sequence[str], None] = None 20 | 21 | 22 | def upgrade() -> None: 23 | op.create_table( 24 | "handles", 25 | sa.Column("guid", sa.String(512), primary_key=True), 26 | sa.Column("did", sa.String(512), nullable=False), 27 | sa.Column("handle", sa.String(512), nullable=False), 28 | sa.Column("pds", sa.String(512), nullable=False), 29 | ) 30 | op.create_index("idx_handles_did", "handles", ["did"], unique=True) 31 | op.create_index("idx_handles_handle", "handles", ["handle"]) 32 | 33 | op.create_table( 34 | "oauth_requests", 35 | sa.Column("oauth_state", sa.String(64), primary_key=True), 36 | sa.Column("issuer", sa.String(512), nullable=False), 37 | sa.Column("guid", sa.String(512), nullable=False), 38 | sa.Column("pkce_verifier", sa.String(128), nullable=False), 39 | sa.Column("secret_jwk_id", sa.String(32), nullable=False), 40 | sa.Column("dpop_jwk", sa.JSON, nullable=False), 41 | sa.Column("destination", sa.String(512), nullable=False), 42 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), 43 | sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), 44 | ) 45 | op.create_index("idx_oauth_requests_guid", "oauth_requests", ["guid"]) 46 | op.create_index("idx_oauth_requests_expires", "oauth_requests", ["expires_at"]) 47 | 48 | op.create_table( 49 | "oauth_sessions", 50 | sa.Column("session_group", sa.String(64), primary_key=True), 51 | sa.Column("access_token", sa.String(1024), nullable=False), 52 | sa.Column("guid", sa.String(512), nullable=False), 53 | sa.Column("refresh_token", sa.String(512), nullable=False), 54 | sa.Column("issuer", sa.String(512), nullable=False), 55 | sa.Column("secret_jwk_id", sa.String(32), nullable=False), 56 | sa.Column("dpop_jwk", sa.JSON, nullable=False), 57 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), 58 | sa.Column( 59 | "access_token_expires_at", sa.DateTime(timezone=True), nullable=False 60 | ), 61 | sa.Column("hard_expires_at", sa.DateTime(timezone=True), nullable=False), 62 | ) 63 | op.create_index("idx_oauth_sessions_guid", "oauth_sessions", ["guid"]) 64 | op.create_index( 65 | "idx_oauth_sessions_expires", "oauth_sessions", ["access_token_expires_at"] 66 | ) 67 | 68 | op.create_table( 69 | "atproto_app_passwords", 70 | sa.Column("guid", sa.String(512), primary_key=True), 71 | sa.Column("app_password", sa.String(512), nullable=False), 72 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), 73 | ) 74 | 75 | op.create_table( 76 | "atproto_app_password_sessions", 77 | sa.Column("guid", sa.String(512), primary_key=True), 78 | sa.Column("access_token", sa.String(1024), nullable=False), 79 | sa.Column( 80 | "access_token_expires_at", sa.DateTime(timezone=True), nullable=False 81 | ), 82 | sa.Column("refresh_token", sa.String(512), nullable=False), 83 | sa.Column( 84 | "refresh_token_expires_at", sa.DateTime(timezone=True), nullable=False 85 | ), 86 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), 87 | ) 88 | 89 | # `guid` has permission `permission` on `target_guid` 90 | op.create_table( 91 | "guid_permissions", 92 | sa.Column("guid", sa.String(512), nullable=False), 93 | sa.Column("target_guid", sa.String(512), nullable=False), 94 | sa.Column("permission", sa.Integer, nullable=False), 95 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), 96 | ) 97 | op.create_primary_key( 98 | "pk_guid_permissions", "guid_permissions", ["guid", "target_guid"] 99 | ) 100 | op.create_index( 101 | "idx_guid_permissions_target_guid", "guid_permissions", ["target_guid"] 102 | ) 103 | 104 | 105 | def downgrade() -> None: 106 | op.drop_table("handles") 107 | op.drop_table("oauth_requests") 108 | op.drop_table("oauth_sessions") 109 | op.drop_table("atproto_app_passwords") 110 | op.drop_table("guid_permissions") 111 | op.drop_table("atproto_app_password_sessions") 112 | -------------------------------------------------------------------------------- /bin/scripts/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e # Exit immediately if a command exits with a non-zero status 3 | 4 | # Run init script 5 | /app/init.sh 6 | 7 | # Start the AIP server 8 | echo "Starting AIP server..." 9 | exec pdm run aipserver 10 | -------------------------------------------------------------------------------- /bin/scripts/init-db.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Create the database if it doesn't exist 5 | psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL 6 | CREATE DATABASE aip_db; 7 | EOSQL 8 | -------------------------------------------------------------------------------- /bin/scripts/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | echo "Waiting for database connection..." 3 | until pg_isready -h db -p 5432 -U aip; do 4 | sleep 2 5 | done 6 | echo "Database is ready." 7 | 8 | # Ensure the database exists 9 | echo "Checking if database exists..." 10 | DB_EXISTS=$(PGPASSWORD="aip_password" psql -h db -U aip -tAc "SELECT 1 FROM pg_database WHERE datname='aip_db'") 11 | if [ "$DB_EXISTS" != "1" ]; then 12 | echo "Database aip_db not found, creating..." 13 | PGPASSWORD="aip_password" createdb -h db -U aip aip_db 14 | else 15 | echo "Database already exists." 16 | fi 17 | 18 | # Generate signing keys if they don't exist 19 | if [ ! -f signing_keys.json ]; then 20 | echo "Generating signing keys..." 21 | SIGNING_KEY=$(pdm run aiputil gen-jwk) 22 | if [ -n "$SIGNING_KEY" ]; then 23 | echo "{\"keys\":[$SIGNING_KEY]}" > signing_keys.json 24 | echo "Signing keys generated." 25 | else 26 | echo "Error generating signing keys!" 27 | exit 1 28 | fi 29 | else 30 | echo "Signing keys already exist." 31 | fi 32 | 33 | # Extract 'kid' values from signing_keys.json 34 | if command -v jq >/dev/null 2>&1; then 35 | export ACTIVE_SIGNING_KEYS=$(jq -c '[.keys[].kid]' signing_keys.json) 36 | else 37 | echo "Error: jq is required but not installed. Install jq to continue." 38 | exit 1 39 | fi 40 | 41 | # Run Alembic migrations 42 | echo "Running Alembic migrations..." 43 | pdm run alembic upgrade head || { echo "Alembic migrations failed!"; exit 1; } 44 | 45 | 46 | # Mark initialization as complete 47 | touch /app/init_done 48 | echo "Initialization complete." 49 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | db: 3 | image: postgres:17 4 | container_name: aip_db 5 | restart: always 6 | environment: 7 | POSTGRES_USER: aip 8 | POSTGRES_PASSWORD: aip_password 9 | POSTGRES_DB: aip_db 10 | volumes: 11 | - aip_db_data:/var/lib/postgresql/data 12 | ports: 13 | - "5432:5432" 14 | healthcheck: 15 | test: ["CMD-SHELL", "pg_isready -U aip -d aip_db"] 16 | interval: 5s 17 | retries: 5 18 | timeout: 3s 19 | 20 | valkey: 21 | image: valkey/valkey:8-alpine 22 | container_name: aip_valkey 23 | restart: always 24 | ports: 25 | - "6379:6379" 26 | 27 | telegraf: 28 | image: telegraf:latest 29 | container_name: aip_telegraf 30 | restart: always 31 | depends_on: 32 | - db 33 | - valkey 34 | ports: 35 | - "8125:8125/udp" 36 | volumes: 37 | - ./telegraf.conf:/etc/telegraf/telegraf.conf 38 | 39 | aip: 40 | build: . 41 | container_name: aip_service 42 | command: ["pdm", "run", "aipserver"] 43 | restart: always 44 | depends_on: 45 | db: 46 | condition: service_healthy 47 | valkey: 48 | condition: service_started 49 | telegraf: 50 | condition: service_started 51 | environment: 52 | DATABASE_URL: postgresql+asyncpg://aip:aip_password@db:5432/aip_db 53 | REDIS_URL: redis://valkey:6379/0 54 | DEBUG: "true" 55 | WORKER_ID: "dev1" 56 | ACTIVE_SIGNING_KEYS: '["01JNEKAHBPFQYJX3RS7HH7W2RY"]' 57 | JSON_WEB_KEYS: /app/signing_keys.json 58 | ports: 59 | - "8080:8080" 60 | - "5100:5100" 61 | volumes: 62 | - .:/app # Mount local directory to container 63 | - /app/.venv # Preserve virtual environment 64 | # TODO: move to Tilt-based debugging 65 | # volumes: 66 | # - .:/app 67 | # - /app/.venv 68 | 69 | volumes: 70 | aip_db_data: 71 | influxdb_data: 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "aip" 7 | version = "0.2.4" 8 | description = "ATmosphere Authentication, Identity, and Permission Proxy" 9 | requires-python = ">=3.13" 10 | 11 | license = { file = "LICENSE" } 12 | 13 | authors = [{ name = "Nick Gerakines", email = "nick.gerakines@gmail.com" }] 14 | maintainers = [{ name = "Nick Gerakines", email = "nick.gerakines@gmail.com" }] 15 | 16 | classifiers = [ 17 | "Development Status :: 3 - Alpha", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.13", 20 | ] 21 | 22 | dependencies = [ 23 | "sqlalchemy>=2.0.37", 24 | "pydantic>=2.10.6", 25 | "aiohttp>=3.11.11", 26 | "jinja2>=3.1.5", 27 | "aiohttp-jinja2>=1.6", 28 | "alembic>=1.14.1", 29 | "psycopg2>=2.9.10", 30 | "jwcrypto>=1.5.6", 31 | "pydantic-settings>=2.7.1", 32 | "python-json-logger>=3.2.1", 33 | "aiodns>=3.2.0", 34 | "asyncpg>=0.30.0", 35 | "python-ulid>=3.0.0", 36 | "cryptography>=44.0.1", 37 | "redis>=5.2.1", 38 | "aio-statsd>=0.2.9", 39 | "sentry-sdk>=2.24.1", 40 | ] 41 | readme = "README.md" 42 | 43 | [project.optional-dependencies] 44 | dev = ["flake8>=7.1.1", "pytest>=8.3.4", "black>=25.1.0"] 45 | test = ["coverage"] 46 | 47 | [project.urls] 48 | "Homepage" = "https://github.com/graze-social/aip" 49 | "Bug Reports" = "https://github.com/graze-social/aip/issues" 50 | "Source" = "https://github.com/graze-social/aip" 51 | 52 | [project.scripts] 53 | aipserver = "social.graze.aip.app.__main__:main" 54 | resolve = "social.graze.aip.resolve.__main__:main" 55 | aiputil = "social.graze.aip.app.util.__main__:main" 56 | 57 | [tool.pdm] 58 | distribution = true 59 | 60 | [tool.pdm.build] 61 | includes = ["src/social", "LICENSE"] 62 | 63 | [tool.pytest.ini_options] 64 | pythonpath = ["src/"] 65 | 66 | [tool.pyright] 67 | venvPath = "." 68 | venv = ".venv" 69 | -------------------------------------------------------------------------------- /src/social/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/__init__.py -------------------------------------------------------------------------------- /src/social/graze/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/graze/__init__.py -------------------------------------------------------------------------------- /src/social/graze/aip/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AIP - AT Protocol Identity Provider 3 | 4 | This module implements an AT Protocol Identity Provider service that handles user authentication 5 | and identity resolution for the Bluesky/AT Protocol ecosystem. It serves as a bridge between 6 | users and AT Protocol services, managing authentication flows, token management, and identity 7 | resolution. 8 | 9 | Key Components: 10 | - app: Web application layer with request handlers and server configuration 11 | - atproto: Integration with AT Protocol, handling authentication and PDS communication 12 | - model: Database models for storing authentication data and user information 13 | - resolve: Identity resolution utilities for AT Protocol DIDs and handles 14 | 15 | Architecture Overview: 16 | 1. Authentication Flow: 17 | - User initiates OAuth flow through the service 18 | - Service verifies user identity with AT Protocol PDS 19 | - Secure tokens are issued and managed 20 | 21 | 2. Identity Resolution: 22 | - Resolves user handles to DIDs through DNS and HTTP mechanisms 23 | - Resolves DIDs to canonical data (handle, PDS location) 24 | 25 | 3. Token Management: 26 | - Background tasks refresh tokens before expiry 27 | - Redis-backed token caching 28 | - Task distribution using work queues 29 | 30 | The service is designed with security, performance, and reliability in mind, 31 | following OAuth 2.0 and AT Protocol specifications. 32 | """ -------------------------------------------------------------------------------- /src/social/graze/aip/app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AIP Application Layer 3 | 4 | This package implements the web application layer for the AIP service, handling HTTP requests 5 | and responses using the aiohttp framework. It provides handlers for OAuth flows, app password 6 | authentication, and internal API endpoints. 7 | 8 | Key Components: 9 | - __main__.py: Entry point for running the application 10 | - server.py: Web server configuration and middleware setup 11 | - config.py: Configuration management using Pydantic settings 12 | - handlers/: Request handlers for different endpoints 13 | - tasks.py: Background tasks for token refresh and health monitoring 14 | - cors.py: CORS handling for cross-origin requests 15 | - util/: Utility functions for the application layer 16 | 17 | The application uses several middleware layers: 18 | - CORS middleware for handling cross-origin requests 19 | - Statsd middleware for metrics collection 20 | - Sentry middleware for error reporting 21 | 22 | It provides the following main endpoints: 23 | - OAuth authentication endpoints (/auth/atproto/*) 24 | - Internal API endpoints (/internal/api/*) 25 | - XRPC proxy endpoints (/xrpc/*) 26 | """ -------------------------------------------------------------------------------- /src/social/graze/aip/app/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from aiohttp import web 3 | import logging 4 | from logging.config import dictConfig 5 | import json 6 | 7 | 8 | def configure_logging(): 9 | logging_config_file = os.getenv("LOGGING_CONFIG_FILE", "") 10 | 11 | if len(logging_config_file) > 0: 12 | with open(logging_config_file) as fl: 13 | dictConfig(json.load(fl)) 14 | return 15 | 16 | logging.basicConfig() 17 | logging.getLogger().setLevel(logging.DEBUG) 18 | 19 | 20 | def main(): 21 | configure_logging() 22 | 23 | from social.graze.aip.app.server import start_web_server 24 | 25 | web.run_app(start_web_server()) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration Module for AIP Service 3 | 4 | This module defines the configuration system for the AIP (AT Protocol Identity Provider) service, 5 | using Pydantic for settings validation and dependency injection through AppKeys. 6 | 7 | The configuration follows these principles: 8 | 1. Environment-based configuration with sensible defaults 9 | 2. Strong validation and typing through Pydantic 10 | 3. Dependency injection pattern using aiohttp's app context 11 | 4. Secure handling of cryptographic materials 12 | 13 | The Settings class serves as the central configuration point, loaded from environment variables 14 | with defaults suitable for development environments. All application components access settings 15 | and shared resources through typed AppKeys to maintain clean dependency injection. 16 | 17 | Key configuration areas include: 18 | - Service identification and networking 19 | - Database and cache connections 20 | - Cryptographic materials (signing keys, encryption) 21 | - Background processing configuration 22 | - UI customization 23 | """ 24 | 25 | import os 26 | import asyncio 27 | from typing import Annotated, Final, List, Optional 28 | import logging 29 | from aio_statsd import TelegrafStatsdClient 30 | from jwcrypto import jwk 31 | from pydantic import ( 32 | AliasChoices, 33 | Field, 34 | field_validator, 35 | PostgresDsn, 36 | RedisDsn, 37 | ) 38 | import base64 39 | from pydantic_settings import BaseSettings, NoDecode 40 | from aiohttp import web 41 | from cryptography.fernet import Fernet 42 | from sqlalchemy.ext.asyncio import ( 43 | AsyncEngine, 44 | async_sessionmaker, 45 | AsyncSession, 46 | ) 47 | from aiohttp import ClientSession 48 | from redis import asyncio as redis 49 | 50 | from social.graze.aip.model.health import HealthGauge 51 | 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | 56 | class Settings(BaseSettings): 57 | """ 58 | Application settings for the AIP service. 59 | 60 | This class uses Pydantic's BaseSettings to automatically load values from environment 61 | variables, with sensible defaults for development environments. It handles validation, 62 | type conversion, and provides centralized configuration management. 63 | 64 | Settings are organized into the following categories: 65 | - Environment and debugging 66 | - Network and service identification 67 | - Database and cache connections 68 | - Security and cryptography 69 | - Background task configuration 70 | - Monitoring and observability 71 | - UI customization 72 | 73 | Environment variables are automatically mapped to settings fields, with aliases 74 | provided for backward compatibility. For example, the database connection string 75 | can be set with either PG_DSN or DATABASE_URL environment variables. 76 | """ 77 | 78 | # Environment and debugging settings 79 | debug: bool = False 80 | """ 81 | Enable debug mode for verbose logging and development features. 82 | Set with DEBUG=true environment variable. 83 | """ 84 | 85 | allowed_domains: str = "https://www.graze.social, https://sky-feeder-git-astro-graze.vercel.app" 86 | """ 87 | Comma-separated list of domains allowed for CORS. 88 | Set with ALLOWED_DOMAINS environment variable. 89 | """ 90 | 91 | # Network and service identification settings 92 | http_port: int = Field(alias="port", default=5100) 93 | """ 94 | HTTP port for the service to listen on. 95 | Set with PORT environment variable. 96 | """ 97 | 98 | external_hostname: str = "aip_service" 99 | """ 100 | Public hostname for the service, used for generating callback URLs. 101 | Set with EXTERNAL_HOSTNAME environment variable. 102 | """ 103 | 104 | plc_hostname: str = "plc.directory" 105 | """ 106 | Hostname for the PLC directory service for DID resolution. 107 | Set with PLC_HOSTNAME environment variable. 108 | """ 109 | 110 | # Monitoring and error reporting 111 | sentry_dsn: Optional[str] = None 112 | """ 113 | Sentry DSN for error reporting. Optional, no error reporting if not set. 114 | Set with SENTRY_DSN environment variable. 115 | """ 116 | 117 | # Database and cache connections 118 | redis_dsn: RedisDsn = Field( 119 | "redis://valkey:6379/1?decode_responses=True", 120 | validation_alias=AliasChoices("redis_dsn", "redis_url"), 121 | ) # type: ignore 122 | """ 123 | Redis connection string for caching and background tasks. 124 | Set with REDIS_DSN or REDIS_URL environment variables. 125 | Default: redis://valkey:6379/1?decode_responses=True 126 | """ 127 | 128 | pg_dsn: PostgresDsn = Field( 129 | "postgresql+asyncpg://postgres:password@db/aip", 130 | validation_alias=AliasChoices("pg_dsn", "database_url"), 131 | ) # type: ignore 132 | """ 133 | PostgreSQL connection string for database access. 134 | Set with PG_DSN or DATABASE_URL environment variables. 135 | Default: postgresql+asyncpg://postgres:password@db/aip 136 | """ 137 | 138 | # Security and cryptography settings 139 | json_web_keys: Annotated[jwk.JWKSet, NoDecode] = jwk.JWKSet() 140 | """ 141 | JSON Web Key Set containing signing keys for JWT operations. 142 | Can be set to a JWKSet object or path to a JSON file containing keys. 143 | Set with JSON_WEB_KEYS environment variable. 144 | """ 145 | 146 | active_signing_keys: List[str] = list() 147 | """ 148 | List of key IDs (kid) from json_web_keys that should be used for signing. 149 | Set with ACTIVE_SIGNING_KEYS environment variable as comma-separated values. 150 | """ 151 | 152 | service_auth_keys: List[str] = list() 153 | """ 154 | List of key IDs (kid) from json_web_keys used for service-to-service auth. 155 | Set with SERVICE_AUTH_KEYS environment variable as comma-separated values. 156 | """ 157 | 158 | encryption_key: Fernet = Fernet(Fernet.generate_key()) 159 | """ 160 | Fernet symmetric encryption key for sensitive data. 161 | Can be set to a Fernet object or base64-encoded key string. 162 | Set with ENCRYPTION_KEY environment variable. 163 | """ 164 | 165 | # Worker identification 166 | worker_id: str 167 | """ 168 | Unique identifier for this worker instance (required, no default). 169 | Used to distribute work among multiple instances. 170 | Set with WORKER_ID environment variable. 171 | """ 172 | 173 | # Background processing configuration 174 | refresh_queue_oauth: str = "refresh_queue:oauth" 175 | """ 176 | Redis queue name for OAuth token refresh tasks. 177 | Set with REFRESH_QUEUE_OAUTH environment variable. 178 | """ 179 | 180 | refresh_queue_app_password: str = "refresh_queue:app_password" 181 | """ 182 | Redis queue name for App Password refresh tasks. 183 | Set with REFRESH_QUEUE_APP_PASSWORD environment variable. 184 | """ 185 | 186 | # Token expiration settings 187 | app_password_access_token_expiry: int = 720 # 12 minutes 188 | """ 189 | Expiration time in seconds for app password access tokens. 190 | Set with APP_PASSWORD_ACCESS_TOKEN_EXPIRY environment variable. 191 | Default: 720 (12 minutes) 192 | """ 193 | 194 | app_password_refresh_token_expiry: int = 7776000 # 90 days 195 | """ 196 | Expiration time in seconds for app password refresh tokens. 197 | Set with APP_PASSWORD_REFRESH_TOKEN_EXPIRY environment variable. 198 | Default: 7776000 (90 days) 199 | """ 200 | 201 | token_refresh_before_expiry_ratio: float = 0.8 202 | """ 203 | Ratio of token lifetime to wait before refreshing. 204 | For example, 0.8 means tokens are refreshed after 80% of their lifetime. 205 | Set with TOKEN_REFRESH_BEFORE_EXPIRY_RATIO environment variable. 206 | Default: 0.8 207 | """ 208 | 209 | oauth_refresh_max_retries: int = 3 210 | """ 211 | Maximum number of retry attempts for failed OAuth refresh operations. 212 | Set with OAUTH_REFRESH_MAX_RETRIES environment variable. 213 | Default: 3 214 | """ 215 | 216 | oauth_refresh_retry_base_delay: int = 300 217 | """ 218 | Base delay in seconds for OAuth refresh retry attempts (exponential backoff). 219 | Actual delay = base_delay * (2 ^ retry_attempt) 220 | Set with OAUTH_REFRESH_RETRY_BASE_DELAY environment variable. 221 | Default: 300 (5 minutes) 222 | """ 223 | 224 | default_destination: str = "https://localhost:5100/auth/atproto/debug" 225 | """ 226 | Default redirect destination after authentication if none specified. 227 | Set with DEFAULT_DESTINATION environment variable. 228 | """ 229 | 230 | # Monitoring and observability settings 231 | statsd_host: str = Field(alias="TELEGRAF_HOST", default="telegraf") 232 | """ 233 | StatsD/Telegraf host for metrics collection. 234 | Set with TELEGRAF_HOST environment variable. 235 | """ 236 | 237 | statsd_port: int = Field(alias="TELEGRAF_PORT", default=8125) 238 | """ 239 | StatsD/Telegraf port for metrics collection. 240 | Set with TELEGRAF_PORT environment variable. 241 | """ 242 | 243 | statsd_prefix: str = "aip" 244 | """ 245 | Prefix for all StatsD metrics from this service. 246 | Set with STATSD_PREFIX environment variable. 247 | """ 248 | 249 | # UI customization settings for login page 250 | svg_logo: str = "https://www.graze.social/logo.svg" 251 | """URL for the logo displayed on the login page""" 252 | 253 | brand_name: str = "Graze" 254 | """Brand name displayed on the login page""" 255 | 256 | destination: str = "https://graze.social/app/auth/callback" 257 | """Default destination URL after authentication""" 258 | 259 | background_from: str = "#0588f0" 260 | """Starting gradient color for login page background""" 261 | 262 | background_to: str = "#5eb1ef" 263 | """Ending gradient color for login page background""" 264 | 265 | text_color: str = "#FFFFFF" 266 | """Text color for login page""" 267 | 268 | form_color: str = "#FFFFFF" 269 | """Form background color for login page""" 270 | 271 | @field_validator("json_web_keys", mode="before") 272 | @classmethod 273 | def decode_json_web_keys(cls, v) -> jwk.JWKSet: 274 | """ 275 | Validate and process the json_web_keys setting. 276 | 277 | This validator accepts either: 278 | - An existing JWKSet object (for programmatic configuration) 279 | - A file path to a JSON file containing a JWK Set 280 | 281 | Args: 282 | v: The input value to validate 283 | 284 | Returns: 285 | jwk.JWKSet: A valid JWKSet object 286 | 287 | Raises: 288 | ValueError: If the input is neither a JWKSet nor a valid file path 289 | """ 290 | if isinstance(v, jwk.JWKSet): # If it's already a JWKSet, return it directly 291 | return v 292 | elif isinstance(v, str): # If it's a file path, load from file 293 | with open(v) as fd: 294 | data = fd.read() 295 | return jwk.JWKSet.from_json(data) 296 | raise ValueError( 297 | "json_web_keys must be a JWKSet object or a valid JSON file path" 298 | ) 299 | 300 | @field_validator("encryption_key", mode="before") 301 | @classmethod 302 | def decode_encryption_key(cls, v) -> Fernet: 303 | """ 304 | Validate and process the encryption_key setting. 305 | 306 | This validator accepts either: 307 | - An existing Fernet object (for programmatic configuration) 308 | - A base64-encoded string containing a Fernet key 309 | 310 | Args: 311 | v: The input value to validate 312 | 313 | Returns: 314 | Fernet: A valid Fernet encryption object 315 | 316 | Raises: 317 | ValueError: If the input is neither a Fernet object nor a valid base64 key 318 | """ 319 | if isinstance(v, Fernet): # Already a Fernet instance, return it 320 | return v 321 | elif isinstance(v, str): # Decode from a base64-encoded string 322 | key_data = base64.b64decode(v) 323 | return Fernet(key_data) 324 | raise ValueError( 325 | "encryption_key must be a Fernet object or a base64-encoded key string" 326 | ) 327 | 328 | 329 | # Background task queue constants 330 | OAUTH_REFRESH_QUEUE = "auth_session:oauth:refresh" 331 | """ 332 | Redis sorted set key for scheduling OAuth token refresh operations. 333 | Contains session_group IDs with refresh timestamps as scores. 334 | """ 335 | 336 | OAUTH_REFRESH_RETRY_QUEUE = "auth_session:oauth:refresh:retry" 337 | """ 338 | Redis hash key for tracking OAuth refresh retry attempts. 339 | Keys are session_group IDs, values are retry counts. 340 | """ 341 | 342 | APP_PASSWORD_REFRESH_QUEUE = "auth_session:app-password:refresh" 343 | """ 344 | Redis sorted set key for scheduling App Password refresh operations. 345 | Contains user GUIDs with refresh timestamps as scores. 346 | """ 347 | 348 | # Application context keys for dependency injection 349 | SettingsAppKey: Final = web.AppKey("settings", Settings) 350 | """AppKey for accessing the application settings""" 351 | 352 | DatabaseAppKey: Final = web.AppKey("database", AsyncEngine) 353 | """AppKey for accessing the SQLAlchemy async database engine""" 354 | 355 | DatabaseSessionMakerAppKey: Final = web.AppKey( 356 | "database_session_maker", async_sessionmaker[AsyncSession] 357 | ) 358 | """AppKey for accessing the SQLAlchemy async session factory""" 359 | 360 | SessionAppKey: Final = web.AppKey("http_session", ClientSession) 361 | """AppKey for accessing the shared aiohttp client session""" 362 | 363 | RedisPoolAppKey: Final = web.AppKey("redis_pool", redis.ConnectionPool) 364 | """AppKey for accessing the Redis connection pool""" 365 | 366 | RedisClientAppKey: Final = web.AppKey("redis_client", redis.Redis) 367 | """AppKey for accessing the Redis client""" 368 | 369 | HealthGaugeAppKey: Final = web.AppKey("health_gauge", HealthGauge) 370 | """AppKey for accessing the health monitoring gauge""" 371 | 372 | OAuthRefreshTaskAppKey: Final = web.AppKey("oauth_refresh_task", asyncio.Task[None]) 373 | """AppKey for the background task that refreshes OAuth tokens""" 374 | 375 | AppPasswordRefreshTaskAppKey: Final = web.AppKey( 376 | "app_password_refresh_task", asyncio.Task[None] 377 | ) 378 | """AppKey for the background task that refreshes App Passwords""" 379 | 380 | TickHealthTaskAppKey: Final = web.AppKey("tick_health_task", asyncio.Task[None]) 381 | """AppKey for the background task that monitors service health""" 382 | 383 | TelegrafStatsdClientAppKey: Final = web.AppKey( 384 | "telegraf_statsd_client", TelegrafStatsdClient 385 | ) 386 | """AppKey for the Telegraf/StatsD metrics client""" -------------------------------------------------------------------------------- /src/social/graze/aip/app/cors.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from urllib.parse import urlparse 3 | 4 | def get_cors_headers( 5 | origin_value: Optional[str], path: str, debug: bool 6 | ) -> Dict[str, str]: 7 | """Return appropriate CORS headers based on origin and path.""" 8 | allowed_origins = { 9 | "https://graze.social", 10 | "https://www.graze.social", 11 | "https://sky-feeder-git-astro-graze.vercel.app", 12 | } 13 | 14 | allowed_debug_hosts = { 15 | "localhost", 16 | "127.0.0.1", 17 | } 18 | 19 | headers = { 20 | "Access-Control-Allow-Methods": "GET, POST, OPTIONS", 21 | "Access-Control-Allow-Headers": ( 22 | "Keep-Alive, User-Agent, X-Requested-With, " 23 | "If-Modified-Since, Cache-Control, Content-Type, " 24 | "Authorization, X-Subject, X-Service" 25 | ), 26 | "Vary": "Origin" 27 | } 28 | 29 | if path.startswith("/auth/"): 30 | headers["Access-Control-Allow-Origin"] = "*" 31 | elif origin_value: 32 | parsed = urlparse(origin_value) 33 | base = f"{parsed.scheme}://{parsed.hostname}" if parsed.scheme and parsed.hostname else origin_value 34 | 35 | if base in allowed_origins: 36 | headers["Access-Control-Allow-Origin"] = origin_value 37 | elif debug and parsed.hostname in allowed_debug_hosts: 38 | headers["Access-Control-Allow-Origin"] = origin_value 39 | 40 | return headers 41 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/graze/aip/app/handlers/__init__.py -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/app_password.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone, timedelta 2 | import logging 3 | from typing import Optional 4 | from aiohttp import web 5 | from pydantic import BaseModel, ValidationError, field_validator 6 | from sqlalchemy import delete 7 | from sqlalchemy.dialects.postgresql import insert 8 | import sentry_sdk 9 | 10 | from social.graze.aip.app.config import ( 11 | APP_PASSWORD_REFRESH_QUEUE, 12 | DatabaseSessionMakerAppKey, 13 | RedisClientAppKey, 14 | TelegrafStatsdClientAppKey, 15 | ) 16 | from social.graze.aip.app.handlers.helpers import auth_token_helper 17 | from social.graze.aip.model.app_password import AppPassword 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class AppPasswordOperation(BaseModel): 24 | value: Optional[str] = None 25 | 26 | @field_validator("value") 27 | def app_password_check(cls, v: Optional[str]) -> Optional[str]: 28 | if v is None: 29 | return None 30 | 31 | if len(v) != 19: 32 | raise ValueError("invalid format") 33 | 34 | if v.count("-") != 3: 35 | raise ValueError("invalid format") 36 | 37 | return v 38 | 39 | 40 | async def handle_internal_app_password(request: web.Request) -> web.Response: 41 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 42 | redis_session = request.app[RedisClientAppKey] 43 | statsd_client = request.app[TelegrafStatsdClientAppKey] 44 | 45 | try: 46 | data = await request.read() 47 | app_password_operation = AppPasswordOperation.model_validate_json(data) 48 | except (OSError, ValidationError) as e: 49 | # TODO: Fix the returned error message when JSON fails because of pydantic validation functions. 50 | sentry_sdk.capture_exception(e) 51 | return web.json_response(status=400, data={"error": "Invalid JSON"}) 52 | 53 | try: 54 | async with (database_session_maker() as database_session,): 55 | auth_token = await auth_token_helper( 56 | database_session, statsd_client, request, allow_permissions=False 57 | ) 58 | if auth_token is None: 59 | return web.json_response(status=401, data={"error": "Not Authorized"}) 60 | 61 | now = datetime.now(timezone.utc) 62 | 63 | async with database_session.begin(): 64 | if app_password_operation.value is None: 65 | stmt = delete(AppPassword).where( 66 | AppPassword.guid == auth_token.guid 67 | ) 68 | await database_session.execute(stmt) 69 | await redis_session.zrem( 70 | APP_PASSWORD_REFRESH_QUEUE, auth_token.guid 71 | ) 72 | else: 73 | stmt = ( 74 | insert(AppPassword) 75 | .values( 76 | [ 77 | { 78 | "guid": auth_token.guid, 79 | "app_password": app_password_operation.value, 80 | "created_at": now, 81 | } 82 | ] 83 | ) 84 | .on_conflict_do_update( 85 | index_elements=["guid"], 86 | set_={"app_password": app_password_operation.value}, 87 | ) 88 | ) 89 | await database_session.execute(stmt) 90 | 91 | refresh_at = now + timedelta(0, 5) 92 | 93 | await redis_session.zadd( 94 | APP_PASSWORD_REFRESH_QUEUE, 95 | {auth_token.guid: int(refresh_at.timestamp())}, 96 | ) 97 | 98 | await database_session.commit() 99 | 100 | app_password_key = f"auth_session:app-password:{auth_token.guid}" 101 | await redis_session.delete(app_password_key) 102 | 103 | return web.Response(status=200) 104 | except web.HTTPException as e: 105 | sentry_sdk.capture_exception(e) 106 | logging.exception("handle_internal_permissions: web.HTTPException") 107 | raise e 108 | except Exception as e: 109 | sentry_sdk.capture_exception(e) 110 | logging.exception("handle_internal_permissions: Exception") 111 | return web.json_response(status=500, data={"error": "Internal Server Error"}) 112 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/credentials.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from aiohttp import web 3 | import sentry_sdk 4 | 5 | from social.graze.aip.app.config import ( 6 | DatabaseSessionMakerAppKey, 7 | TelegrafStatsdClientAppKey, 8 | ) 9 | from social.graze.aip.app.handlers.helpers import auth_token_helper 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | async def handle_internal_credentials(request: web.Request) -> web.Response: 15 | statsd_client = request.app[TelegrafStatsdClientAppKey] 16 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 17 | 18 | try: 19 | async with (database_session_maker() as database_session,): 20 | # TODO: Allow optional auth here. 21 | auth_token = await auth_token_helper( 22 | database_session, statsd_client, request 23 | ) 24 | if auth_token is None: 25 | return web.json_response(status=401, data={"error": "Not Authorized"}) 26 | except web.HTTPException as e: 27 | sentry_sdk.capture_exception(e) 28 | raise e 29 | except Exception as e: 30 | sentry_sdk.capture_exception(e) 31 | return web.json_response(status=500, data={"error": "Internal Server Error"}) 32 | 33 | if auth_token.app_password_session is not None: 34 | return web.json_response( 35 | { 36 | "type": "bearer", 37 | "token": auth_token.app_password_session.access_token, 38 | } 39 | ) 40 | 41 | if auth_token.oauth_session is not None: 42 | return web.json_response( 43 | { 44 | "type": "dpop", 45 | "token": auth_token.oauth_session.access_token, 46 | "jwk": auth_token.oauth_session.dpop_jwk, 47 | "issuer": auth_token.oauth_session.issuer, 48 | } 49 | ) 50 | 51 | return web.json_response( 52 | { 53 | "type": "none", 54 | } 55 | ) 56 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import datetime, timezone 3 | import json 4 | import logging 5 | from typing import ( 6 | Optional, 7 | Dict, 8 | ) 9 | from aio_statsd import TelegrafStatsdClient 10 | from aiohttp import web 11 | from jwcrypto import jwt 12 | from sqlalchemy import select 13 | from sqlalchemy.ext.asyncio import ( 14 | AsyncSession, 15 | ) 16 | import sentry_sdk 17 | 18 | from social.graze.aip.app.config import ( 19 | SettingsAppKey, 20 | ) 21 | from social.graze.aip.model.app_password import AppPasswordSession 22 | from social.graze.aip.model.handles import Handle 23 | from social.graze.aip.model.oauth import OAuthSession, Permission 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | @dataclass(repr=False, eq=False) 29 | class AuthToken: 30 | """ 31 | Represents an authenticated token with user identity and session information. 32 | 33 | This class contains both the authenticating user's information (guid, subject, handle) 34 | and the context for the current request (which may be different if using permissions 35 | to act on behalf of another user). 36 | 37 | Attributes: 38 | guid: The guid of the authenticating user 39 | subject: The DID of the authenticating user 40 | handle: The handle of the authenticating user 41 | 42 | context_service: The service URL for the context of the request 43 | context_guid: The guid for the context of the request 44 | context_subject: The DID for the context of the request 45 | context_pds: The PDS URL for the context of the request 46 | 47 | oauth_session: The OAuth session if authenticated via OAuth 48 | app_password_session: The App Password session if authenticated via App Password 49 | """ 50 | guid: str 51 | subject: str 52 | handle: str 53 | 54 | context_service: str 55 | context_guid: str 56 | context_subject: str 57 | context_pds: str 58 | 59 | oauth_session: Optional[OAuthSession] = None 60 | app_password_session: Optional[AppPasswordSession] = None 61 | 62 | 63 | class AuthenticationException(Exception): 64 | """ 65 | Exception raised for authentication failures. 66 | 67 | This exception class provides static methods for creating specific 68 | authentication failure instances with appropriate error messages. 69 | """ 70 | 71 | @staticmethod 72 | def jwt_subject_missing() -> "AuthenticationException": 73 | """JWT is missing the required 'sub' claim.""" 74 | return AuthenticationException("error-auth-helper-1000 JWT missing subject") 75 | 76 | @staticmethod 77 | def jwt_session_group_missing() -> "AuthenticationException": 78 | """JWT is missing the required 'grp' claim.""" 79 | return AuthenticationException( 80 | "error-auth-helper-1001 JWT missing session group" 81 | ) 82 | 83 | @staticmethod 84 | def session_not_found() -> "AuthenticationException": 85 | """No valid session was found for the authenticated user.""" 86 | return AuthenticationException( 87 | "error-auth-helper-1002 No valid session found" 88 | ) 89 | 90 | @staticmethod 91 | def session_expired() -> "AuthenticationException": 92 | """The session has expired and is no longer valid.""" 93 | return AuthenticationException( 94 | "error-auth-helper-1003 Session has expired" 95 | ) 96 | 97 | @staticmethod 98 | def handle_not_found() -> "AuthenticationException": 99 | """No handle record was found for the authenticated user.""" 100 | return AuthenticationException( 101 | "error-auth-helper-1004 Handle record not found" 102 | ) 103 | 104 | @staticmethod 105 | def permission_denied() -> "AuthenticationException": 106 | """User does not have permission to perform the requested action.""" 107 | return AuthenticationException( 108 | "error-auth-helper-1005 Permission denied" 109 | ) 110 | 111 | @staticmethod 112 | def unexpected(msg: str = "") -> "AuthenticationException": 113 | """An unexpected error occurred during authentication.""" 114 | return AuthenticationException( 115 | f"error-auth-helper-1999 Unexpected authentication error: {msg}" 116 | ) 117 | 118 | 119 | async def auth_token_helper( 120 | database_session: AsyncSession, 121 | statsd_client: TelegrafStatsdClient, 122 | request: web.Request, 123 | allow_permissions: bool = True, 124 | ) -> Optional[AuthToken]: 125 | """ 126 | Authenticate a request and return an AuthToken with user and context information. 127 | 128 | This helper enforces the following policies: 129 | * All API calls must be authenticated and require an `Authorization` header with a bearer token. 130 | * The `X-Subject` header can optionally specify the subject of the request. The value must be a known guid. 131 | * The `X-Service` header can optionally specify the hostname of the service providing the invoked XRPC method. 132 | 133 | The function validates the JWT token, retrieves the associated session information, and checks permissions 134 | if the request is acting on behalf of another user. 135 | 136 | Args: 137 | database_session: SQLAlchemy async session for database queries 138 | statsd_client: Statsd client for metrics 139 | request: The HTTP request to authenticate 140 | allow_permissions: Whether to allow acting on behalf of another user via permissions 141 | 142 | Returns: 143 | An AuthToken object if authentication succeeds, None otherwise 144 | 145 | Raises: 146 | AuthenticationException: If there's a specific authentication failure that should be reported 147 | """ 148 | 149 | # Check for Authorization header with Bearer token 150 | authorizations: Optional[str] = request.headers.getone("Authorization", None) 151 | if ( 152 | authorizations is None 153 | or not authorizations.startswith("Bearer ") 154 | or len(authorizations) < 8 155 | ): 156 | return None 157 | 158 | serialized_auth_token = authorizations[7:] 159 | 160 | settings = request.app[SettingsAppKey] 161 | 162 | try: 163 | # Validate the JWT token 164 | validated_auth_token = jwt.JWT( 165 | jwt=serialized_auth_token, key=settings.json_web_keys, algs=["ES256"] 166 | ) 167 | 168 | # Parse and validate claims 169 | auth_token_claims: Dict[str, str] = json.loads(validated_auth_token.claims) 170 | 171 | auth_token_subject: Optional[str] = auth_token_claims.get("sub", None) 172 | if auth_token_subject is None: 173 | raise AuthenticationException.jwt_subject_missing() 174 | 175 | auth_token_session_group: Optional[str] = auth_token_claims.get("grp", None) 176 | if auth_token_session_group is None: 177 | raise AuthenticationException.jwt_session_group_missing() 178 | 179 | now = datetime.now(timezone.utc) 180 | 181 | async with database_session.begin(): 182 | 183 | # 1. Get the OAuthSession from the database and validate it 184 | oauth_session_stmt = select(OAuthSession).where( 185 | OAuthSession.guid == auth_token_subject, 186 | 187 | # Sessions are not limited to the same session group because we 188 | # don't actually care about which session we end up getting. The 189 | # only requirement is that it's valid. 190 | # OAuthSession.session_group == auth_token_session_group, 191 | 192 | OAuthSession.access_token_expires_at > now, 193 | OAuthSession.hard_expires_at > now 194 | ).order_by(OAuthSession.created_at.desc()) 195 | 196 | oauth_session: Optional[OAuthSession] = ( 197 | await database_session.scalars(oauth_session_stmt) 198 | ).first() 199 | 200 | if oauth_session is None: 201 | raise AuthenticationException.session_not_found() 202 | 203 | # 2. Get the Handle from the database and validate it 204 | oauth_session_handle_stmt = select(Handle).where( 205 | Handle.guid == auth_token_subject 206 | ) 207 | oauth_session_handle_result = await database_session.scalars(oauth_session_handle_stmt) 208 | oauth_session_handle = oauth_session_handle_result.first() 209 | 210 | if oauth_session_handle is None: 211 | raise AuthenticationException.handle_not_found() 212 | 213 | # 3. Get the X-Subject header, defaulting to the current oauth session handle's guid 214 | x_subject: str = request.headers.getone( 215 | "X-Subject", oauth_session_handle.guid 216 | ) 217 | 218 | # If the subject of the request is the same as the subject of the auth token, 219 | # then we have everything we need and can return a fully formed AuthToken 220 | if x_subject == oauth_session_handle.guid: 221 | x_service: str = request.headers.getone( 222 | "X-Service", oauth_session_handle.pds 223 | ) 224 | 225 | # Look up app password session if it exists 226 | app_password_session_stmt = select(AppPasswordSession).where( 227 | AppPasswordSession.guid == oauth_session.guid, 228 | ) 229 | app_password_session: Optional[AppPasswordSession] = ( 230 | await database_session.scalars(app_password_session_stmt) 231 | ).first() 232 | 233 | return AuthToken( 234 | oauth_session=oauth_session, 235 | app_password_session=app_password_session, 236 | guid=oauth_session_handle.guid, 237 | subject=oauth_session_handle.did, 238 | handle=oauth_session_handle.handle, 239 | context_service=x_service, 240 | context_guid=oauth_session_handle.guid, 241 | context_subject=oauth_session_handle.did, 242 | context_pds=oauth_session_handle.pds, 243 | ) 244 | 245 | # If permissions are not allowed but the subject differs, deny the request 246 | if allow_permissions is False: 247 | raise AuthenticationException.permission_denied() 248 | 249 | # 4. Get the permission record for the oauth session handle to the x_repository guid 250 | permission_stmt = select(Permission).where( 251 | Permission.guid == oauth_session_handle.guid, 252 | Permission.target_guid == x_subject, 253 | Permission.permission > 0, 254 | ) 255 | permission: Optional[Permission] = ( 256 | await database_session.scalars(permission_stmt) 257 | ).first() 258 | 259 | # If no permission is found, deny access 260 | if permission is None: 261 | raise AuthenticationException.permission_denied() 262 | 263 | # Get the handle for the target subject 264 | subject_handle_stmt = select(Handle).where( 265 | Handle.guid == permission.target_guid 266 | ) 267 | subject_handle_result = await database_session.scalars(subject_handle_stmt) 268 | subject_handle = subject_handle_result.first() 269 | 270 | if subject_handle is None: 271 | raise AuthenticationException.handle_not_found() 272 | 273 | # Get the app password session for the target subject if it exists 274 | app_password_session_stmt = select(AppPasswordSession).where( 275 | AppPasswordSession.guid == subject_handle.guid, 276 | ) 277 | app_password_session: Optional[AppPasswordSession] = ( 278 | await database_session.scalars(app_password_session_stmt) 279 | ).first() 280 | 281 | # Get a valid OAuth session for the target subject 282 | target_oauth_session_stmt = select(OAuthSession).where( 283 | OAuthSession.guid == subject_handle.guid, 284 | OAuthSession.access_token_expires_at > now, 285 | OAuthSession.hard_expires_at > now 286 | ).order_by(OAuthSession.created_at.desc()) 287 | 288 | target_oauth_session: Optional[OAuthSession] = ( 289 | await database_session.scalars(target_oauth_session_stmt) 290 | ).first() 291 | 292 | # Get the service endpoint from the X-Service header or use the subject's PDS 293 | x_service: str = request.headers.getone("X-Service", subject_handle.pds) 294 | 295 | await database_session.commit() 296 | 297 | return AuthToken( 298 | oauth_session=target_oauth_session, 299 | app_password_session=app_password_session, 300 | guid=oauth_session_handle.guid, 301 | subject=oauth_session_handle.did, 302 | handle=oauth_session_handle.handle, 303 | context_service=x_service, 304 | context_guid=subject_handle.guid, 305 | context_subject=subject_handle.did, 306 | context_pds=subject_handle.pds, 307 | ) 308 | except AuthenticationException as e: 309 | sentry_sdk.capture_exception(e) 310 | raise 311 | except Exception as e: 312 | sentry_sdk.capture_exception(e) 313 | statsd_client.increment( 314 | "aip.auth.exception", 315 | 1, 316 | tag_dict={"exception": type(e).__name__}, 317 | ) 318 | logger.exception("auth_token_helper: Exception") 319 | return None -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/internal.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from aiohttp import web 5 | import sentry_sdk 6 | 7 | from social.graze.aip.app.config import ( 8 | DatabaseSessionMakerAppKey, 9 | HealthGaugeAppKey, 10 | SessionAppKey, 11 | SettingsAppKey, 12 | TelegrafStatsdClientAppKey, 13 | ) 14 | from social.graze.aip.app.handlers.helpers import auth_token_helper 15 | from social.graze.aip.model.handles import upsert_handle_stmt 16 | from social.graze.aip.resolve.handle import resolve_subject 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | async def handle_internal_me(request: web.Request): 22 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 23 | statsd_client = request.app[TelegrafStatsdClientAppKey] 24 | 25 | try: 26 | async with (database_session_maker() as database_session,): 27 | auth_token = await auth_token_helper( 28 | database_session, statsd_client, request, allow_permissions=False 29 | ) 30 | if auth_token is None: 31 | raise web.HTTPUnauthorized( 32 | body=json.dumps({"error": "Not Authorized"}), 33 | content_type="application/json", 34 | ) 35 | 36 | # TODO: Include has_app_password boolean in the response. 37 | 38 | return web.json_response( 39 | { 40 | "handle": auth_token.handle, 41 | "pds": auth_token.context_service, 42 | "did": auth_token.subject, 43 | "guid": auth_token.guid, 44 | "oauth_session_valid": auth_token.oauth_session is not None, 45 | "app_password_session_valid": auth_token.app_password_session 46 | is not None, 47 | } 48 | ) 49 | except web.HTTPException as e: 50 | sentry_sdk.capture_exception(e) 51 | raise e 52 | except Exception as e: 53 | sentry_sdk.capture_exception(e) 54 | raise web.HTTPInternalServerError( 55 | body=json.dumps({"error": "Internal Server Error"}), 56 | content_type="application/json", 57 | ) 58 | 59 | 60 | async def handle_internal_ready(request: web.Request): 61 | health_gauge = request.app[HealthGaugeAppKey] 62 | if await health_gauge.is_healthy(): 63 | return web.Response(status=200) 64 | return web.Response(status=503) 65 | 66 | 67 | async def handle_internal_alive(request: web.Request): 68 | return web.Response(status=200) 69 | 70 | 71 | async def handle_internal_resolve(request: web.Request): 72 | subjects = request.query.getall("subject", []) 73 | if len(subjects) == 0: 74 | return web.json_response([]) 75 | 76 | # Nick: This could be improved by using Redis to cache results. Eventually inputs should go into a queue to be 77 | # processed in the background and the results streamed back to the client via SSE. If this becomes a high volume 78 | # endpoint, then we should run our own PLC replica or consider tapping into jetstream. 79 | 80 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 81 | settings = request.app[SettingsAppKey] 82 | 83 | # TODO: Use a pydantic list structure for this. 84 | results = [] 85 | async with database_session_maker() as database_session: 86 | for subject in subjects: 87 | resolved_subject = await resolve_subject( 88 | request.app[SessionAppKey], settings.plc_hostname, subject 89 | ) 90 | if resolved_subject is None: 91 | continue 92 | async with database_session.begin(): 93 | stmt = upsert_handle_stmt( 94 | resolved_subject.did, resolved_subject.handle, resolved_subject.pds 95 | ) 96 | await database_session.execute(stmt) 97 | await database_session.commit() 98 | results.append(resolved_subject.model_dump()) 99 | return web.json_response(results) 100 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/oauth.py: -------------------------------------------------------------------------------- 1 | """ 2 | AT Protocol OAuth Handlers 3 | 4 | This module implements the web request handlers for OAuth authentication with AT Protocol. 5 | It provides endpoints for initiating authentication, handling callbacks from the AT Protocol 6 | authorization server, refreshing tokens, and debugging authentication information. 7 | 8 | OAuth Flow with AT Protocol: 9 | 1. User enters their handle/DID in the login form 10 | 2. Application initiates OAuth flow by redirecting to AT Protocol authorization server 11 | 3. User authenticates with their AT Protocol PDS (Personal Data Server) 12 | 4. PDS redirects back to the application with an authorization code 13 | 5. Application exchanges the code for access and refresh tokens 14 | 6. Application stores tokens and returns an auth token to the client 15 | 7. Tokens are refreshed before they expire 16 | 17 | The handlers in this module provide the following endpoints: 18 | - GET /auth/atproto - Login form for entering AT Protocol handle/DID 19 | - POST /auth/atproto - Submit login form to initiate OAuth flow 20 | - GET /auth/atproto/callback - OAuth callback from AT Protocol authorization server 21 | - GET /auth/atproto/refresh - Manually refresh tokens 22 | - GET /auth/atproto/debug - Display debug information about authentication 23 | - GET /.well-known/jwks.json - JWKS endpoint for key verification 24 | - GET /auth/atproto/client-metadata.json - OAuth client metadata 25 | """ 26 | 27 | import json 28 | import logging 29 | from typing import ( 30 | Optional, 31 | Dict, 32 | List, 33 | Any, 34 | ) 35 | from aiohttp import web 36 | import aiohttp_jinja2 37 | from jwcrypto import jwt 38 | from pydantic import BaseModel 39 | from sqlalchemy import select 40 | import sentry_sdk 41 | from urllib.parse import urlparse, urlencode, parse_qsl, urlunparse 42 | 43 | 44 | from social.graze.aip.app.config import ( 45 | DatabaseSessionMakerAppKey, 46 | RedisClientAppKey, 47 | SessionAppKey, 48 | SettingsAppKey, 49 | TelegrafStatsdClientAppKey, 50 | ) 51 | from social.graze.aip.app.cors import get_cors_headers 52 | from social.graze.aip.atproto.oauth import oauth_complete, oauth_init, oauth_refresh 53 | from social.graze.aip.model.handles import Handle 54 | from social.graze.aip.model.oauth import OAuthSession 55 | 56 | logger = logging.getLogger(__name__) 57 | 58 | 59 | def context_vars(settings): 60 | """ 61 | Create a context dictionary for template rendering with UI customization settings. 62 | 63 | This function extracts UI customization settings from the application settings 64 | to be passed to the template rendering engine. 65 | 66 | Args: 67 | settings: Application settings object 68 | 69 | Returns: 70 | Dict containing UI customization variables for templates 71 | """ 72 | return { 73 | "svg_logo": settings.svg_logo, 74 | "brand_name": settings.brand_name, 75 | "destination": settings.destination, 76 | "background_from": settings.background_from, 77 | "background_to": settings.background_to, 78 | "text_color": settings.text_color, 79 | "form_color": settings.form_color, 80 | } 81 | 82 | 83 | class ATProtocolOAuthClientMetadata(BaseModel): 84 | """ 85 | OAuth 2.0 Client Metadata for AT Protocol integration. 86 | 87 | This model represents the client metadata used for OAuth registration with AT Protocol. 88 | It follows the OAuth 2.0 Dynamic Client Registration Protocol (RFC 7591) with 89 | additional fields specific to AT Protocol requirements. 90 | 91 | The metadata is exposed at the client-metadata.json endpoint and is used by 92 | AT Protocol authorization servers to validate OAuth requests. 93 | """ 94 | client_id: str 95 | """Client identifier URI""" 96 | 97 | dpop_bound_access_tokens: bool 98 | """Whether access tokens are bound to DPoP proofs""" 99 | 100 | application_type: str 101 | """Type of application (web, native)""" 102 | 103 | redirect_uris: List[str] 104 | """List of allowed redirect URIs for this client""" 105 | 106 | client_uri: str 107 | """URI of the client's homepage""" 108 | 109 | grant_types: List[str] 110 | """OAuth grant types supported by this client""" 111 | 112 | response_types: List[str] 113 | """OAuth response types supported by this client""" 114 | 115 | scope: str 116 | """OAuth scopes requested by this client""" 117 | 118 | client_name: str 119 | """Human-readable name of the client application""" 120 | 121 | token_endpoint_auth_method: str 122 | """Authentication method for the token endpoint""" 123 | 124 | jwks_uri: str 125 | """URI of the client's JWKS (JSON Web Key Set)""" 126 | 127 | logo_uri: str 128 | """URI of the client's logo""" 129 | 130 | tos_uri: str 131 | """URI of the client's terms of service""" 132 | 133 | policy_uri: str 134 | """URI of the client's policy document""" 135 | 136 | subject_type: str 137 | """Subject type requested for responses""" 138 | 139 | token_endpoint_auth_signing_alg: str 140 | """Algorithm used for signing token endpoint authentication assertions""" 141 | 142 | 143 | async def handle_atproto_login(request: web.Request): 144 | """ 145 | Handle GET request to the AT Protocol login page. 146 | 147 | This handler renders the login form where users can enter their 148 | AT Protocol handle or DID to begin the authentication process. 149 | 150 | Args: 151 | request: HTTP request object 152 | 153 | Returns: 154 | HTTP response with rendered login template 155 | """ 156 | settings = request.app[SettingsAppKey] 157 | context = context_vars(settings) 158 | 159 | destination = request.query.get("destination") 160 | if destination: 161 | context["destination"] = destination 162 | 163 | return await aiohttp_jinja2.render_template_async( 164 | "atproto_login.html", request, context=context 165 | ) 166 | 167 | 168 | async def handle_atproto_login_submit(request: web.Request): 169 | """ 170 | Handle POST request from the AT Protocol login form. 171 | 172 | This handler processes the login form submission and initiates the OAuth flow. 173 | It extracts the subject (handle or DID) and optional destination from the form, 174 | then calls oauth_init to start the OAuth process. 175 | 176 | Request Parameters: 177 | subject: AT Protocol handle or DID 178 | destination: Optional redirect URL after authentication 179 | 180 | Args: 181 | request: HTTP request object 182 | 183 | Returns: 184 | HTTP redirect to the AT Protocol authorization server 185 | 186 | Raises: 187 | HTTPFound: To redirect to authorization server 188 | 189 | Flow: 190 | 1. Extract subject and destination from form 191 | 2. Initialize OAuth flow with oauth_init 192 | 3. Redirect user to authorization server 193 | """ 194 | settings = request.app[SettingsAppKey] 195 | data = await request.post() 196 | subject: Optional[str] = data.get("subject", None) # type: ignore 197 | destination: Optional[str] = data.get("destination", None) # type: ignore 198 | 199 | if subject is None: 200 | return await aiohttp_jinja2.render_template_async( 201 | "atproto_login.html", 202 | request, 203 | context=dict( 204 | **context_vars(settings), **{"error_message": "No subject provided"} 205 | ), 206 | ) 207 | 208 | http_session = request.app[SessionAppKey] 209 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 210 | statsd_client = request.app[TelegrafStatsdClientAppKey] 211 | redis_session = request.app[RedisClientAppKey] 212 | 213 | if destination is None: 214 | destination = settings.default_destination 215 | 216 | try: 217 | redirect_destination = await oauth_init( 218 | settings, 219 | statsd_client, 220 | http_session, 221 | database_session_maker, 222 | redis_session, 223 | subject, 224 | destination, 225 | ) 226 | except Exception as e: 227 | logger.exception("login error") 228 | 229 | sentry_sdk.capture_exception(e) 230 | # TODO: Return a localized error message. 231 | return await aiohttp_jinja2.render_template_async( 232 | "atproto_login.html", 233 | request, 234 | context=dict(**context_vars(settings), **{"error_message": str(e)}), 235 | ) 236 | raise web.HTTPFound( 237 | str(redirect_destination), 238 | headers=get_cors_headers(request.headers.get("Origin"), request.path, settings.debug), 239 | ) 240 | 241 | 242 | async def handle_atproto_callback(request: web.Request): 243 | """ 244 | Handle OAuth callback from AT Protocol authorization server. 245 | 246 | This handler processes the callback from the AT Protocol authorization server, 247 | exchanging the authorization code for access and refresh tokens, then redirecting 248 | the user to their final destination with an auth token. 249 | 250 | Query Parameters: 251 | state: OAuth state parameter to prevent CSRF 252 | iss: Issuer identifier (authorization server) 253 | code: Authorization code to exchange for tokens 254 | 255 | Args: 256 | request: HTTP request object 257 | 258 | Returns: 259 | HTTP redirect to the final destination with auth token 260 | 261 | Raises: 262 | HTTPFound: To redirect to final destination 263 | 264 | Flow: 265 | 1. Extract state, issuer, and code from query parameters 266 | 2. Complete OAuth flow with oauth_complete 267 | 3. Add auth token to destination URL 268 | 4. Redirect user to final destination 269 | """ 270 | state: Optional[str] = request.query.get("state", None) 271 | issuer: Optional[str] = request.query.get("iss", None) 272 | code: Optional[str] = request.query.get("code", None) 273 | 274 | settings = request.app[SettingsAppKey] 275 | http_session = request.app[SessionAppKey] 276 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 277 | redis_session = request.app[RedisClientAppKey] 278 | statsd_client = request.app[TelegrafStatsdClientAppKey] 279 | 280 | try: 281 | (serialized_auth_token, destination) = await oauth_complete( 282 | settings, 283 | http_session, 284 | statsd_client, 285 | database_session_maker, 286 | redis_session, 287 | state, 288 | issuer, 289 | code, 290 | ) 291 | except Exception as e: 292 | return await aiohttp_jinja2.render_template_async( 293 | "alert.html", 294 | request, 295 | context={"error_message": str(e)}, 296 | ) 297 | 298 | parsed_destination = urlparse(destination) 299 | query = dict(parse_qsl(parsed_destination.query)) 300 | query.update({"auth_token": serialized_auth_token}) 301 | parsed_destination = parsed_destination._replace(query=urlencode(query)) 302 | redirect_destination = urlunparse(parsed_destination) 303 | raise web.HTTPFound(redirect_destination) 304 | 305 | 306 | async def handle_atproto_refresh(request: web.Request): 307 | """ 308 | Handle manual token refresh request. 309 | 310 | This handler allows for manual refreshing of OAuth tokens. It extracts the 311 | auth token from the query parameters, validates it, finds the associated 312 | OAuth session, and refreshes the token. 313 | 314 | Query Parameters: 315 | auth_token: JWT authentication token 316 | 317 | Args: 318 | request: HTTP request object 319 | 320 | Returns: 321 | HTTP redirect to debug page with refreshed token 322 | 323 | Raises: 324 | HTTPFound: To redirect to debug page 325 | Exception: If auth token is invalid or session not found 326 | 327 | Flow: 328 | 1. Extract and validate auth token 329 | 2. Find associated OAuth session 330 | 3. Refresh tokens with oauth_refresh 331 | 4. Redirect to debug page 332 | """ 333 | settings = request.app[SettingsAppKey] 334 | http_session = request.app[SessionAppKey] 335 | statsd_client = request.app[TelegrafStatsdClientAppKey] 336 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 337 | redis_session = request.app[RedisClientAppKey] 338 | 339 | serialized_auth_token: Optional[str] = request.query.get("auth_token", None) 340 | if serialized_auth_token is None: 341 | raise Exception("Invalid request") 342 | 343 | validated_auth_token = jwt.JWT( 344 | jwt=serialized_auth_token, key=settings.json_web_keys, algs=["ES256"] 345 | ) 346 | auth_token_claims = json.loads(validated_auth_token.claims) 347 | 348 | auth_token_subject: Optional[str] = auth_token_claims.get("sub", None) 349 | auth_token_session_group: Optional[str] = auth_token_claims.get("grp", None) 350 | 351 | # Fetch OAuth session first 352 | async with (database_session_maker() as database_session,): 353 | async with database_session.begin(): 354 | oauth_session_stmt = select(OAuthSession).where( 355 | OAuthSession.guid == auth_token_subject, 356 | OAuthSession.session_group == auth_token_session_group, 357 | ) 358 | oauth_session: OAuthSession = ( 359 | await database_session.scalars(oauth_session_stmt) 360 | ).one() 361 | 362 | # Create fresh database session for oauth_refresh to avoid transaction conflicts 363 | async with (database_session_maker() as fresh_database_session,): 364 | await oauth_refresh( 365 | settings, 366 | http_session, 367 | statsd_client, 368 | fresh_database_session, 369 | redis_session, 370 | oauth_session, 371 | ) 372 | 373 | # The same auth token is returned, but the access token is updated. 374 | raise web.HTTPFound(f"/auth/atproto/debug?auth_token={serialized_auth_token}") 375 | 376 | 377 | async def handle_atproto_debug(request: web.Request): 378 | """ 379 | Handle debug page request showing authentication information. 380 | 381 | This handler displays detailed information about the authentication session, 382 | including the JWT token contents, OAuth session details, and user handle. 383 | It's primarily used for debugging and development purposes. 384 | 385 | Query Parameters: 386 | auth_token: JWT authentication token 387 | 388 | Args: 389 | request: HTTP request object 390 | 391 | Returns: 392 | HTTP response with rendered debug template 393 | 394 | Raises: 395 | Exception: If auth token is invalid or session/handle not found 396 | """ 397 | settings = request.app[SettingsAppKey] 398 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 399 | 400 | serialized_auth_token: Optional[str] = request.query.get("auth_token", None) 401 | if serialized_auth_token is None: 402 | raise Exception("Invalid request") 403 | 404 | validated_auth_token = jwt.JWT( 405 | jwt=serialized_auth_token, key=settings.json_web_keys, algs=["ES256"] 406 | ) 407 | auth_token_claims = json.loads(validated_auth_token.claims) 408 | auth_token_header = json.loads(validated_auth_token.header) 409 | 410 | auth_token_subject: Optional[str] = auth_token_claims.get("sub", None) 411 | auth_token_session_group: Optional[str] = auth_token_claims.get("grp", None) 412 | 413 | async with database_session_maker() as database_session: 414 | 415 | async with database_session.begin(): 416 | 417 | oauth_session_stmt = select(OAuthSession).where( 418 | OAuthSession.guid == auth_token_subject, 419 | OAuthSession.session_group == auth_token_session_group, 420 | ) 421 | oauth_session: Optional[OAuthSession] = ( 422 | await database_session.scalars(oauth_session_stmt) 423 | ).first() 424 | if oauth_session is None: 425 | raise Exception("Invalid request: no matching session") 426 | 427 | handle_stmt = select(Handle).where(Handle.guid == oauth_session.guid) 428 | handle: Optional[Handle] = ( 429 | await database_session.scalars(handle_stmt) 430 | ).first() 431 | if handle is None: 432 | raise Exception("Invalid request: no matching handle") 433 | 434 | await database_session.commit() 435 | 436 | return await aiohttp_jinja2.render_template_async( 437 | "atproto_debug.html", 438 | request, 439 | context={ 440 | "auth_token": {"claims": auth_token_claims, "header": auth_token_header}, 441 | "oauth_session": oauth_session, 442 | "handle": handle, 443 | "serialized_auth_token": serialized_auth_token, 444 | }, 445 | ) 446 | 447 | 448 | async def handle_jwks(request: web.Request): 449 | """ 450 | Handle JWKS (JSON Web Key Set) endpoint request. 451 | 452 | This handler provides the public keys used for verifying JWT signatures. 453 | It returns a JWKS document containing the public portions of the active signing keys. 454 | 455 | Args: 456 | request: HTTP request object 457 | 458 | Returns: 459 | HTTP JSON response with JWKS document 460 | """ 461 | settings = request.app[SettingsAppKey] 462 | results: List[Dict[str, Any]] = [] 463 | for kid in settings.active_signing_keys: 464 | key = settings.json_web_keys.get_key(kid) 465 | if key is None: 466 | continue 467 | results.append(key.export_public(as_dict=True)) 468 | return web.json_response({"keys": results}) 469 | 470 | 471 | async def handle_atproto_client_metadata(request: web.Request): 472 | """ 473 | Handle OAuth client metadata endpoint request. 474 | 475 | This handler provides OAuth client metadata according to the OAuth 2.0 476 | Dynamic Client Registration Protocol (RFC 7591). It returns a JSON document 477 | describing this client to AT Protocol authorization servers. 478 | 479 | The metadata includes client identification, capabilities, endpoints, 480 | and authentication methods. 481 | 482 | Args: 483 | request: HTTP request object 484 | 485 | Returns: 486 | HTTP JSON response with client metadata 487 | """ 488 | settings = request.app[SettingsAppKey] 489 | client_id = ( 490 | f"https://{settings.external_hostname}/auth/atproto/client-metadata.json" 491 | ) 492 | client_uri = f"https://{settings.external_hostname}" 493 | jwks_uri = f"https://{settings.external_hostname}/.well-known/jwks.json" 494 | logo_uri = f"https://{settings.external_hostname}/logo.png" 495 | policy_uri = f"https://{settings.external_hostname}/PLACEHOLDER" 496 | redirect_uris = [f"https://{settings.external_hostname}/auth/atproto/callback"] 497 | tos_uri = f"https://{settings.external_hostname}/PLACEHOLDER" 498 | client_metadata = ATProtocolOAuthClientMetadata( 499 | application_type="web", 500 | client_id=client_id, 501 | client_name="Graze Social", 502 | client_uri=client_uri, 503 | dpop_bound_access_tokens=True, 504 | grant_types=["authorization_code", "refresh_token"], 505 | jwks_uri=jwks_uri, 506 | logo_uri=logo_uri, 507 | policy_uri=policy_uri, 508 | redirect_uris=redirect_uris, 509 | response_types=["code"], 510 | scope="atproto transition:generic", 511 | token_endpoint_auth_method="private_key_jwt", 512 | token_endpoint_auth_signing_alg="ES256", 513 | subject_type="public", 514 | tos_uri=tos_uri, 515 | ) 516 | return web.json_response(client_metadata.dict()) -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/permissions.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | from typing import ( 4 | Literal, 5 | Optional, 6 | Dict, 7 | List, 8 | Any, 9 | ) 10 | from aiohttp import web 11 | from pydantic import BaseModel, PositiveInt, RootModel, ValidationError 12 | from sqlalchemy import delete, select 13 | import sentry_sdk 14 | 15 | from social.graze.aip.app.config import ( 16 | DatabaseSessionMakerAppKey, 17 | TelegrafStatsdClientAppKey, 18 | ) 19 | from social.graze.aip.app.handlers.helpers import auth_token_helper 20 | from social.graze.aip.model.oauth import ( 21 | Permission, 22 | upsert_permission_stmt, 23 | ) 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class PermissionOperation(BaseModel): 29 | op: Literal["test", "add", "remove", "replace"] 30 | path: str 31 | 32 | # TODO: Make these permission values mean something. 33 | value: Optional[PositiveInt] = None 34 | 35 | 36 | PermissionOperations = RootModel[list[PermissionOperation]] 37 | 38 | 39 | async def handle_internal_permissions(request: web.Request) -> web.Response: 40 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 41 | statsd_client = request.app[TelegrafStatsdClientAppKey] 42 | 43 | # TODO: Support GET requests that returns paginated permission objects. 44 | 45 | try: 46 | data = await request.read() 47 | operations = PermissionOperations.model_validate_json(data) 48 | except (OSError, ValidationError) as e: 49 | sentry_sdk.capture_exception(e) 50 | return web.Response(text="Invalid JSON", status=400) 51 | 52 | try: 53 | async with (database_session_maker() as database_session,): 54 | auth_token = await auth_token_helper( 55 | database_session, statsd_client, request, allow_permissions=False 56 | ) 57 | if auth_token is None: 58 | return web.json_response(status=401, data={"error": "Not Authorized"}) 59 | 60 | # TODO: Fail with error if the user does not have an app-password set. 61 | 62 | now = datetime.datetime.now(datetime.timezone.utc) 63 | 64 | async with database_session.begin(): 65 | results: List[Dict[str, Any]] = [] 66 | 67 | for operation in operations.root: 68 | guid = operation.path.removeprefix("/") 69 | if guid == auth_token.guid: 70 | return web.json_response( 71 | status=400, data={"error": "Invalid permission"} 72 | ) 73 | 74 | if ( 75 | operation.op == "add" or operation.op == "replace" 76 | ) and operation.value is not None: 77 | 78 | # TODO: Fail if the guid is unknown. Clients should use /internal/api/resolve on all subjects 79 | # prior to setting permissions. 80 | 81 | stmt = upsert_permission_stmt( 82 | guid=guid, 83 | target_guid=auth_token.guid, 84 | permission=operation.value, 85 | created_at=now, 86 | ) 87 | await database_session.execute(stmt) 88 | 89 | if operation.op == "remove": 90 | stmt = delete(Permission).where( 91 | Permission.guid == guid, 92 | Permission.target_guid == auth_token.guid, 93 | ) 94 | await database_session.execute(stmt) 95 | 96 | if operation.op == "test": 97 | permission_stmt = select(Permission).where( 98 | Permission.guid == guid, 99 | Permission.target_guid == auth_token.guid, 100 | ) 101 | if operation.value is not None: 102 | permission_stmt = permission_stmt.where( 103 | Permission.permission == operation.value 104 | ) 105 | permission: Optional[Permission] = ( 106 | await database_session.scalars(permission_stmt) 107 | ).first() 108 | if permission is not None: 109 | results.append( 110 | {"path": operation.path, "value": permission.permission} 111 | ) 112 | 113 | await database_session.commit() 114 | 115 | return web.json_response(results) 116 | except web.HTTPException as e: 117 | sentry_sdk.capture_exception(e) 118 | logging.exception("handle_internal_permissions: web.HTTPException") 119 | raise e 120 | except Exception as e: 121 | sentry_sdk.capture_exception(e) 122 | logging.exception("handle_internal_permissions: Exception") 123 | return web.json_response(status=500, data={"error": "Internal Server Error"}) 124 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/handlers/proxy.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | import logging 3 | from time import time 4 | from typing import ( 5 | List, 6 | Optional, 7 | Dict, 8 | Any, 9 | ) 10 | from aiohttp import web 11 | import hashlib 12 | import base64 13 | from jwcrypto import jwk 14 | import sentry_sdk 15 | from urllib.parse import urlparse, urlencode, parse_qsl, urlunparse 16 | 17 | from social.graze.aip.app.config import ( 18 | DatabaseSessionMakerAppKey, 19 | SessionAppKey, 20 | SettingsAppKey, 21 | TelegrafStatsdClientAppKey, 22 | ) 23 | from social.graze.aip.app.handlers.helpers import auth_token_helper 24 | from social.graze.aip.atproto.chain import ( 25 | ChainMiddlewareClient, 26 | DebugMiddleware, 27 | GenerateDpopMiddleware, 28 | RequestMiddlewareBase, 29 | StatsdMiddleware, 30 | ) 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | async def handle_xrpc_proxy(request: web.Request) -> web.Response: 36 | statsd_client = request.app[TelegrafStatsdClientAppKey] 37 | 38 | # TODO: Validate this against an allowlist. 39 | xrpc_method: Optional[str] = request.match_info.get("method", None) 40 | if xrpc_method is None: 41 | statsd_client.increment("aip.proxy.invalid_method", 1) 42 | return web.json_response(status=400, data={"error": "Invalid XRPC method"}) 43 | 44 | database_session_maker = request.app[DatabaseSessionMakerAppKey] 45 | 46 | try: 47 | async with (database_session_maker() as database_session,): 48 | # TODO: Allow optional auth here. 49 | auth_token = await auth_token_helper( 50 | database_session, statsd_client, request 51 | ) 52 | if auth_token is None: 53 | statsd_client.increment( 54 | "aip.proxy.unauthorized", 1, tag_dict={"method": xrpc_method} 55 | ) 56 | return web.json_response(status=401, data={"error": "Not Authorized"}) 57 | except web.HTTPException as e: 58 | sentry_sdk.capture_exception(e) 59 | raise e 60 | except Exception as e: 61 | sentry_sdk.capture_exception(e) 62 | statsd_client.increment( 63 | "aip.proxy.exception", 64 | 1, 65 | tag_dict={"exception": type(e).__name__, "method": xrpc_method}, 66 | ) 67 | return web.json_response(status=500, data={"error": "Internal Server Error"}) 68 | 69 | now = datetime.now(timezone.utc) 70 | 71 | # TODO: Look for endpoint header and fallback to context PDS value. 72 | # TODO: Support more complex URLs that include prefixes and or suffixes. 73 | parsed_destination = urlparse(f"{auth_token.context_service}/xrpc/{xrpc_method}") 74 | parsed_destination_query = dict(parse_qsl(request.query_string)) 75 | parsed_destination = parsed_destination._replace( 76 | query=urlencode(parsed_destination_query) 77 | ) 78 | xrpc_url = urlunparse(parsed_destination) 79 | 80 | http_session = request.app[SessionAppKey] 81 | 82 | headers = { 83 | "Content-Type": request.headers.get("Content-Type", "application/json"), 84 | } 85 | 86 | chain_middleware: List[RequestMiddlewareBase] = [StatsdMiddleware(statsd_client)] 87 | 88 | # App password sessions use the `Authorization` header with `Bearer` scheme. 89 | if auth_token.app_password_session is not None: 90 | headers["Authorization"] = ( 91 | f"Bearer {auth_token.app_password_session.access_token}" 92 | ) 93 | 94 | # OAuth sessions use the `Authorization` header with `DPoP` scheme. 95 | elif ( 96 | auth_token.app_password_session is None and auth_token.oauth_session is not None 97 | ): 98 | hashed_access_token = hashlib.sha256( 99 | str(auth_token.oauth_session.access_token).encode("ascii") 100 | ).digest() 101 | encoded_hashed_access_token = base64.urlsafe_b64encode(hashed_access_token) 102 | pkcs_access_token = encoded_hashed_access_token.decode("ascii").rstrip("=") 103 | 104 | dpop_key = jwk.JWK(**auth_token.oauth_session.dpop_jwk) 105 | dpop_key_public_key = dpop_key.export_public(as_dict=True) 106 | dpop_assertation_header = { 107 | "alg": "ES256", 108 | "jwk": dpop_key_public_key, 109 | "typ": "dpop+jwt", 110 | } 111 | dpop_assertation_claims = { 112 | "htm": request.method, 113 | "htu": f"{auth_token.context_service}/xrpc/{xrpc_method}", 114 | "iat": int(now.timestamp()) - 1, 115 | "exp": int(now.timestamp()) + 30, 116 | "nonce": "tmp", 117 | "ath": pkcs_access_token, 118 | "iss": f"{auth_token.oauth_session.issuer}", 119 | } 120 | 121 | headers["Authorization"] = f"DPoP {auth_token.oauth_session.access_token}" 122 | 123 | chain_middleware.append( 124 | GenerateDpopMiddleware( 125 | dpop_key, 126 | dpop_assertation_header, 127 | dpop_assertation_claims, 128 | ) 129 | ) 130 | 131 | settings = request.app[SettingsAppKey] 132 | if settings.debug: 133 | chain_middleware.append(DebugMiddleware()) 134 | 135 | rargs: Dict[str, Any] = {} 136 | 137 | if request.method == "POST": 138 | rargs["data"] = await request.read() 139 | 140 | chain_client = ChainMiddlewareClient( 141 | client_session=http_session, raise_for_status=False, middleware=chain_middleware 142 | ) 143 | 144 | start_time = time() 145 | cross_subject = auth_token.guid != auth_token.context_guid 146 | auth_method = "anonymous" 147 | if auth_token.app_password_session is not None: 148 | auth_method = "app-password" 149 | elif ( 150 | auth_token.oauth_session is not None and auth_token.app_password_session is None 151 | ): 152 | auth_method = "oauth" 153 | 154 | try: 155 | async with chain_client.request( 156 | request.method, xrpc_url, raise_for_status=None, headers=headers, **rargs 157 | ) as ( 158 | client_response, 159 | chain_response, 160 | ): 161 | # TODO: Figure out if websockets or SSE support is needed. Gut says no. 162 | # TODO: Think about using a header like `X-AIP-Error` for additional error context. 163 | return chain_response.to_web_response() 164 | finally: 165 | statsd_client.timer( 166 | "aip.proxy.request.time", 167 | time() - start_time, 168 | tag_dict={ 169 | "xrpc_service": auth_token.context_service.removeprefix("https://"), 170 | "xrpc_method": xrpc_method, 171 | "method": request.method.lower(), 172 | "authentication": auth_method, 173 | "cross_subject": str(cross_subject), 174 | }, 175 | ) 176 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/server.py: -------------------------------------------------------------------------------- 1 | import re 2 | import asyncio 3 | import contextlib 4 | import os 5 | import logging 6 | from time import time 7 | from typing import ( 8 | Optional, 9 | ) 10 | from aio_statsd import TelegrafStatsdClient 11 | import jinja2 12 | from aiohttp import web 13 | import aiohttp_jinja2 14 | import aiohttp 15 | import redis.asyncio as redis 16 | from sqlalchemy.ext.asyncio import ( 17 | create_async_engine, 18 | async_sessionmaker, 19 | AsyncSession, 20 | ) 21 | import sentry_sdk 22 | from sentry_sdk.integrations.aiohttp import AioHttpIntegration 23 | 24 | from social.graze.aip.app.config import ( 25 | DatabaseAppKey, 26 | DatabaseSessionMakerAppKey, 27 | HealthGaugeAppKey, 28 | RedisClientAppKey, 29 | RedisPoolAppKey, 30 | SessionAppKey, 31 | Settings, 32 | SettingsAppKey, 33 | AppPasswordRefreshTaskAppKey, 34 | OAuthRefreshTaskAppKey, 35 | TelegrafStatsdClientAppKey, 36 | TickHealthTaskAppKey, 37 | ) 38 | from social.graze.aip.app.cors import get_cors_headers 39 | from social.graze.aip.app.handlers.app_password import handle_internal_app_password 40 | from social.graze.aip.app.handlers.credentials import handle_internal_credentials 41 | from social.graze.aip.app.handlers.internal import ( 42 | handle_internal_alive, 43 | handle_internal_me, 44 | handle_internal_ready, 45 | handle_internal_resolve, 46 | ) 47 | from social.graze.aip.app.handlers.oauth import ( 48 | handle_atproto_callback, 49 | handle_atproto_client_metadata, 50 | handle_atproto_debug, 51 | handle_atproto_login, 52 | handle_atproto_login_submit, 53 | handle_atproto_refresh, 54 | handle_jwks, 55 | ) 56 | from social.graze.aip.app.handlers.permissions import handle_internal_permissions 57 | from social.graze.aip.app.handlers.proxy import handle_xrpc_proxy 58 | from social.graze.aip.app.tasks import ( 59 | oauth_refresh_task, 60 | tick_health_task, 61 | app_password_refresh_task, 62 | ) 63 | from social.graze.aip.model.health import HealthGauge 64 | 65 | logger = logging.getLogger(__name__) 66 | allowed_origin_pattern = re.compile( 67 | r"https:\/\/(www\.)?graze\.social" 68 | r"https:\/\/(www\.)?sky-feeder-git-astro-graze\.vercel\.app" 69 | r"http:\/\/localhost\:\d+" 70 | r"http:\/\/127\.0\.0\.1\:\d+" 71 | ) 72 | 73 | 74 | async def handle_index(request: web.Request): 75 | return await aiohttp_jinja2.render_template_async("index.html", request, context={}) 76 | 77 | 78 | async def background_tasks(app): 79 | logger.info("Starting up") 80 | settings: Settings = app[SettingsAppKey] 81 | 82 | engine = create_async_engine(str(settings.pg_dsn)) 83 | app[DatabaseAppKey] = engine 84 | database_session = async_sessionmaker( 85 | engine, class_=AsyncSession, expire_on_commit=False 86 | ) 87 | app[DatabaseSessionMakerAppKey] = database_session 88 | 89 | trace_config = aiohttp.TraceConfig() 90 | 91 | if settings.debug: 92 | 93 | async def on_request_start( 94 | session, trace_config_ctx, params: aiohttp.TraceRequestStartParams 95 | ): 96 | logging.info("Starting request: %s", params) 97 | 98 | async def on_request_chunk_sent( 99 | session, trace_config_ctx, params: aiohttp.TraceRequestChunkSentParams 100 | ): 101 | logging.info("Chunk sent: %s", str(params.chunk)) 102 | 103 | async def on_request_end(session, trace_config_ctx, params): 104 | logging.info("Ending request: %s", params) 105 | 106 | trace_config.on_request_start.append(on_request_start) 107 | trace_config.on_request_end.append(on_request_end) 108 | trace_config.on_request_chunk_sent.append(on_request_chunk_sent) 109 | 110 | app[SessionAppKey] = aiohttp.ClientSession(trace_configs=[trace_config]) 111 | 112 | app[RedisPoolAppKey] = redis.ConnectionPool.from_url(str(settings.redis_dsn)) 113 | 114 | app[RedisClientAppKey] = redis.Redis( 115 | connection_pool=redis.ConnectionPool.from_url(str(settings.redis_dsn)) 116 | ) 117 | 118 | statsd_client = TelegrafStatsdClient( 119 | host=settings.statsd_host, port=settings.statsd_port, debug=settings.debug 120 | ) 121 | await statsd_client.connect() 122 | app[TelegrafStatsdClientAppKey] = statsd_client 123 | 124 | logger.info("Startup complete") 125 | 126 | app[TickHealthTaskAppKey] = asyncio.create_task(tick_health_task(app)) 127 | app[OAuthRefreshTaskAppKey] = asyncio.create_task(oauth_refresh_task(app)) 128 | app[AppPasswordRefreshTaskAppKey] = asyncio.create_task( 129 | app_password_refresh_task(app) 130 | ) 131 | 132 | yield 133 | 134 | print("Shutting down background tasks") 135 | 136 | app[TickHealthTaskAppKey].cancel() 137 | app[OAuthRefreshTaskAppKey].cancel() 138 | app[AppPasswordRefreshTaskAppKey].cancel() 139 | 140 | with contextlib.suppress(asyncio.exceptions.CancelledError): 141 | await app[TickHealthTaskAppKey] 142 | 143 | with contextlib.suppress(asyncio.exceptions.CancelledError): 144 | await app[OAuthRefreshTaskAppKey] 145 | 146 | with contextlib.suppress(asyncio.exceptions.CancelledError): 147 | await app[AppPasswordRefreshTaskAppKey] 148 | 149 | await app[DatabaseAppKey].dispose() 150 | await app[SessionAppKey].close() 151 | await app[RedisPoolAppKey].aclose() 152 | await app[TelegrafStatsdClientAppKey].close() 153 | 154 | 155 | @web.middleware 156 | async def cors_middleware(request: web.Request, handler): 157 | settings = request.app[SettingsAppKey] 158 | 159 | origin = request.headers.get("Origin") 160 | host = request.headers.get("Host") 161 | origin_value = origin if origin else host 162 | path = request.path 163 | 164 | headers = get_cors_headers(origin_value, path, settings.debug) 165 | 166 | if request.method == "OPTIONS": 167 | logger.debug(f"[CORS] Returning early for OPTIONS request: {path}") 168 | return web.Response(status=200, headers=headers) 169 | 170 | response = await handler(request) 171 | 172 | for k, v in headers.items(): 173 | response.headers[k] = v 174 | 175 | return response 176 | 177 | @web.middleware 178 | async def sentry_middleware(request: web.Request, handler): 179 | request_method: str = request.method 180 | request_path = request.path 181 | 182 | try: 183 | response = await handler(request) 184 | return response 185 | except Exception as e: 186 | sentry_sdk.capture_exception(e) 187 | raise e 188 | 189 | 190 | @web.middleware 191 | async def statsd_middleware(request: web.Request, handler): 192 | statsd_client = request.app[TelegrafStatsdClientAppKey] 193 | request_method: str = request.method 194 | request_path = request.path 195 | 196 | start_time: float = time() 197 | response_status_code = 0 198 | 199 | try: 200 | response = await handler(request) 201 | response_status_code = response.status 202 | return response 203 | except Exception as e: 204 | statsd_client.increment( 205 | "aip.server.request.exception", 206 | 1, 207 | tag_dict={ 208 | "exception": type(e).__name__, 209 | "path": request_path, 210 | "method": request_method, 211 | }, 212 | ) 213 | raise e 214 | finally: 215 | statsd_client.timer( 216 | "aip.server.request.time", 217 | time() - start_time, 218 | tag_dict={"path": request_path, "method": request_method}, 219 | ) 220 | statsd_client.increment( 221 | "aip.server.request.count", 222 | 1, 223 | tag_dict={ 224 | "path": request_path, 225 | "method": request_method, 226 | "status": response_status_code, 227 | }, 228 | ) 229 | 230 | 231 | async def shutdown(app): 232 | await app[DatabaseAppKey].dispose() 233 | await app[SessionAppKey].close() 234 | await app[RedisPoolAppKey].aclose() 235 | await app[TelegrafStatsdClientAppKey].close() 236 | 237 | 238 | async def start_web_server(settings: Optional[Settings] = None): 239 | 240 | if settings is None: 241 | settings = Settings() # type: ignore 242 | if settings.sentry_dsn: 243 | sentry_sdk.init( 244 | dsn=settings.sentry_dsn, 245 | send_default_pii=True, 246 | integrations=[AioHttpIntegration()], 247 | ) 248 | 249 | app = web.Application( 250 | middlewares=[cors_middleware, statsd_middleware, sentry_middleware] 251 | ) 252 | 253 | app[SettingsAppKey] = settings 254 | app[HealthGaugeAppKey] = HealthGauge() 255 | 256 | app.add_routes( 257 | [ 258 | web.static( 259 | "/static", os.path.join(os.getcwd(), "static"), append_version=True 260 | ) 261 | ] 262 | ) 263 | 264 | app.add_routes([web.get("/.well-known/jwks.json", handle_jwks)]) 265 | 266 | app.add_routes([web.get("/", handle_index)]) 267 | app.add_routes([web.get("/auth/atproto", handle_atproto_login)]) 268 | app.add_routes([web.post("/auth/atproto", handle_atproto_login_submit)]) 269 | app.add_routes([web.get("/auth/atproto/callback", handle_atproto_callback)]) 270 | app.add_routes([web.get("/auth/atproto/debug", handle_atproto_debug)]) 271 | app.add_routes([web.get("/auth/atproto/refresh", handle_atproto_refresh)]) 272 | 273 | app.add_routes( 274 | [web.get("/auth/atproto/client-metadata.json", handle_atproto_client_metadata)] 275 | ) 276 | 277 | app.add_routes( 278 | [ 279 | web.get("/internal/alive", handle_internal_alive), 280 | web.get("/internal/ready", handle_internal_ready), 281 | web.get("/internal/api/me", handle_internal_me), 282 | web.get("/internal/api/resolve", handle_internal_resolve), 283 | web.post("/internal/api/permissions", handle_internal_permissions), 284 | web.post("/internal/api/app_password", handle_internal_app_password), 285 | web.get("/internal/api/credentials", handle_internal_credentials), 286 | ] 287 | ) 288 | 289 | app.add_routes( 290 | [ 291 | web.get("/xrpc/{method}", handle_xrpc_proxy), 292 | web.post("/xrpc/{method}", handle_xrpc_proxy), 293 | ] 294 | ) 295 | 296 | _ = aiohttp_jinja2.setup( 297 | app, 298 | enable_async=True, 299 | loader=jinja2.FileSystemLoader(os.path.join(os.getcwd(), "templates")), 300 | ) 301 | 302 | app["static_root_url"] = "/static" 303 | 304 | # app.on_startup.append(startup) 305 | # app.on_cleanup.append(shutdown) 306 | app.cleanup_ctx.append(background_tasks) 307 | 308 | return app 309 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/tasks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime, timezone 3 | import logging 4 | from time import time 5 | from typing import List, NoReturn, Tuple 6 | from aiohttp import web 7 | from sqlalchemy import select, delete 8 | import sentry_sdk 9 | 10 | from social.graze.aip.app.config import ( 11 | APP_PASSWORD_REFRESH_QUEUE, 12 | OAUTH_REFRESH_QUEUE, 13 | OAUTH_REFRESH_RETRY_QUEUE, 14 | DatabaseSessionMakerAppKey, 15 | HealthGaugeAppKey, 16 | RedisClientAppKey, 17 | SessionAppKey, 18 | SettingsAppKey, 19 | TelegrafStatsdClientAppKey, 20 | ) 21 | from social.graze.aip.atproto.app_password import populate_session 22 | from social.graze.aip.atproto.oauth import oauth_refresh 23 | from social.graze.aip.model.oauth import OAuthSession, OAuthRequest 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | async def tick_health_task(app: web.Application) -> NoReturn: 29 | """ 30 | Tick the health gauge every 30 seconds, reducing the health score by 1 each time. 31 | """ 32 | 33 | logger.info("Starting health gauge task") 34 | 35 | health_gauge = app[HealthGaugeAppKey] 36 | while True: 37 | await health_gauge.tick() 38 | await asyncio.sleep(30) 39 | 40 | 41 | async def oauth_refresh_task(app: web.Application) -> NoReturn: 42 | """ 43 | oauth_refresh_task is a background process that refreshes OAuth sessions immediately before they expire. 44 | 45 | The process is as follows: 46 | 47 | Given the queue name is "auth_refresh" 48 | Given the worker id is "worker1" 49 | Given `now = datetime.datetime.now(datetime.timezone.utc)` 50 | 51 | 1. In a redis pipeline, get some work. 52 | * Populate the worker queue with work. This stores a range of things from the begining of time to "now" into a 53 | new queue. 54 | ZRANGESTORE "auth_refresh_worker1" "auth_refresh" 1 {now} LIMIT 5 55 | 56 | * Get the work that we just populated. 57 | ZRANGE "auth_refresh_worker1" 0 -1 58 | 59 | * Store the difference between the worker queue and the main queue to remove the pulled work from the main 60 | queue. 61 | ZDIFFSTORE "auth_refresh" 2 "auth_refresh" "auth_refresh_worker1" 62 | 63 | 2. For the work that we just got, process it all and remove each from the worker queue. 64 | ZREM "auth_refresh_worker1" {work_id} 65 | 66 | 3. Sleep 15-30 seconds and repeat. 67 | 68 | This does a few things that are important to note. 69 | 70 | 1. Work is queued up and indexed (redis zindex) against the time that it needs to be processed, not when 71 | it was queued. This lets the queue be lazily evaluated and also pull work that needs to be processed 72 | soonest. 73 | 74 | 2. Work is batched into a worker queue outside of app instances, so it can be processed in parallel. If 75 | we need to scale up workers, we can do so by adjusting the deployment replica count. 76 | 77 | 3. Work is grabbed in batches that don't need to be uniform, so there is no arbitrary delay. Workers 78 | don't have to wait for 5 jobs to be ready before taking them. 79 | 80 | 4. If a worker dies, we have the temporary worker queue to recover the work that was in progress. If 81 | needed, we can create a watchdog worker that looks at orphaned worker queues and adds the work back to 82 | the main queue. 83 | """ 84 | 85 | logger.info("Starting oauth refresh task") 86 | 87 | settings = app[SettingsAppKey] 88 | database_session_maker = app[DatabaseSessionMakerAppKey] 89 | http_session = app[SessionAppKey] 90 | 91 | redis_session = app[RedisClientAppKey] 92 | statsd_client = app[TelegrafStatsdClientAppKey] 93 | 94 | while True: 95 | 96 | await asyncio.sleep(10) 97 | 98 | now = datetime.now(timezone.utc) 99 | 100 | worker_queue = f"{OAUTH_REFRESH_QUEUE}:{settings.worker_id}" 101 | workers_heartbeat = f"{OAUTH_REFRESH_QUEUE}:workers" 102 | 103 | await redis_session.hset( 104 | workers_heartbeat, settings.worker_id, str(int(now.timestamp())) 105 | ) # type: ignore 106 | 107 | worker_queue_count: int = await redis_session.zcount( 108 | worker_queue, 0, int(now.timestamp()) 109 | ) 110 | statsd_client.gauge( 111 | "aip.task.oauth_refresh.worker_queue_count", 112 | worker_queue_count, 113 | tag_dict={"worker_id": settings.worker_id}, 114 | ) 115 | 116 | global_queue_count: int = await redis_session.zcount( 117 | OAUTH_REFRESH_QUEUE, 0, int(now.timestamp()) 118 | ) 119 | statsd_client.gauge( 120 | "aip.task.oauth_refresh.global_queue_count", 121 | global_queue_count, 122 | tag_dict={"worker_id": settings.worker_id}, 123 | ) 124 | 125 | if worker_queue_count == 0 and global_queue_count > 0: 126 | async with redis_session.pipeline() as redis_pipe: 127 | try: 128 | logger.debug( 129 | f"tick_task: processing {OAUTH_REFRESH_QUEUE} up to {int(now.timestamp())}" 130 | ) 131 | redis_pipe.zrangestore( 132 | worker_queue, 133 | OAUTH_REFRESH_QUEUE, 134 | 0, 135 | int(now.timestamp()), 136 | num=5, 137 | offset=0, 138 | byscore=True, 139 | ) 140 | 141 | redis_pipe.zdiffstore( 142 | OAUTH_REFRESH_QUEUE, [OAUTH_REFRESH_QUEUE, worker_queue] 143 | ) 144 | (zrangestore_res, zdiffstore_res) = await redis_pipe.execute() 145 | statsd_client.increment( 146 | "aip.task.oauth_refresh.work_queued", 147 | zrangestore_res, 148 | tag_dict={"worker_id": settings.worker_id}, 149 | ) 150 | except Exception as e: 151 | sentry_sdk.capture_exception(e) 152 | logging.exception("error populating worker queue") 153 | 154 | tasks: List[Tuple[str, float]] = await redis_session.zrange( 155 | worker_queue, 0, int(now.timestamp()), byscore=True, withscores=True 156 | ) 157 | if len(tasks) > 0: 158 | async with (database_session_maker() as database_session,): 159 | for session_group, deadline in tasks: 160 | 161 | logger.debug( 162 | "tick_task: processing session_group %s deadline %s", 163 | session_group, 164 | deadline, 165 | ) 166 | 167 | start_time = time() 168 | 169 | try: 170 | async with database_session.begin(): 171 | if isinstance(session_group, bytes): 172 | session_group = session_group.decode() 173 | oauth_session_stmt = select(OAuthSession).where( 174 | OAuthSession.session_group == session_group 175 | ) 176 | oauth_session: OAuthSession = ( 177 | await database_session.scalars(oauth_session_stmt) 178 | ).one() 179 | 180 | await oauth_refresh( 181 | settings, 182 | http_session, 183 | statsd_client, 184 | database_session, 185 | redis_session, 186 | oauth_session, 187 | ) 188 | 189 | # Clear retry count on successful refresh 190 | session_group_str = session_group.decode() if isinstance(session_group, bytes) else session_group 191 | await redis_session.hdel(OAUTH_REFRESH_RETRY_QUEUE, session_group_str) 192 | 193 | except Exception as e: 194 | sentry_sdk.capture_exception(e) 195 | logging.exception( 196 | "error processing session group %s", session_group 197 | ) 198 | 199 | # Implement retry logic with exponential backoff 200 | session_group_str = session_group.decode() if isinstance(session_group, bytes) else session_group 201 | current_retries = await redis_session.hget(OAUTH_REFRESH_RETRY_QUEUE, session_group_str) 202 | current_retries = int(current_retries) if current_retries else 0 203 | 204 | if current_retries < settings.oauth_refresh_max_retries: 205 | # Calculate exponential backoff delay 206 | retry_delay = settings.oauth_refresh_retry_base_delay * (2 ** current_retries) 207 | retry_timestamp = int(now.timestamp()) + retry_delay 208 | 209 | # Re-queue with delay and increment retry count 210 | await redis_session.zadd(OAUTH_REFRESH_QUEUE, {session_group_str: retry_timestamp}) 211 | await redis_session.hset(OAUTH_REFRESH_RETRY_QUEUE, session_group_str, current_retries + 1) 212 | 213 | logging.info( 214 | "Scheduled retry %d/%d for session_group %s in %d seconds", 215 | current_retries + 1, 216 | settings.oauth_refresh_max_retries, 217 | session_group_str, 218 | retry_delay 219 | ) 220 | 221 | statsd_client.increment( 222 | "aip.task.oauth_refresh.retry_scheduled", 223 | 1, 224 | tag_dict={ 225 | "retry_attempt": str(current_retries + 1), 226 | "worker_id": settings.worker_id, 227 | }, 228 | ) 229 | else: 230 | # Max retries exceeded, give up and clean up retry tracking 231 | await redis_session.hdel(OAUTH_REFRESH_RETRY_QUEUE, session_group_str) 232 | 233 | logging.error( 234 | "Max retries exceeded for session_group %s, giving up after %d attempts", 235 | session_group_str, 236 | settings.oauth_refresh_max_retries 237 | ) 238 | 239 | statsd_client.increment( 240 | "aip.task.oauth_refresh.max_retries_exceeded", 241 | 1, 242 | tag_dict={"worker_id": settings.worker_id}, 243 | ) 244 | 245 | # TODO: Don't actually tag session_group because cardinality will be very high. 246 | statsd_client.increment( 247 | "aip.task.oauth_refresh.exception", 248 | 1, 249 | tag_dict={ 250 | "exception": type(e).__name__, 251 | "session_group": session_group, 252 | "worker_id": settings.worker_id, 253 | }, 254 | ) 255 | 256 | finally: 257 | statsd_client.timer( 258 | "aip.task.oauth_refresh.time", 259 | time() - start_time, 260 | tag_dict={"worker_id": settings.worker_id}, 261 | ) 262 | # TODO: Probably don't need this because it is the same as `COUNT(aip.task.oauth_refresh.time)` 263 | statsd_client.increment( 264 | "aip.task.oauth_refresh.count", 265 | 1, 266 | tag_dict={"worker_id": settings.worker_id}, 267 | ) 268 | await redis_session.zrem(worker_queue, session_group) 269 | 270 | 271 | async def app_password_refresh_task(app: web.Application) -> NoReturn: 272 | 273 | logger.info("Starting app password refresh task") 274 | 275 | settings = app[SettingsAppKey] 276 | database_session_maker = app[DatabaseSessionMakerAppKey] 277 | http_session = app[SessionAppKey] 278 | redis_session = app[RedisClientAppKey] 279 | statsd_client = app[TelegrafStatsdClientAppKey] 280 | 281 | while True: 282 | try: 283 | await asyncio.sleep(10) 284 | 285 | now = datetime.now(timezone.utc) 286 | 287 | worker_queue = f"{APP_PASSWORD_REFRESH_QUEUE}:{settings.worker_id}" 288 | workers_heartbeat = f"{APP_PASSWORD_REFRESH_QUEUE}:workers" 289 | 290 | await redis_session.hset( 291 | workers_heartbeat, settings.worker_id, str(int(now.timestamp())) 292 | ) # type: ignore 293 | 294 | worker_queue_count: int = await redis_session.zcount( 295 | worker_queue, 0, int(now.timestamp()) 296 | ) 297 | statsd_client.gauge( 298 | "aip.task.app_password_refresh.worker_queue_count", 299 | worker_queue_count, 300 | tag_dict={"worker_id": settings.worker_id}, 301 | ) 302 | 303 | global_queue_count: int = await redis_session.zcount( 304 | APP_PASSWORD_REFRESH_QUEUE, 0, int(now.timestamp()) 305 | ) 306 | statsd_client.gauge( 307 | "aip.task.app_password_refresh.global_queue_count", 308 | global_queue_count, 309 | tag_dict={"worker_id": settings.worker_id}, 310 | ) 311 | 312 | if worker_queue_count == 0 and global_queue_count > 0: 313 | async with redis_session.pipeline() as redis_pipe: 314 | try: 315 | logger.debug( 316 | f"tick_task: processing {APP_PASSWORD_REFRESH_QUEUE} up to {int(now.timestamp())}" 317 | ) 318 | 319 | redis_pipe.zrangestore( 320 | worker_queue, 321 | APP_PASSWORD_REFRESH_QUEUE, 322 | 0, 323 | int(now.timestamp()), 324 | num=5, 325 | offset=0, 326 | byscore=True, 327 | ) 328 | 329 | redis_pipe.zdiffstore( 330 | APP_PASSWORD_REFRESH_QUEUE, 331 | [APP_PASSWORD_REFRESH_QUEUE, worker_queue], 332 | ) 333 | 334 | (zrangestore_res, zdiffstore_res) = await redis_pipe.execute() 335 | statsd_client.increment( 336 | "aip.task.app_password_refresh.work_queued", 337 | zrangestore_res, 338 | tag_dict={"worker_id": settings.worker_id}, 339 | ) 340 | except Exception as e: 341 | sentry_sdk.capture_exception(e) 342 | logging.exception("error populating app password worker queue") 343 | 344 | tasks: List[Tuple[str, float]] = await redis_session.zrange( 345 | worker_queue, 0, int(now.timestamp()), byscore=True, withscores=True 346 | ) 347 | 348 | if len(tasks) > 0: 349 | for handle_guid, deadline in tasks: 350 | 351 | logger.debug( 352 | "tick_task: processing guid %s deadline %s", 353 | handle_guid, 354 | deadline, 355 | ) 356 | 357 | start_time = time() 358 | try: 359 | await populate_session( 360 | http_session, 361 | database_session_maker, 362 | redis_session, 363 | handle_guid, 364 | settings, 365 | ) 366 | 367 | except Exception as e: 368 | sentry_sdk.capture_exception(e) 369 | logging.exception("error processing guid %s", handle_guid) 370 | # TODO: Don't actually tag session_group because cardinality will be very high. 371 | statsd_client.increment( 372 | "aip.task.app_password_refresh.exception", 373 | 1, 374 | tag_dict={ 375 | "exception": type(e).__name__, 376 | "guid": handle_guid, 377 | "worker_id": settings.worker_id, 378 | }, 379 | ) 380 | 381 | finally: 382 | statsd_client.timer( 383 | "aip.task.app_password_refresh.time", 384 | time() - start_time, 385 | tag_dict={"worker_id": settings.worker_id}, 386 | ) 387 | # TODO: Probably don't need this because it is the same as 388 | # `COUNT(aip.task.app_password_refresh.time)` 389 | statsd_client.increment( 390 | "aip.task.app_password_refresh.count", 391 | 1, 392 | tag_dict={"worker_id": settings.worker_id}, 393 | ) 394 | await redis_session.zrem(worker_queue, handle_guid) 395 | 396 | except Exception as e: 397 | sentry_sdk.capture_exception(e) 398 | logging.exception("app password tick failed") 399 | 400 | 401 | async def oauth_cleanup_task(app: web.Application) -> NoReturn: 402 | """ 403 | Background task to clean up expired OAuth records. 404 | 405 | This task runs every hour and removes: 406 | - OAuthRequest records that have expired 407 | - OAuthSession records that have reached their hard expiration 408 | 409 | This prevents database bloat from accumulating expired records. 410 | """ 411 | logger.info("Starting OAuth cleanup task") 412 | 413 | settings = app[SettingsAppKey] 414 | database_session_maker = app[DatabaseSessionMakerAppKey] 415 | statsd_client = app[TelegrafStatsdClientAppKey] 416 | 417 | while True: 418 | try: 419 | # Run cleanup every hour 420 | await asyncio.sleep(3600) 421 | 422 | now = datetime.now(timezone.utc) 423 | 424 | async with (database_session_maker() as database_session,): 425 | async with database_session.begin(): 426 | # Clean up expired OAuthRequest records 427 | expired_requests_stmt = delete(OAuthRequest).where( 428 | OAuthRequest.expires_at < now 429 | ) 430 | expired_requests_result = await database_session.execute(expired_requests_stmt) 431 | expired_requests_count = expired_requests_result.rowcount 432 | 433 | # Clean up expired OAuthSession records 434 | expired_sessions_stmt = delete(OAuthSession).where( 435 | OAuthSession.hard_expires_at < now 436 | ) 437 | expired_sessions_result = await database_session.execute(expired_sessions_stmt) 438 | expired_sessions_count = expired_sessions_result.rowcount 439 | 440 | await database_session.commit() 441 | 442 | if expired_requests_count > 0 or expired_sessions_count > 0: 443 | logger.info( 444 | "Cleaned up %d expired OAuth requests and %d expired OAuth sessions", 445 | expired_requests_count, 446 | expired_sessions_count 447 | ) 448 | 449 | # Report metrics 450 | statsd_client.increment( 451 | "aip.task.oauth_cleanup.expired_requests_removed", 452 | expired_requests_count, 453 | tag_dict={"worker_id": settings.worker_id}, 454 | ) 455 | statsd_client.increment( 456 | "aip.task.oauth_cleanup.expired_sessions_removed", 457 | expired_sessions_count, 458 | tag_dict={"worker_id": settings.worker_id}, 459 | ) 460 | 461 | except Exception as e: 462 | sentry_sdk.capture_exception(e) 463 | logging.exception("OAuth cleanup task failed") 464 | -------------------------------------------------------------------------------- /src/social/graze/aip/app/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/graze/aip/app/util/__init__.py -------------------------------------------------------------------------------- /src/social/graze/aip/app/util/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import logging 4 | from typing import Any, Dict, Optional 5 | import aiohttp 6 | from jwcrypto import jwk 7 | from ulid import ULID 8 | import base64 9 | from cryptography.fernet import Fernet 10 | 11 | from social.graze.aip.resolve.handle import resolve_subject 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | async def genAppPassword( 17 | handle: str, 18 | password: str, 19 | label: str, 20 | plc_hostname: str, 21 | auth_token: Optional[str] = None, 22 | ) -> None: 23 | async with aiohttp.ClientSession() as http_session: 24 | resolved_handle = await resolve_subject(http_session, plc_hostname, handle) 25 | assert resolved_handle is not None 26 | 27 | async with aiohttp.ClientSession() as http_session: 28 | create_session_url = ( 29 | f"{resolved_handle.pds}/xrpc/com.atproto.server.createSession" 30 | ) 31 | create_session_body = { 32 | "identifier": resolved_handle.did, 33 | "password": password, 34 | } 35 | if auth_token is not None: 36 | create_session_body["authToken"] = auth_token 37 | 38 | async with http_session.post( 39 | create_session_url, json=create_session_body 40 | ) as resp: 41 | assert resp.status == 200 42 | 43 | created_session: Dict[str, Any] = await resp.json() 44 | 45 | assert created_session.get("did", str) == resolved_handle.did 46 | assert created_session.get("handle", str) == handle 47 | 48 | create_app_password_url = ( 49 | f"{resolved_handle.pds}/xrpc/com.atproto.server.createAppPassword" 50 | ) 51 | create_app_password_body = {"name": label} 52 | create_app_password_headers = { 53 | "Authorization": f"Bearer {created_session['accessJwt']}" 54 | } 55 | 56 | async with http_session.post( 57 | create_app_password_url, 58 | headers=create_app_password_headers, 59 | json=create_app_password_body, 60 | ) as resp: 61 | if resp.status != 200: 62 | print("Error creating app password: %s", await resp.text()) 63 | return 64 | app_password = await resp.json() 65 | print( 66 | f"{app_password["name"]} created {app_password["createdAt"]}: {app_password["password"]}" 67 | ) 68 | 69 | 70 | async def genJwk() -> None: 71 | key = jwk.JWK.generate(kty="EC", crv="P-256", kid=str(ULID()), alg="ES256") 72 | print(key.export(private_key=True)) 73 | 74 | 75 | async def genCryptoKey() -> None: 76 | key = Fernet.generate_key() 77 | print(base64.b64encode(key).decode("utf-8")) 78 | 79 | 80 | async def realMain() -> None: 81 | parser = argparse.ArgumentParser(prog="aiputil", description="AIP utilities") 82 | 83 | parser.add_argument( 84 | "--plc-hostname", 85 | default="plc.directory", 86 | help="The PLC hostname to use for resolving did-method-plc DIDs.", 87 | ) 88 | 89 | subparsers = parser.add_subparsers(dest="command", required=True) 90 | 91 | _ = subparsers.add_parser("gen-jwk", help="Generate a JWK") 92 | _ = subparsers.add_parser("gen-crypto", help="Generate an encryption key") 93 | gen_app_password = subparsers.add_parser( 94 | "gen-app-password", help="Generate an app-password" 95 | ) 96 | 97 | gen_app_password.add_argument("handle", help="The handle to authenticate with.") 98 | gen_app_password.add_argument("password", help="The password to authenticate with.") 99 | gen_app_password.add_argument("label", help="The label for the app-password.") 100 | 101 | args = vars(parser.parse_args()) 102 | command = args.get("command", None) 103 | 104 | if command == "gen-jwk": 105 | await genJwk() 106 | elif command == "gen-crypto": 107 | await genCryptoKey() 108 | elif command == "gen-app-password": 109 | handle: str = args.get("handle", str) 110 | password: str = args.get("password", str) 111 | label: str = args.get("label", str) 112 | plc_hostname: str = args.get("plc_hostname", str) 113 | await genAppPassword(handle, password, label, plc_hostname) 114 | 115 | 116 | def main() -> None: 117 | asyncio.run(realMain()) 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /src/social/graze/aip/atproto/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AT Protocol Integration 3 | 4 | This package provides integration with the AT Protocol, handling authentication flows 5 | and communication with Personal Data Server (PDS) instances. 6 | 7 | Key Components: 8 | - app_password.py: Implementation of app password authentication 9 | - oauth.py: Implementation of OAuth flows including authorization code and refresh token 10 | - chain.py: Middleware chain for API requests (DPoP, claims, metrics) 11 | - pds.py: Interaction with PDS (Personal Data Server) instances 12 | 13 | Key Features: 14 | - OAuth 2.0 flow implementation with PKCE 15 | - DPoP (Demonstrating Proof-of-Possession) for access tokens 16 | - JWT-based client assertion for secure client authentication 17 | - Proactive token refresh to maintain session validity 18 | 19 | The authentication flow follows these steps: 20 | 1. Initialize OAuth flow with subject (handle or DID) 21 | 2. Generate PKCE challenge and redirect to authorization server 22 | 3. Complete OAuth flow with authorization code 23 | 4. Store tokens and set up refresh schedule 24 | 5. Refresh tokens before expiry to maintain session validity 25 | 26 | All communication with AT Protocol services uses middleware chains for 27 | consistent handling of authentication, metrics, and error reporting. 28 | """ -------------------------------------------------------------------------------- /src/social/graze/aip/atproto/app_password.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone, timedelta 2 | import logging 3 | from typing import Any, Dict, Optional 4 | from aiohttp import ClientSession 5 | import redis.asyncio as redis 6 | from sqlalchemy import select, update 7 | from sqlalchemy.ext.asyncio import ( 8 | async_sessionmaker, 9 | AsyncSession, 10 | ) 11 | from sqlalchemy.dialects.postgresql import insert 12 | import sentry_sdk 13 | from social.graze.aip.model.app_password import AppPassword, AppPasswordSession 14 | from social.graze.aip.model.handles import Handle 15 | from social.graze.aip.app.config import APP_PASSWORD_REFRESH_QUEUE, Settings 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | async def populate_session( 21 | http_session: ClientSession, 22 | database_session_maker: async_sessionmaker[AsyncSession], 23 | redis_session: redis.Redis, 24 | subject_guid: str, 25 | settings: Optional[Settings] = None, 26 | ) -> None: 27 | now = datetime.now(timezone.utc) 28 | async with database_session_maker() as database_session: 29 | async with database_session.begin(): 30 | # 1. Get the handle by guid 31 | handle_stmt = select(Handle).where(Handle.guid == subject_guid) 32 | handle: Optional[Handle] = ( 33 | await database_session.scalars(handle_stmt) 34 | ).first() 35 | 36 | if handle is None: 37 | raise ValueError(f"Handle not found for guid: {subject_guid}") 38 | 39 | # 2. Get the AppPassword by guid 40 | app_password_stmt = select(AppPassword).where( 41 | AppPassword.guid == subject_guid 42 | ) 43 | app_password: Optional[AppPassword] = ( 44 | await database_session.scalars(app_password_stmt) 45 | ).first() 46 | 47 | if app_password is None: 48 | raise ValueError(f"App password not found for guid: {subject_guid}") 49 | 50 | # 3. Get optional AppPasswordSession by guid 51 | app_password_session_stmt = select(AppPasswordSession).where( 52 | AppPasswordSession.guid == subject_guid 53 | ) 54 | app_password_session: Optional[AppPasswordSession] = ( 55 | await database_session.scalars(app_password_session_stmt) 56 | ).first() 57 | 58 | start_over = False 59 | access_token: str | None = None 60 | 61 | # Use settings if provided, otherwise use defaults 62 | access_token_expiry = settings.app_password_access_token_expiry if settings else 720 63 | refresh_token_expiry = settings.app_password_refresh_token_expiry if settings else 7776000 64 | 65 | # TODO: Pull this from the access token JWT claims payload 66 | access_token_expires_at = now + timedelta(0, access_token_expiry) 67 | 68 | refresh_token: str | None = None 69 | 70 | # TODO: Pull this from the refresh token JWT claims payload 71 | refresh_token_expires_at = now + timedelta(0, refresh_token_expiry) 72 | 73 | # 4. If AppPasswordSession exists: refresh it, update row, and return 74 | if app_password_session is not None: 75 | try: 76 | refresh_url = f"{handle.pds}/xrpc/com.atproto.server.refreshSession" 77 | headers = { 78 | "Authorization": f"Bearer {app_password_session.refresh_token}" 79 | } 80 | # This could fail if the app password was revoked or if the server isn't honoring the expiration 81 | # time of the refresh token. It's more likely that a user will remove / replace an app password. 82 | # When that happens, we want to start over and create a new session. 83 | async with http_session.post( 84 | refresh_url, headers=headers 85 | ) as response: 86 | if response.status != 200: 87 | raise Exception( 88 | f"Failed to refresh session: {response.status}" 89 | ) 90 | 91 | body: Dict[str, Any] = await response.json() 92 | access_token = body.get("accessJwt", None) 93 | refresh_token = body.get("refreshJwt", None) 94 | is_active = body.get("active", False) 95 | found_did = body.get("did", "") 96 | 97 | if found_did != handle.did: 98 | start_over = True 99 | 100 | if not is_active: 101 | start_over = True 102 | 103 | except Exception as e: 104 | sentry_sdk.capture_exception(e) 105 | logger.exception("Error refreshing session") 106 | start_over = True 107 | 108 | if ( 109 | access_token is not None 110 | and refresh_token is not None 111 | and not start_over 112 | ): 113 | update_session_stmt = ( 114 | update(AppPasswordSession) 115 | .where( 116 | AppPasswordSession.guid == app_password_session.guid, 117 | ) 118 | .values( 119 | access_token=access_token, 120 | access_token_expires_at=access_token_expires_at, 121 | refresh_token=refresh_token, 122 | refresh_token_expires_at=refresh_token_expires_at, 123 | ) 124 | ) 125 | await database_session.execute(update_session_stmt) 126 | 127 | refresh_ratio = settings.token_refresh_before_expiry_ratio if settings else 0.8 128 | expires_in_mod = access_token_expiry * refresh_ratio 129 | refresh_at = now + timedelta(0, expires_in_mod) 130 | await redis_session.zadd( 131 | APP_PASSWORD_REFRESH_QUEUE, 132 | {handle.guid: int(refresh_at.timestamp())}, 133 | ) 134 | 135 | if app_password_session is None or start_over: 136 | try: 137 | refresh_url = f"{handle.pds}/xrpc/com.atproto.server.createSession" 138 | headers = {} 139 | payload = { 140 | "identifier": handle.did, 141 | "password": app_password.app_password, 142 | } 143 | async with http_session.post( 144 | refresh_url, headers=headers, json=payload 145 | ) as response: 146 | if response.status != 200: 147 | raise Exception( 148 | f"Failed to refresh session: {response.status}" 149 | ) 150 | 151 | body: Dict[str, Any] = await response.json() 152 | access_token = body.get("accessJwt", None) 153 | refresh_token = body.get("refreshJwt", None) 154 | is_active = body.get("active", False) 155 | found_did = body.get("did", "") 156 | 157 | # It'd be pretty wild if this didn't match the handle, but would also lead to some really 158 | # unexpected behavior. 159 | if found_did != handle.did: 160 | raise ValueError( 161 | f"Handle did does not match found did: {handle.did} != {found_did}" 162 | ) 163 | 164 | if not is_active: 165 | raise ValueError("Handle is not active.") 166 | 167 | except Exception as e: 168 | sentry_sdk.capture_exception(e) 169 | logger.exception("Error creating session") 170 | 171 | # 5. Create new AppPasswordSession 172 | if access_token is not None and refresh_token is not None: 173 | update_session_stmt = ( 174 | insert(AppPasswordSession) 175 | .values( 176 | [ 177 | { 178 | "guid": handle.guid, 179 | "access_token": access_token, 180 | "access_token_expires_at": access_token_expires_at, 181 | "refresh_token": refresh_token, 182 | "refresh_token_expires_at": refresh_token_expires_at, 183 | "created_at": now, 184 | } 185 | ] 186 | ) 187 | .on_conflict_do_update( 188 | index_elements=["guid"], 189 | set_={ 190 | "access_token": access_token, 191 | "access_token_expires_at": access_token_expires_at, 192 | "refresh_token": refresh_token, 193 | "refresh_token_expires_at": refresh_token_expires_at, 194 | }, 195 | ) 196 | ) 197 | await database_session.execute(update_session_stmt) 198 | 199 | refresh_ratio = settings.token_refresh_before_expiry_ratio if settings else 0.8 200 | expires_in_mod = access_token_expiry * refresh_ratio 201 | refresh_at = now + timedelta(0, expires_in_mod) 202 | await redis_session.zadd( 203 | APP_PASSWORD_REFRESH_QUEUE, 204 | {handle.guid: int(refresh_at.timestamp())}, 205 | ) 206 | -------------------------------------------------------------------------------- /src/social/graze/aip/atproto/chain.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | import json 4 | import re 5 | import secrets 6 | from time import time 7 | from types import TracebackType 8 | from typing import ( 9 | Any, 10 | Awaitable, 11 | Callable, 12 | Generator, 13 | Optional, 14 | Sequence, 15 | Tuple, 16 | Union, 17 | Protocol, 18 | Dict, 19 | ) 20 | import logging 21 | from aio_statsd import TelegrafStatsdClient 22 | from aiohttp import web, ClientResponse, ClientSession, FormData, hdrs 23 | from aiohttp.typedefs import StrOrURL 24 | from multidict import CIMultiDictProxy 25 | from jwcrypto import jwt, jwk 26 | import sentry_sdk 27 | 28 | RequestFunc = Callable[..., Awaitable[ClientResponse]] 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class _LoggerStub(Protocol): 34 | """_Logger defines which methods logger object should have.""" 35 | 36 | @abstractmethod 37 | def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: 38 | pass 39 | 40 | @abstractmethod 41 | def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: 42 | pass 43 | 44 | @abstractmethod 45 | def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: 46 | pass 47 | 48 | 49 | _LoggerType = Union[_LoggerStub, logging.Logger] 50 | 51 | 52 | @dataclass 53 | class ChainRequest: 54 | method: str 55 | url: StrOrURL 56 | headers: dict[str, Any] | None = None 57 | trace_request_ctx: dict[str, Any] | None = None 58 | kwargs: dict[str, Any] | None = None 59 | 60 | # TODO: Fix bug where IDE doesn't like `def from_chain_request(request: Self) -> Self` 61 | @staticmethod 62 | def from_chain_request(request: "ChainRequest") -> "ChainRequest": 63 | return ChainRequest( 64 | method=request.method, 65 | url=request.url, 66 | headers=request.headers, 67 | trace_request_ctx=request.trace_request_ctx, 68 | kwargs=request.kwargs, 69 | ) 70 | 71 | 72 | @dataclass 73 | class ChainResponse: 74 | status: int 75 | headers: CIMultiDictProxy[str] 76 | body: str | bytes | dict[str, Any] | None = None 77 | # exception: BaseException | None = None 78 | 79 | # TODO: Fix bug where IDE doesn't like `-> Self` 80 | @staticmethod 81 | async def from_aiohttp_response(response: ClientResponse) -> "ChainResponse": 82 | status = response.status 83 | headers = response.headers 84 | 85 | content_type = response.headers.get(hdrs.CONTENT_TYPE, "") 86 | 87 | if content_type.startswith("application/json"): 88 | return ChainResponse( 89 | status=status, headers=headers, body=await response.json() 90 | ) 91 | elif content_type.startswith("text/"): 92 | return ChainResponse( 93 | status=status, headers=headers, body=await response.text() 94 | ) 95 | else: 96 | return ChainResponse( 97 | status=status, headers=headers, body=await response.read() 98 | ) 99 | 100 | def body_contains(self, text: str) -> bool: 101 | if self.body is None: 102 | return False 103 | 104 | if isinstance(self.body, str): 105 | return text in self.body 106 | 107 | elif isinstance(self.body, bytes): 108 | return text.encode("utf-8") in self.body 109 | 110 | elif isinstance(self.body, dict): 111 | return text in self.body 112 | 113 | return False 114 | 115 | def body_matches_kv(self, key: str, value: Any) -> bool: 116 | if self.body is None: 117 | return False 118 | 119 | return ( 120 | isinstance(self.body, dict) and key in self.body and self.body[key] == value 121 | ) 122 | 123 | def to_web_response(self) -> web.Response: 124 | rargs = { 125 | "status": self.status, 126 | "headers": {"Content-Type": self.headers.get(hdrs.CONTENT_TYPE, "")}, 127 | } 128 | if isinstance(self.body, str): 129 | rargs["text"] = self.body 130 | elif isinstance(self.body, bytes): 131 | rargs["body"] = self.body 132 | elif isinstance(self.body, dict): 133 | rargs["body"] = json.dumps(self.body) 134 | return web.Response(**rargs) 135 | 136 | 137 | NextChainResponseCallbackType = ( 138 | Tuple[ClientResponse, ChainResponse] 139 | | Tuple[ClientResponse, ChainResponse, ChainRequest] 140 | ) 141 | 142 | NextChainCallbackType = Callable[ 143 | [ChainRequest], Awaitable[NextChainResponseCallbackType] 144 | ] 145 | 146 | 147 | class RequestMiddlewareBase(ABC): 148 | @abstractmethod 149 | async def handle( 150 | self, next: NextChainCallbackType, request: ChainRequest 151 | ) -> NextChainResponseCallbackType: 152 | pass 153 | 154 | def handle_gen(self, next: NextChainCallbackType) -> NextChainCallbackType: 155 | async def next_invoke(request: ChainRequest) -> NextChainResponseCallbackType: 156 | return await self.handle(next, request) 157 | 158 | return next_invoke 159 | 160 | 161 | class StatsdMiddleware(RequestMiddlewareBase): 162 | def __init__( 163 | self, 164 | statsd_client: TelegrafStatsdClient, 165 | ) -> None: 166 | super().__init__() 167 | self._statsd_client = statsd_client 168 | self._regex = re.compile(r"\W+") 169 | 170 | async def handle( 171 | self, next: NextChainCallbackType, request: ChainRequest 172 | ) -> NextChainResponseCallbackType: 173 | start_time = time() 174 | try: 175 | return await next(request) 176 | except Exception as e: 177 | sentry_sdk.capture_exception(e) 178 | raise e 179 | finally: 180 | # TODO: Don't record URL. Cardinality is too high here and will cause issues. 181 | self._statsd_client.timer( 182 | "aip.client.request.time", 183 | time() - start_time, 184 | tag_dict={ 185 | "url": self._regex.sub("", str(request.url)), 186 | "method": request.method.lower(), 187 | }, 188 | ) 189 | self._statsd_client.increment( 190 | "aip.client.request.count", 191 | 1, 192 | tag_dict={ 193 | "url": self._regex.sub("", str(request.url)), 194 | "method": request.method.lower(), 195 | }, 196 | ) 197 | 198 | 199 | class DebugMiddleware(RequestMiddlewareBase): 200 | async def handle( 201 | self, next: NextChainCallbackType, request: ChainRequest 202 | ) -> NextChainResponseCallbackType: 203 | logger.debug( 204 | f"Request: {request.method} {request.url} {request.headers} {request.kwargs}" 205 | ) 206 | response = await next(request) 207 | logger.debug( 208 | f"Response: {response[1].status} {response[1].headers} {response[1].body}" 209 | ) 210 | return response 211 | 212 | 213 | class GenerateClaimAssertionMiddleware(RequestMiddlewareBase): 214 | def __init__( 215 | self, 216 | signing_key: jwk.JWK, 217 | client_assertion_header: Dict[str, Any], 218 | client_assertion_claims: Dict[str, Any], 219 | ) -> None: 220 | super().__init__() 221 | self._signing_key = signing_key 222 | self._client_assertion_header = client_assertion_header 223 | self._client_assertion_claims = client_assertion_claims 224 | 225 | async def handle( 226 | self, next: NextChainCallbackType, request: ChainRequest 227 | ) -> NextChainResponseCallbackType: 228 | 229 | if request.kwargs is None: 230 | return await next(request) 231 | 232 | self._client_assertion_claims["jti"] = secrets.token_urlsafe(32) 233 | claims_assertation = jwt.JWT( 234 | header=self._client_assertion_header, 235 | claims=self._client_assertion_claims, 236 | ) 237 | claims_assertation.make_signed_token(self._signing_key) 238 | claims_assertation_token = claims_assertation.serialize() 239 | 240 | data: Optional[FormData] = None 241 | if request.kwargs is not None: 242 | data = request.kwargs.get("data", None) 243 | 244 | if data is None: 245 | data = FormData() 246 | 247 | data.add_field("client_assertion", claims_assertation_token) 248 | 249 | request.kwargs["data"] = data 250 | 251 | return await next(request) 252 | 253 | 254 | class GenerateDpopMiddleware(RequestMiddlewareBase): 255 | def __init__( 256 | self, 257 | dpop_key: jwk.JWK, 258 | dop_assertion_header: Dict[str, Any], 259 | dop_assertion_claims: Dict[str, Any], 260 | ) -> None: 261 | super().__init__() 262 | self._dpop_key = dpop_key 263 | self._dpop_assertion_header = dop_assertion_header 264 | self._dpop_assertion_claims = dop_assertion_claims 265 | 266 | async def handle( 267 | self, next: NextChainCallbackType, request: ChainRequest 268 | ) -> NextChainResponseCallbackType: 269 | self._dpop_assertion_claims["jti"] = secrets.token_urlsafe(32) 270 | 271 | dpop_assertation = jwt.JWT( 272 | header=self._dpop_assertion_header, 273 | claims=self._dpop_assertion_claims, 274 | ) 275 | dpop_assertation.make_signed_token(self._dpop_key) 276 | dpop_assertation_token = dpop_assertation.serialize() 277 | 278 | if request.headers is None: 279 | request.headers = {} 280 | request.headers["DPoP"] = dpop_assertation_token 281 | 282 | response = await next(request) 283 | client_response = response[0] 284 | chain_response = response[1] 285 | new_request = None 286 | if len(response) == 3: 287 | new_request = response[2] 288 | 289 | if chain_response.status == 401 or chain_response.status == 400: 290 | if chain_response.headers is None: 291 | raise ValueError("Response headers are None") 292 | 293 | if chain_response.body_matches_kv( 294 | "error", "invalid_dpop_proof" 295 | ) or chain_response.body_matches_kv("error", "use_dpop_nonce"): 296 | self._dpop_assertion_claims["nonce"] = chain_response.headers.get( 297 | "DPoP-Nonce", "" 298 | ) 299 | 300 | if new_request is None: 301 | new_request = ChainRequest.from_chain_request(request) 302 | 303 | if new_request is None: 304 | return client_response, chain_response 305 | return client_response, chain_response, new_request 306 | 307 | 308 | class EndOfLineChainMiddleware: 309 | def __init__( 310 | self, 311 | request_func: RequestFunc, 312 | logger: _LoggerType, 313 | raise_for_status: bool = False, 314 | ) -> None: 315 | super().__init__() 316 | self._request_func = request_func 317 | self._raise_for_status = raise_for_status 318 | self._logger = logger 319 | 320 | async def handle(self, request: ChainRequest) -> NextChainResponseCallbackType: 321 | response: ClientResponse = await self._request_func( 322 | request.method.lower(), 323 | request.url, 324 | headers=request.headers, 325 | trace_request_ctx={ 326 | **(request.trace_request_ctx or {}), 327 | }, 328 | **(request.kwargs or {}), 329 | ) 330 | 331 | if self._raise_for_status: 332 | response.raise_for_status() 333 | 334 | return response, await ChainResponse.from_aiohttp_response(response) 335 | 336 | 337 | class ChainMiddlewareContext: 338 | def __init__( 339 | self, 340 | chain_callback: NextChainCallbackType, 341 | chain_request: ChainRequest, 342 | logger: _LoggerType, 343 | raise_for_status: bool = False, 344 | attempt_max: int = 3, 345 | ) -> None: 346 | self._chain_callback = chain_callback 347 | self._chain_request = chain_request 348 | self._logger = logger 349 | self._raise_for_status = raise_for_status 350 | 351 | self._chain_response: ChainResponse | None = None 352 | self.client_response: ClientResponse | None = None 353 | 354 | self._attempt_max = attempt_max 355 | 356 | async def _do_request(self) -> Tuple[ClientResponse, ChainResponse]: 357 | current_attempt = 0 358 | 359 | chain_request = self._chain_request 360 | 361 | while True: 362 | current_attempt += 1 363 | 364 | if current_attempt > self._attempt_max: 365 | raise Exception("Max attempts reached") 366 | 367 | response = await self._chain_callback(chain_request) 368 | client_response = response[0] 369 | chain_response = response[1] 370 | new_request = None 371 | if len(response) == 3: 372 | new_request = response[2] 373 | 374 | self._chain_response = chain_response 375 | self._client_response = client_response 376 | 377 | if new_request is None: 378 | return client_response, chain_response 379 | 380 | chain_request = new_request 381 | 382 | if self._raise_for_status: 383 | client_response.raise_for_status() 384 | 385 | def __await__(self) -> Generator[Any, None, Tuple[ClientResponse, ChainResponse]]: 386 | return self.__aenter__().__await__() 387 | 388 | async def __aenter__(self) -> Tuple[ClientResponse, ChainResponse]: 389 | return await self._do_request() 390 | 391 | async def __aexit__( 392 | self, 393 | exc_type: type[BaseException] | None, 394 | exc_val: BaseException | None, 395 | exc_tb: TracebackType | None, 396 | ) -> None: 397 | if self.client_response is not None and not self.client_response.closed: 398 | self.client_response.close() 399 | 400 | 401 | class ChainMiddlewareClient: 402 | def __init__( 403 | self, 404 | client_session: ClientSession | None = None, 405 | logger: _LoggerType | None = None, 406 | middleware: Sequence[RequestMiddlewareBase] | None = None, 407 | raise_for_status: bool = False, 408 | *args: Any, 409 | **kwargs: Any, 410 | ) -> None: 411 | if client_session is not None: 412 | client = client_session 413 | closed = None 414 | else: 415 | client = ClientSession(*args, **kwargs) 416 | closed = False 417 | 418 | self._middleware = middleware 419 | 420 | self._client = client 421 | self._closed = closed 422 | 423 | self._logger: _LoggerType = logger or logging.getLogger("aiohttp_chain") 424 | self._raise_for_status = raise_for_status 425 | 426 | def request( 427 | self, 428 | method: str, 429 | url: StrOrURL, 430 | raise_for_status: bool | None = None, 431 | **kwargs: Any, 432 | ) -> ChainMiddlewareContext: 433 | return self._make_request( 434 | method=method, 435 | url=url, 436 | **kwargs, 437 | ) 438 | 439 | def get( 440 | self, 441 | url: StrOrURL, 442 | raise_for_status: bool | None = None, 443 | **kwargs: Any, 444 | ) -> ChainMiddlewareContext: 445 | return self._make_request( 446 | method=hdrs.METH_GET, 447 | url=url, 448 | **kwargs, 449 | ) 450 | 451 | def options( 452 | self, 453 | url: StrOrURL, 454 | raise_for_status: bool | None = None, 455 | **kwargs: Any, 456 | ) -> ChainMiddlewareContext: 457 | return self._make_request( 458 | method=hdrs.METH_OPTIONS, 459 | url=url, 460 | **kwargs, 461 | ) 462 | 463 | def head( 464 | self, 465 | url: StrOrURL, 466 | raise_for_status: bool | None = None, 467 | **kwargs: Any, 468 | ) -> ChainMiddlewareContext: 469 | return self._make_request( 470 | method=hdrs.METH_HEAD, 471 | url=url, 472 | **kwargs, 473 | ) 474 | 475 | def post( 476 | self, 477 | url: StrOrURL, 478 | raise_for_status: bool | None = None, 479 | **kwargs: Any, 480 | ) -> ChainMiddlewareContext: 481 | return self._make_request( 482 | method=hdrs.METH_POST, 483 | url=url, 484 | **kwargs, 485 | ) 486 | 487 | def put( 488 | self, 489 | url: StrOrURL, 490 | raise_for_status: bool | None = None, 491 | **kwargs: Any, 492 | ) -> ChainMiddlewareContext: 493 | return self._make_request( 494 | method=hdrs.METH_PUT, 495 | url=url, 496 | **kwargs, 497 | ) 498 | 499 | def patch( 500 | self, 501 | url: StrOrURL, 502 | raise_for_status: bool | None = None, 503 | **kwargs: Any, 504 | ) -> ChainMiddlewareContext: 505 | return self._make_request( 506 | method=hdrs.METH_PATCH, 507 | url=url, 508 | **kwargs, 509 | ) 510 | 511 | def delete( 512 | self, 513 | url: StrOrURL, 514 | raise_for_status: bool | None = None, 515 | **kwargs: Any, 516 | ) -> ChainMiddlewareContext: 517 | return self._make_request( 518 | method=hdrs.METH_DELETE, 519 | url=url, 520 | **kwargs, 521 | ) 522 | 523 | async def close(self) -> None: 524 | await self._client.close() 525 | self._closed = True 526 | 527 | def _make_request( 528 | self, 529 | method: str, 530 | url: StrOrURL, 531 | raise_for_status: bool | None = None, 532 | **kwargs: Any, 533 | ) -> ChainMiddlewareContext: 534 | chain_request = ChainRequest( 535 | method=method, 536 | url=url, 537 | headers=kwargs.pop("headers", {}), 538 | trace_request_ctx=kwargs.pop("trace_request_ctx", None), 539 | kwargs=kwargs, 540 | ) 541 | 542 | if raise_for_status is None: 543 | raise_for_status = self._raise_for_status 544 | 545 | end_of_line_middleware = EndOfLineChainMiddleware( 546 | request_func=self._client.request, 547 | logger=self._logger, 548 | raise_for_status=raise_for_status, 549 | ) 550 | 551 | chain_callback: NextChainCallbackType = end_of_line_middleware.handle 552 | 553 | full_middleware_chain = reversed(self._middleware or []) 554 | 555 | for mw in full_middleware_chain: 556 | chain_callback = mw.handle_gen(chain_callback) 557 | 558 | return ChainMiddlewareContext( 559 | chain_callback=chain_callback, 560 | chain_request=chain_request, 561 | logger=self._logger, 562 | raise_for_status=raise_for_status, 563 | ) 564 | 565 | async def __aenter__(self) -> "ChainMiddlewareClient": 566 | return self 567 | 568 | async def __aexit__( 569 | self, 570 | exc_type: type[BaseException] | None, 571 | exc_val: BaseException | None, 572 | exc_tb: TracebackType | None, 573 | ) -> None: 574 | await self.close() 575 | 576 | def __del__(self) -> None: 577 | if getattr(self, "_closed", None) is None: 578 | # in case object was not initialized (__init__ raised an exception) 579 | return 580 | 581 | if not self._closed: 582 | self._logger.warning("Aiohttp chain client was not closed") 583 | -------------------------------------------------------------------------------- /src/social/graze/aip/atproto/pds.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Dict 2 | from aiohttp import ClientSession 3 | 4 | 5 | async def oauth_protected_resource( 6 | session: ClientSession, pds: str 7 | ) -> Optional[Dict[str, Any]]: 8 | async with session.get(f"{pds}/.well-known/oauth-protected-resource") as resp: 9 | if resp.status != 200: 10 | return None 11 | return await resp.json() 12 | return None 13 | 14 | 15 | async def oauth_authorization_server( 16 | session: ClientSession, authorization_server: str 17 | ) -> Optional[Dict[str, Any]]: 18 | async with session.get( 19 | f"{authorization_server}/.well-known/oauth-authorization-server" 20 | ) as resp: 21 | if resp.status != 200: 22 | return None 23 | return await resp.json() 24 | return None 25 | -------------------------------------------------------------------------------- /src/social/graze/aip/model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database Models 3 | 4 | This package defines the database models for the AIP service using SQLAlchemy ORM. 5 | These models represent the persistent data structures for authentication, user 6 | information, and service health. 7 | 8 | Key Models: 9 | - base.py: Base SQLAlchemy model with common type definitions 10 | - app_password.py: Models for app password authentication 11 | - oauth.py: Models for OAuth sessions and token management 12 | - handles.py: Models for user handles and DIDs 13 | - health.py: Health monitoring model 14 | 15 | The data models follow these relationships: 16 | - Handle: Represents a user identity with DID, handle, and PDS location 17 | - OAuthRequest: Temporary storage for OAuth request data 18 | - OAuthSession: Active OAuth session with tokens and expiration info 19 | - AppPassword: App-specific password credentials 20 | 21 | Each model includes: 22 | - Creation and expiration timestamps 23 | - Secure storage of credentials (encrypted where appropriate) 24 | - Relationships to other models where needed 25 | 26 | The models use SQLAlchemy's async interface for non-blocking database operations 27 | and include methods for common operations like upserts. 28 | """ -------------------------------------------------------------------------------- /src/social/graze/aip/model/app_password.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sqlalchemy import String, DateTime 3 | from sqlalchemy.orm import Mapped, mapped_column 4 | 5 | from social.graze.aip.model.base import Base, str1024 6 | 7 | 8 | class AppPassword(Base): 9 | __tablename__ = "atproto_app_passwords" 10 | 11 | guid: Mapped[str] = mapped_column(String(64), primary_key=True) 12 | app_password: Mapped[str] 13 | created_at: Mapped[datetime] = mapped_column( 14 | DateTime(timezone=True), nullable=False 15 | ) 16 | 17 | 18 | class AppPasswordSession(Base): 19 | __tablename__ = "atproto_app_password_sessions" 20 | 21 | guid: Mapped[str] = mapped_column(String(64), primary_key=True) 22 | access_token: Mapped[str1024] 23 | access_token_expires_at: Mapped[datetime] = mapped_column( 24 | DateTime(timezone=True), nullable=False 25 | ) 26 | refresh_token: Mapped[str] 27 | refresh_token_expires_at: Mapped[datetime] = mapped_column( 28 | DateTime(timezone=True), nullable=False 29 | ) 30 | created_at: Mapped[datetime] = mapped_column( 31 | DateTime(timezone=True), nullable=False 32 | ) 33 | -------------------------------------------------------------------------------- /src/social/graze/aip/model/base.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import String, orm 2 | from sqlalchemy.orm import mapped_column 3 | 4 | from typing_extensions import Annotated 5 | 6 | str512 = Annotated[str, 512] 7 | str1024 = Annotated[str, 1024] 8 | guidpk = Annotated[str, mapped_column(String(512), primary_key=True)] 9 | 10 | 11 | class Base(orm.DeclarativeBase): 12 | type_annotation_map = { 13 | str512: String(512), 14 | str1024: String(1024), 15 | guidpk: String(512), 16 | } 17 | -------------------------------------------------------------------------------- /src/social/graze/aip/model/handles.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import Mapped 2 | from sqlalchemy.dialects.postgresql import insert 3 | from ulid import ULID 4 | 5 | from social.graze.aip.model.base import Base, str512, guidpk 6 | 7 | 8 | class Handle(Base): 9 | __tablename__ = "handles" 10 | 11 | guid: Mapped[guidpk] 12 | did: Mapped[str512] 13 | handle: Mapped[str512] 14 | pds: Mapped[str512] 15 | 16 | 17 | def upsert_handle_stmt(did: str, handle: str, pds: str): 18 | return ( 19 | insert(Handle) 20 | .values( 21 | [ 22 | { 23 | "guid": str(ULID()), 24 | "did": did, 25 | "handle": handle, 26 | "pds": pds, 27 | } 28 | ] 29 | ) 30 | .on_conflict_do_update( 31 | index_elements=["did"], 32 | set_={ 33 | "handle": handle, 34 | "pds": pds, 35 | }, 36 | ) 37 | .returning(Handle.guid) 38 | ) 39 | -------------------------------------------------------------------------------- /src/social/graze/aip/model/health.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | class HealthGauge: 5 | """ 6 | This is a makeshift health check system. 7 | 8 | This health gauge is used to track the health of the application and provide a way to return a somewhat meaningful 9 | value to readiness probes. 10 | 11 | The health gauge is a simple counter that can be incremented and decremented. When an exception occurs that is 12 | outside of regular application flow-control (an actual error and not a warning), then the counter is 13 | incremented. As time goes on, the counter decrements. When a burst of exceptions occurs, the is_healthy method will 14 | return false, triggering a failed readiness check. 15 | """ 16 | 17 | def __init__(self, value: int = 0, health_threshold: int = 100) -> None: 18 | self._value = value 19 | self._health_threshold = health_threshold 20 | self._lock = asyncio.Lock() 21 | 22 | async def womp(self, d=1) -> int: 23 | async with self._lock: 24 | self._value += int(d) 25 | return self._value 26 | 27 | async def tick(self) -> None: 28 | async with self._lock: 29 | if self._value > 0: 30 | self._value -= 1 31 | 32 | async def is_healthy(self) -> bool: 33 | async with self._lock: 34 | return self._value <= self._health_threshold 35 | -------------------------------------------------------------------------------- /src/social/graze/aip/model/oauth.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from datetime import datetime 3 | from sqlalchemy import Integer, String, DateTime 4 | from sqlalchemy.orm import Mapped, mapped_column 5 | from sqlalchemy.dialects.postgresql import JSON, insert 6 | 7 | from social.graze.aip.model.base import Base, str512, str1024 8 | 9 | 10 | class OAuthRequest(Base): 11 | __tablename__ = "oauth_requests" 12 | 13 | oauth_state: Mapped[str] = mapped_column(String(64), primary_key=True) 14 | issuer: Mapped[str512] 15 | guid: Mapped[str512] 16 | pkce_verifier: Mapped[str] 17 | secret_jwk_id: Mapped[str] 18 | dpop_jwk: Mapped[Any] = mapped_column(JSON, nullable=False) 19 | destination: Mapped[str] 20 | created_at: Mapped[datetime] = mapped_column( 21 | DateTime(timezone=True), nullable=False 22 | ) 23 | expires_at: Mapped[datetime] = mapped_column( 24 | DateTime(timezone=True), nullable=False 25 | ) 26 | 27 | 28 | class OAuthSession(Base): 29 | __tablename__ = "oauth_sessions" 30 | 31 | session_group: Mapped[str] = mapped_column(String(64), primary_key=True) 32 | issuer: Mapped[str512] 33 | guid: Mapped[str512] 34 | access_token: Mapped[str1024] 35 | refresh_token: Mapped[str512] 36 | secret_jwk_id: Mapped[str512] 37 | dpop_jwk: Mapped[Any] = mapped_column(JSON, nullable=False) 38 | created_at: Mapped[datetime] = mapped_column( 39 | DateTime(timezone=True), nullable=False 40 | ) 41 | access_token_expires_at: Mapped[datetime] = mapped_column( 42 | DateTime(timezone=True), nullable=False 43 | ) 44 | hard_expires_at: Mapped[datetime] = mapped_column( 45 | DateTime(timezone=True), nullable=False 46 | ) 47 | 48 | 49 | class Permission(Base): 50 | __tablename__ = "guid_permissions" 51 | 52 | guid: Mapped[str] = mapped_column(String(512), primary_key=True) 53 | target_guid: Mapped[str] = mapped_column(String(512), primary_key=True) 54 | permission: Mapped[int] = mapped_column(Integer, nullable=False) 55 | created_at: Mapped[datetime] = mapped_column( 56 | DateTime(timezone=True), nullable=False 57 | ) 58 | 59 | 60 | def upsert_permission_stmt( 61 | guid: str, target_guid: str, permission: int, created_at: datetime 62 | ): 63 | return ( 64 | insert(Permission) 65 | .values( 66 | [ 67 | { 68 | "guid": guid, 69 | "target_guid": target_guid, 70 | "permission": permission, 71 | "created_at": created_at, 72 | } 73 | ] 74 | ) 75 | .on_conflict_do_update( 76 | index_elements=["guid", "target_guid"], 77 | set_={ 78 | "permission": permission, 79 | }, 80 | ) 81 | ) 82 | -------------------------------------------------------------------------------- /src/social/graze/aip/resolve/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Identity Resolution 3 | 4 | This package provides utilities for resolving AT Protocol identifiers (DIDs, handles) 5 | to their canonical forms, implementing both DNS-based and HTTP-based resolution methods. 6 | 7 | Key Components: 8 | - handle.py: Handle resolution implementation 9 | - __main__.py: CLI interface for resolution 10 | 11 | Resolution Types: 12 | 1. Handle Resolution 13 | - DNS-based resolution via TXT records (_atproto.{handle}) 14 | - HTTP-based resolution via well-known endpoints (.well-known/atproto-did) 15 | 16 | 2. DID Resolution 17 | - did:plc method resolution via PLC directory 18 | - did:web method resolution via well-known endpoints 19 | 20 | The resolution flow typically follows these steps: 21 | 1. Parse the input to determine if it's a handle or DID 22 | 2. For handles, attempt DNS resolution first, then HTTP 23 | 3. For DIDs, use the appropriate method based on the DID type 24 | 4. Return the resolved canonical data (DID, handle, PDS location) 25 | 26 | This implementation follows the AT Protocol specification for identity 27 | resolution, ensuring compatibility with the broader ecosystem. 28 | """ -------------------------------------------------------------------------------- /src/social/graze/aip/resolve/__main__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import argparse 3 | import aiohttp 4 | import asyncio 5 | import logging 6 | import sentry_sdk 7 | 8 | from social.graze.aip.resolve.handle import resolve_subject 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | async def realMain() -> None: 14 | parser = argparse.ArgumentParser(prog="resolve", description="Resolve handles") 15 | parser.add_argument("subject", nargs="+", help="The subject(s) to resolve.") 16 | parser.add_argument( 17 | "--plc-hostname", 18 | default="plc.directory", 19 | help="The PLC hostname to use for resolving did-method-plc DIDs.", 20 | ) 21 | 22 | args = vars(parser.parse_args()) 23 | 24 | subjects: List[str] = args.get("subject", []) 25 | 26 | async with aiohttp.ClientSession() as session: 27 | for subject in subjects: 28 | try: 29 | resolved_handle = await resolve_subject( 30 | session, args.get("plc_hostname", str), subject 31 | ) 32 | print(f"resolved_handle {resolved_handle}") 33 | except Exception as e: 34 | sentry_sdk.capture_exception(e) 35 | logging.exception("Exception resolving subject %s", subject) 36 | 37 | 38 | def main() -> None: 39 | asyncio.run(realMain()) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /src/social/graze/aip/resolve/handle.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | from aiohttp import ClientSession 3 | from pydantic import BaseModel 4 | from aiodns import DNSResolver 5 | from typing import Optional, Any, Dict 6 | import sentry_sdk 7 | 8 | 9 | class SubjectType(IntEnum): 10 | did_method_plc = 1 11 | did_method_web = 2 12 | hostname = 3 13 | 14 | 15 | class ParsedSubject(BaseModel): 16 | subject_type: SubjectType 17 | subject: str 18 | 19 | 20 | class ResolvedSubject(BaseModel): 21 | did: str 22 | handle: str 23 | pds: str 24 | 25 | # @model_validator(mode='after') 26 | # def validate_also_known_as(self) -> 'ResolvedSubject': 27 | # for value in self.also_known_as: 28 | # if subject.startswith("at://did:") 29 | # continue 30 | # if subject.startswith("at://") 31 | # return value 32 | # elif subject.startswith("https://") 33 | # continue 34 | # raise ValueError("No handles returned") 35 | 36 | 37 | async def resolve_handle_dns(handle: str) -> Optional[str]: 38 | # TODO: Wrap hickory-dns and use that 39 | resolver = DNSResolver() 40 | try: 41 | results = await resolver.query(f"_atproto.{handle}", "TXT") 42 | except Exception as e: 43 | sentry_sdk.capture_exception(e) 44 | return None 45 | first_result = next(iter(results or []), None) 46 | if first_result is not None: 47 | return first_result.text.removeprefix("did=") 48 | return None 49 | 50 | 51 | async def resolve_handle_http(session: ClientSession, handle: str) -> Optional[str]: 52 | async with session.get(f"https://{handle}/.well-known/atproto-did") as resp: 53 | if resp.status != 200: 54 | return None 55 | body = await resp.text() 56 | if body is not None: 57 | return body 58 | return None 59 | 60 | 61 | async def resolve_handle(session: ClientSession, handle: str) -> Optional[str]: 62 | did: Optional[str] = None 63 | 64 | did = await resolve_handle_dns(handle) 65 | if did is not None: 66 | return did 67 | 68 | return await resolve_handle_http(session, handle) 69 | 70 | # Nick: Alternatively, we could use an async task group to do these in parallel. 71 | # 72 | # async with asyncio.TaskGroup() as tg: 73 | # dns_result = tg.create_task(resolve_handle_dns(handle)) 74 | # http_result = tg.create_task(resolve_handle_http(session, handle)) 75 | # dns_result = dns_result.result() 76 | # http_result = http_result.result() 77 | # if dns_result is not None: 78 | # return dns_result 79 | # return http_result 80 | 81 | 82 | def handle_predicate(value: str) -> bool: 83 | return value.startswith("at://") 84 | 85 | 86 | def pds_predicate(value: Dict[str, Any]) -> bool: 87 | return ( 88 | value.get("type", None) == "AtprotoPersonalDataServer" 89 | and "serviceEndpoint" in value 90 | ) 91 | 92 | 93 | async def resolve_did_method_plc( 94 | plc_directory: str, session: ClientSession, did: str 95 | ) -> Optional[ResolvedSubject]: 96 | async with session.get(f"https://{plc_directory}/{did}") as resp: 97 | if resp.status != 200: 98 | return None 99 | body = await resp.json() 100 | if body is None: 101 | return None 102 | handle = next(filter(handle_predicate, body.get("alsoKnownAs", [])), None) 103 | pds = next(filter(pds_predicate, body.get("service", [])), None) 104 | if handle is not None and pds is not None: 105 | return ResolvedSubject( 106 | did=did, 107 | handle=handle.removeprefix("at://"), 108 | pds=pds.get("serviceEndpoint"), 109 | ) 110 | return None 111 | 112 | 113 | async def resolve_did_method_web( 114 | session: ClientSession, did: str 115 | ) -> Optional[ResolvedSubject]: 116 | 117 | parts = did.removeprefix("did:web:").split(":") 118 | if len(parts) == 0: 119 | return None 120 | 121 | if len(parts) == 1: 122 | parts.append(".well-known") 123 | 124 | url = "https://{inner}/did.json".format(inner="/".join(parts)) 125 | 126 | async with session.get(url) as resp: 127 | if resp.status != 200: 128 | return None 129 | body = await resp.json() 130 | if body is None: 131 | return None 132 | handle = next(filter(handle_predicate, body.get("alsoKnownAs", [])), None) 133 | pds = next(filter(pds_predicate, body.get("service", [])), None) 134 | if handle is not None and pds is not None: 135 | return ResolvedSubject( 136 | did=did, 137 | handle=handle.removeprefix("at://"), 138 | pds=pds.get("serviceEndpoint"), 139 | ) 140 | return None 141 | 142 | 143 | async def resolve_did( 144 | session: ClientSession, plc_hostname: str, did: str 145 | ) -> Optional[ResolvedSubject]: 146 | if did.startswith("did:plc:"): 147 | return await resolve_did_method_plc(plc_hostname, session, did) 148 | elif did.startswith("did:web:"): 149 | return await resolve_did_method_web(session, did) 150 | return None 151 | 152 | 153 | async def resolve_subject( 154 | session: ClientSession, plc_hostname: str, subject: str 155 | ) -> Optional[ResolvedSubject]: 156 | parsed_subject = parse_input(subject) 157 | if parsed_subject is None: 158 | return None 159 | 160 | did: Optional[str] = None 161 | if parsed_subject.subject_type == SubjectType.hostname: 162 | did = await resolve_handle(session, parsed_subject.subject) 163 | elif parsed_subject.subject_type == SubjectType.did_method_plc: 164 | did = parsed_subject.subject 165 | elif parsed_subject.subject_type == SubjectType.did_method_web: 166 | did = parsed_subject.subject 167 | 168 | if did is None: 169 | return None 170 | 171 | return await resolve_did(session, plc_hostname, did) 172 | 173 | 174 | def parse_input(subject: str) -> Optional[ParsedSubject]: 175 | subject = subject.strip() 176 | subject = subject.removeprefix("at://") 177 | subject = subject.removeprefix("@") 178 | 179 | if subject.startswith("did:plc:"): 180 | return ParsedSubject(subject_type=SubjectType.did_method_plc, subject=subject) 181 | elif subject.startswith("did:web:"): 182 | return ParsedSubject(subject_type=SubjectType.did_method_web, subject=subject) 183 | 184 | # TODO: Validate this hostname 185 | return ParsedSubject(subject_type=SubjectType.hostname, subject=subject) 186 | -------------------------------------------------------------------------------- /telegraf.conf: -------------------------------------------------------------------------------- 1 | # Global agent configuration 2 | [agent] 3 | interval = "10s" 4 | round_interval = true 5 | metric_batch_size = 1000 6 | metric_buffer_limit = 10000 7 | collection_jitter = "0s" 8 | flush_interval = "10s" 9 | flush_jitter = "0s" 10 | precision = "" 11 | hostname = "aip_telegraf" 12 | omit_hostname = false 13 | 14 | # Input plugins 15 | [[inputs.cpu]] 16 | percpu = true 17 | totalcpu = true 18 | collect_cpu_time = false 19 | report_active = false 20 | 21 | [[inputs.mem]] 22 | fieldpass = ["used_percent"] 23 | 24 | [[inputs.disk]] 25 | mount_points = ["/"] 26 | ignore_fs = ["tmpfs", "devtmpfs"] 27 | 28 | [[inputs.statsd]] 29 | service_address = ":8125" 30 | 31 | [[outputs.file]] 32 | files = ["/tmp/telegraf.out"] 33 | data_format = "json" 34 | -------------------------------------------------------------------------------- /templates/atproto_debug.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 |
4 |
5 |

Debug

6 | 7 |
{{ auth_token.header }}
8 |
{{ auth_token.claims }}
9 |
{{ handle.__dict__ }}
10 |
{{ oauth_session.__dict__ }}
11 |
12 |
13 |

New Login

14 |
15 |
16 | {% endblock %} -------------------------------------------------------------------------------- /templates/atproto_login.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} {% block content %} 2 |
3 |
6 |
7 |
8 | {% if svg_logo %} 9 | Graze 14 | {% endif %} {% if brand_name %} 15 |

19 | {{ brand_name }} 20 |

21 | {% endif %} 22 |
23 |
24 |
29 |
30 | 42 |
43 | 44 |
45 | 51 |
52 |
53 |
54 |
55 |
56 |
57 | {% endblock %} 58 | -------------------------------------------------------------------------------- /templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | {% block title %}Graze Social AIP{% endblock %} 9 | 10 | 11 | 83 | 84 | 85 | 86 | 87 | {% block content %}{% endblock %} 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 |
4 |
5 |

AIP

6 |

The Graze Social Authentication, Identity, and Permissions Server

7 |
8 |
9 |
10 |
11 |

New Login

12 |
13 |
14 | {% endblock %} -------------------------------------------------------------------------------- /templates/settings.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 |
4 |

AIP

5 |

The Graze Social Authentication, Identity, and Permissions Server

6 |
7 |
8 |

Settings

9 |
10 | {% endblock %} -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_example.py: -------------------------------------------------------------------------------- 1 | from social.graze.aip.calc import inc 2 | 3 | 4 | def test_answer(): 5 | assert inc(3) == 4 6 | --------------------------------------------------------------------------------