Debug
6 | 7 |{{ auth_token.header }}
8 | {{ auth_token.claims }}
9 | {{ handle.__dict__ }}
10 | {{ oauth_session.__dict__ }}
11 | ├── .devcontainer
├── devcontainer.json
├── docker-compose.yml
├── post-create.sh
├── postgres_init.sql
└── telegraf.conf
├── .flake8
├── .github
└── workflows
│ └── release.yaml
├── .gitignore
├── .svu.yaml
├── .vscode
├── launch.json
└── settings.json
├── Dockerfile
├── LICENSE
├── README.md
├── aip.png
├── alembic.ini
├── alembic
├── README
├── env.py
├── script.py.mako
└── versions
│ └── 380b77abd479_init.py
├── bin
└── scripts
│ ├── entrypoint.sh
│ ├── init-db.sh
│ └── init.sh
├── docker-compose.yml
├── pdm.lock
├── pyproject.toml
├── src
└── social
│ ├── __init__.py
│ └── graze
│ ├── __init__.py
│ └── aip
│ ├── __init__.py
│ ├── app
│ ├── __init__.py
│ ├── __main__.py
│ ├── config.py
│ ├── cors.py
│ ├── handlers
│ │ ├── __init__.py
│ │ ├── app_password.py
│ │ ├── credentials.py
│ │ ├── helpers.py
│ │ ├── internal.py
│ │ ├── oauth.py
│ │ ├── permissions.py
│ │ └── proxy.py
│ ├── server.py
│ ├── tasks.py
│ └── util
│ │ ├── __init__.py
│ │ └── __main__.py
│ ├── atproto
│ ├── __init__.py
│ ├── app_password.py
│ ├── chain.py
│ ├── oauth.py
│ └── pds.py
│ ├── model
│ ├── __init__.py
│ ├── app_password.py
│ ├── base.py
│ ├── handles.py
│ ├── health.py
│ └── oauth.py
│ └── resolve
│ ├── __init__.py
│ ├── __main__.py
│ └── handle.py
├── static
├── login.css
├── pico.classless.blue.css
└── pico.colors.min.css
├── telegraf.conf
├── templates
├── atproto_debug.html
├── atproto_login.html
├── base.html
├── index.html
└── settings.html
└── tests
├── __init__.py
└── test_example.py
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "dockerComposeFile": [
3 | "docker-compose.yml"
4 | ],
5 | "service": "devcontainer",
6 | "workspaceFolder": "/workspace",
7 | "features": {
8 | "ghcr.io/devcontainers/features/docker-in-docker:2": {}
9 | },
10 | "postCreateCommand": "/workspace/.devcontainer/post-create.sh",
11 | "forwardPorts": [5100]
12 | }
--------------------------------------------------------------------------------
/.devcontainer/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 | volumes:
3 | aip_db:
4 | aip_ts:
5 | services:
6 | devcontainer:
7 | image: mcr.microsoft.com/devcontainers/python:3.13
8 | volumes:
9 | - ..:/workspace:cached
10 | - /var/run/docker.sock:/var/run/docker.sock
11 | command: sleep infinity
12 | environment:
13 | - HTTP_PORT=5100
14 | - TZ=America/New_York
15 | - DATABASE_URL=postgres://postgres:password@postgres/aip
16 | - JSON_WEB_KEYS=/workspace/signing_keys.json
17 |
18 | postgres:
19 | image: postgres:17-alpine
20 | restart: unless-stopped
21 | volumes:
22 | - aip_db:/var/lib/postgresql/data
23 | - ./postgres_init.sql:/docker-entrypoint-initdb.d/init.sql
24 | environment:
25 | - POSTGRES_PASSWORD=password
26 | healthcheck:
27 | test: 'pg_isready -U postgres'
28 | interval: 500ms
29 | timeout: 10s
30 | retries: 20
31 |
32 | telegraf:
33 | image: docker.io/telegraf:latest
34 | volumes:
35 | - ./telegraf.conf:/etc/telegraf/telegraf.conf
36 |
37 | valkey:
38 | image: valkey/valkey:8-alpine
39 |
40 | tailscale:
41 | image: tailscale/tailscale:latest
42 | restart: unless-stopped
43 | environment:
44 | - TS_STATE_DIR=/var/run/tailscale
45 | - TS_EXTRA_ARGS=--advertise-tags=tag:aip
46 | volumes:
47 | - aip_ts:/var/run/tailscale
48 |
--------------------------------------------------------------------------------
/.devcontainer/post-create.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | sudo usermod -a -G docker vscode
4 |
5 | sudo apt-get update
6 | sudo apt-get install -y postgresql-client
7 |
8 | curl -sSL https://pdm-project.org/install-pdm.py | python3 -
9 |
--------------------------------------------------------------------------------
/.devcontainer/postgres_init.sql:
--------------------------------------------------------------------------------
1 | -- aip
2 | CREATE DATABASE aip;
3 | GRANT ALL PRIVILEGES ON DATABASE aip TO postgres;
4 |
--------------------------------------------------------------------------------
/.devcontainer/telegraf.conf:
--------------------------------------------------------------------------------
1 |
2 | [global_tags]
3 | [agent]
4 | interval = "10s"
5 | round_interval = true
6 | metric_batch_size = 1000
7 | metric_buffer_limit = 10000
8 | collection_jitter = "0s"
9 | flush_interval = "10s"
10 | flush_jitter = "0s"
11 | precision = "0s"
12 | [[outputs.influxdb_v2]]
13 | urls = ["http://metrics.bowfin-woodpecker.ts.net:8086"]
14 | token = "7eAc0CgtNV4-yeDwnKl01tBYdxMrMPtlmmz3h-urW6uBGt2Uv3byhsTkgwiHnkn45Vr0gdaqnc6tbyfhUWxpEw=="
15 | organization = "graze"
16 | bucket = "telegraf"
17 | [[inputs.statsd]]
18 | protocol = "udp"
19 | service_address = ":8125"
20 | datadog_extensions = true
21 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | extend-ignore = E203
3 | exclude =
4 | .git,
5 | .venv,
6 | .pdm-build,
7 | .pytest_cache,
8 | __pycache__,
9 | alembic,
10 | dist
11 | max-complexity = 20
12 | max-line-length = 120
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Create AIP Release
2 | on:
3 | pull_request:
4 | types:
5 | - closed
6 | branches:
7 | - main
8 | jobs:
9 | release_aip:
10 | runs-on: ubuntu-latest
11 | name: create_release
12 | permissions:
13 | contents: write
14 | id-token: write
15 | steps:
16 | - name: checkout
17 | uses: actions/checkout@v3
18 | with:
19 | fetch-depth: 0
20 |
21 | - name: setup_svu
22 | run: curl -kL https://github.com/caarlos0/svu/releases/download/v3.2.2/svu_3.2.2_linux_amd64.tar.gz | tar zx && mv svu /usr/local/bin/svu && chmod +x /usr/local/bin/svu
23 |
24 | - name: create_tag
25 | id: create_tag
26 | run: |
27 | echo "VERSION_TAG=$(svu next)" >> $GITHUB_ENV
28 | echo "VERSION_TAG=$(svu next)" >> $GITHUB_OUTPUT
29 |
30 | - name: create_release
31 | env:
32 | GH_TOKEN: ${{ github.token }}
33 | run: |-
34 | gh release create ${{ env.VERSION_TAG }} -t ${{ env.VERSION_TAG }} --generate-notes
35 |
36 | - name: repo_dispatch
37 | uses: peter-evans/repository-dispatch@v3
38 | with:
39 | repository: graze-social/turbo-deploy
40 | event-type: aip-release
41 | token: ${{ secrets.DISPATCH_GH_TOKEN }}
42 | client-payload: |-
43 | {
44 | "ref": "${{ github.ref }}",
45 | "sha": "${{ github.sha }}",
46 | "version_tag": "${{env.VERSION_TAG}}",
47 | "pr_context": ${{toJson(github.event.pull_request)}}
48 | }
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 | scratch/
173 | .DS_Store
174 | **/*/.DS_Store
175 |
--------------------------------------------------------------------------------
/.svu.yaml:
--------------------------------------------------------------------------------
1 | always: true
2 | v0: true
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "name": "Python Debugger: Module",
5 | "type": "debugpy",
6 | "request": "launch",
7 | "module": "social.graze.aip.app.__main__",
8 | "env": {
9 | "ACTIVE_SIGNING_KEYS": "[\"01JK8QQSW7YSVDEYQS3J7ZE3XJ\"]",
10 | "EXTERNAL_HOSTNAME": "grazeaip.tunn.dev",
11 | "WORKER_ID": "dev1",
12 | "SERVICE_AUTH_KEYS": "[\"01JK8QQSW7YSVDEYQS3J7ZE3XJ\"]",
13 | "DATABASE_URL": "postgres://postgres:password@postgres/aip",
14 | "ENCRYPTION_KEY": "TFZTT0lOSWlTSk8xbWhJMzJYY0lBS2dQcXFQMTVjX0o4ZUlJTFNyWVpzQT0=",
15 | "JSON_WEB_KEYS": "/workspace/signing_keys.json",
16 | "PLC_HOSTNAME": "plc.bowfin-woodpecker.ts.net",
17 | "DEBUG": "true"
18 | },
19 | }
20 | ]
21 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.inlayHints.pytestParameters": true,
3 | "python.analysis.inlayHints.functionReturnTypes": true,
4 | "python.analysis.inlayHints.variableTypes": true,
5 | "python.analysis.typeCheckingMode": "standard"
6 | }
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use a lightweight Python image
2 | FROM python:3.13-slim
3 |
4 | # Set working directory
5 | WORKDIR /app
6 |
7 | # Install system dependencies
8 | RUN apt-get update && apt-get install -y \
9 | curl \
10 | libpq-dev \
11 | postgresql-client \
12 | gcc \
13 | python3-dev \
14 | jq \
15 | && rm -rf /var/lib/apt/lists/*
16 |
17 | # Install PDM globally
18 | RUN pip install pdm
19 |
20 | # TODO: we can split into a multi-stage build and just ship the .venv
21 | # just keeping everything for now
22 | COPY . .
23 |
24 | # Ensure PDM always uses the in-project virtual environment
25 | ENV PDM_VENV_IN_PROJECT=1
26 | ENV PATH="/app/.venv/bin:$PATH"
27 |
28 | # Install dependencies properly inside the virtual environment
29 | RUN pdm install
30 |
31 | # Expose the application port
32 | # TODO: These should be configurable, not hard-coded and thus publishing them here is moot.
33 | EXPOSE 8080
34 | EXPOSE 5100
35 |
36 | # Available CMDs
37 | # See pyproject.toml for more details
38 | # CMD ["pdm", "run", "aipserver"]
39 | # CMD ["pdm", "run", "resolve"]
40 | # CMD ["pdm", "run", "aiputil"]
41 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Graze Social
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ATmosphere Authentication, Identity, and Permission Proxy
2 |
3 | 
4 | ## Running Locally
5 | http://localhost:8080/internal/api/resolve?subject=ngerakines.me&subject=mattie.thegem.city
6 |
7 | 1. Install `pdm install` (Note that `pdm install` may require `sudo apt install -y clang libpq-dev python3-dev build-essential` to build for postgres requirements!)
8 | 2. Start up a postgres server, `export DATABASE_URL=address` for that server
9 | 3. Run migrations `pdm run alembic upgrade head`
10 | 4. Populate signing keys: `OUT=$(pdm run aiputil gen-jwk)` and then `echo "{\"keys\":[$OUT]}" > signing_keys.json`
11 |
12 | 5. Set hostname: `export EXTERNAL_HOSTNAME=grazeaip.tunn.dev`
13 | 6. Set plc hostname: `export PLC_HOSTNAME=plc.bowfin-woodpecker.ts.net`
14 | 7. Set active signing keys: `export ACTIVE_SIGNING_KEYS='["{KEY_ID}"]'`
15 | 8. Start service `pdm run aipserver`
16 | 9. Generate a handle: https://pdsdns.bowfin-woodpecker.ts.net
17 | 10. Verify resolution: `pdm run resolve --plc-hostname ${PLC_HOSTNAME} enabling-boxer.pyroclastic.cloud`
18 | 11. Auth with it: https://grazeaip.tunn.dev/auth/atproto
19 |
20 | ## Running via Docker
21 |
22 | 1. Install `pdm install` (Note that `pdm install` may require `sudo apt install -y clang libpq-dev python3-dev build-essential` to build for postgres requirements!)
23 | 2. Populate signing keys: `OUT=$(pdm run aiputil gen-jwk)` and then `echo "{\"keys\":[$OUT]}" > signing_keys.json`
24 | 3. Add `EXTERNAL_HOSTNAME: your-host` to your docker compose file[1].
25 | 4. `docker compose build && docker compose up`
26 | 5. Then navigate to your host at the path /auth/atproto. This is your login page!
27 | 6. Note that in the config you can change colors, text, display image, and default post-login destination. You can *also* forward the response to *any* URL with a ?destination={URL} parameter on the sign-in page
28 |
29 | [1]: This must match the URL this service is running on. For development purposes, you can install ngrok then run `ngrok http 8080` which will forward traffic on https to a specified ngrok URL. You would then take that host (without https) and put it in your docker compose before starting up.
30 |
31 | ## How to Use AIP tokens to access ATProto / Bluesky:
32 |
33 | Please see this [example usage file in python](https://gist.github.com/DGaffney/99f209e5ff9bb01cc50c4202c9c46554) and port to your use case!
34 |
--------------------------------------------------------------------------------
/aip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/aip.png
--------------------------------------------------------------------------------
/alembic.ini:
--------------------------------------------------------------------------------
1 | [alembic]
2 | script_location = alembic
3 |
4 | prepend_sys_path = .
5 |
6 | version_path_separator = os
7 |
8 | sqlalchemy.url = %(DATABASE_URL)s
9 |
10 |
11 | [post_write_hooks]
12 |
13 | [loggers]
14 | keys = root,sqlalchemy,alembic
15 |
16 | [handlers]
17 | keys = console
18 |
19 | [formatters]
20 | keys = generic
21 |
22 | [logger_root]
23 | level = WARNING
24 | handlers = console
25 | qualname =
26 |
27 | [logger_sqlalchemy]
28 | level = WARNING
29 | handlers =
30 | qualname = sqlalchemy.engine
31 |
32 | [logger_alembic]
33 | level = INFO
34 | handlers =
35 | qualname = alembic
36 |
37 | [handler_console]
38 | class = StreamHandler
39 | args = (sys.stderr,)
40 | level = NOTSET
41 | formatter = generic
42 |
43 | [formatter_generic]
44 | format = %(levelname)-5.5s [%(name)s] %(message)s
45 | datefmt = %H:%M:%S
46 |
--------------------------------------------------------------------------------
/alembic/README:
--------------------------------------------------------------------------------
1 | Generic single-database configuration.
--------------------------------------------------------------------------------
/alembic/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from logging.config import fileConfig
3 |
4 | from sqlalchemy import engine_from_config, pool
5 | from alembic import context
6 |
7 | # Load Alembic Config
8 | config = context.config
9 |
10 | # Set up logging
11 | if config.config_file_name is not None:
12 | fileConfig(config.config_file_name)
13 |
14 | # Use DATABASE_URL from environment variables, with a fallback
15 | DATABASE_URL = os.getenv("DATABASE_URL")
16 | DATABASE_URL = DATABASE_URL.replace("postgresql+asyncpg", "postgresql+psycopg2")
17 |
18 | # Override sqlalchemy.url in Alembic config dynamically
19 | config.set_main_option("sqlalchemy.url", DATABASE_URL)
20 |
21 | # Metadata for migrations (update if needed)
22 | target_metadata = None
23 |
24 |
25 | def run_migrations_offline() -> None:
26 | """Run migrations in 'offline' mode."""
27 | context.configure(
28 | url=DATABASE_URL, # Use the dynamic URL here
29 | target_metadata=target_metadata,
30 | literal_binds=True,
31 | dialect_opts={"paramstyle": "named"},
32 | )
33 |
34 | with context.begin_transaction():
35 | context.run_migrations()
36 |
37 |
38 | def run_migrations_online() -> None:
39 | """Run migrations in 'online' mode."""
40 | connectable = engine_from_config(
41 | config.get_section(config.config_ini_section, {}),
42 | prefix="sqlalchemy.",
43 | poolclass=pool.NullPool,
44 | )
45 |
46 | with connectable.connect() as connection:
47 | context.configure(connection=connection, target_metadata=target_metadata)
48 |
49 | with context.begin_transaction():
50 | context.run_migrations()
51 |
52 |
53 | if context.is_offline_mode():
54 | run_migrations_offline()
55 | else:
56 | run_migrations_online()
57 |
--------------------------------------------------------------------------------
/alembic/script.py.mako:
--------------------------------------------------------------------------------
1 | """${message}
2 |
3 | Revision ID: ${up_revision}
4 | Revises: ${down_revision | comma,n}
5 | Create Date: ${create_date}
6 |
7 | """
8 | from typing import Sequence, Union
9 |
10 | from alembic import op
11 | import sqlalchemy as sa
12 | ${imports if imports else ""}
13 |
14 | # revision identifiers, used by Alembic.
15 | revision: str = ${repr(up_revision)}
16 | down_revision: Union[str, None] = ${repr(down_revision)}
17 | branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
18 | depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
19 |
20 |
21 | def upgrade() -> None:
22 | ${upgrades if upgrades else "pass"}
23 |
24 |
25 | def downgrade() -> None:
26 | ${downgrades if downgrades else "pass"}
27 |
--------------------------------------------------------------------------------
/alembic/versions/380b77abd479_init.py:
--------------------------------------------------------------------------------
1 | """init
2 |
3 | Revision ID: 380b77abd479
4 | Revises:
5 | Create Date: 2025-02-03 16:15:52.235674
6 |
7 | """
8 |
9 | from typing import Sequence, Union
10 |
11 | from alembic import op
12 | import sqlalchemy as sa
13 |
14 |
15 | # revision identifiers, used by Alembic.
16 | revision: str = "380b77abd479"
17 | down_revision: Union[str, None] = None
18 | branch_labels: Union[str, Sequence[str], None] = None
19 | depends_on: Union[str, Sequence[str], None] = None
20 |
21 |
22 | def upgrade() -> None:
23 | op.create_table(
24 | "handles",
25 | sa.Column("guid", sa.String(512), primary_key=True),
26 | sa.Column("did", sa.String(512), nullable=False),
27 | sa.Column("handle", sa.String(512), nullable=False),
28 | sa.Column("pds", sa.String(512), nullable=False),
29 | )
30 | op.create_index("idx_handles_did", "handles", ["did"], unique=True)
31 | op.create_index("idx_handles_handle", "handles", ["handle"])
32 |
33 | op.create_table(
34 | "oauth_requests",
35 | sa.Column("oauth_state", sa.String(64), primary_key=True),
36 | sa.Column("issuer", sa.String(512), nullable=False),
37 | sa.Column("guid", sa.String(512), nullable=False),
38 | sa.Column("pkce_verifier", sa.String(128), nullable=False),
39 | sa.Column("secret_jwk_id", sa.String(32), nullable=False),
40 | sa.Column("dpop_jwk", sa.JSON, nullable=False),
41 | sa.Column("destination", sa.String(512), nullable=False),
42 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
43 | sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
44 | )
45 | op.create_index("idx_oauth_requests_guid", "oauth_requests", ["guid"])
46 | op.create_index("idx_oauth_requests_expires", "oauth_requests", ["expires_at"])
47 |
48 | op.create_table(
49 | "oauth_sessions",
50 | sa.Column("session_group", sa.String(64), primary_key=True),
51 | sa.Column("access_token", sa.String(1024), nullable=False),
52 | sa.Column("guid", sa.String(512), nullable=False),
53 | sa.Column("refresh_token", sa.String(512), nullable=False),
54 | sa.Column("issuer", sa.String(512), nullable=False),
55 | sa.Column("secret_jwk_id", sa.String(32), nullable=False),
56 | sa.Column("dpop_jwk", sa.JSON, nullable=False),
57 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
58 | sa.Column(
59 | "access_token_expires_at", sa.DateTime(timezone=True), nullable=False
60 | ),
61 | sa.Column("hard_expires_at", sa.DateTime(timezone=True), nullable=False),
62 | )
63 | op.create_index("idx_oauth_sessions_guid", "oauth_sessions", ["guid"])
64 | op.create_index(
65 | "idx_oauth_sessions_expires", "oauth_sessions", ["access_token_expires_at"]
66 | )
67 |
68 | op.create_table(
69 | "atproto_app_passwords",
70 | sa.Column("guid", sa.String(512), primary_key=True),
71 | sa.Column("app_password", sa.String(512), nullable=False),
72 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
73 | )
74 |
75 | op.create_table(
76 | "atproto_app_password_sessions",
77 | sa.Column("guid", sa.String(512), primary_key=True),
78 | sa.Column("access_token", sa.String(1024), nullable=False),
79 | sa.Column(
80 | "access_token_expires_at", sa.DateTime(timezone=True), nullable=False
81 | ),
82 | sa.Column("refresh_token", sa.String(512), nullable=False),
83 | sa.Column(
84 | "refresh_token_expires_at", sa.DateTime(timezone=True), nullable=False
85 | ),
86 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
87 | )
88 |
89 | # `guid` has permission `permission` on `target_guid`
90 | op.create_table(
91 | "guid_permissions",
92 | sa.Column("guid", sa.String(512), nullable=False),
93 | sa.Column("target_guid", sa.String(512), nullable=False),
94 | sa.Column("permission", sa.Integer, nullable=False),
95 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
96 | )
97 | op.create_primary_key(
98 | "pk_guid_permissions", "guid_permissions", ["guid", "target_guid"]
99 | )
100 | op.create_index(
101 | "idx_guid_permissions_target_guid", "guid_permissions", ["target_guid"]
102 | )
103 |
104 |
105 | def downgrade() -> None:
106 | op.drop_table("handles")
107 | op.drop_table("oauth_requests")
108 | op.drop_table("oauth_sessions")
109 | op.drop_table("atproto_app_passwords")
110 | op.drop_table("guid_permissions")
111 | op.drop_table("atproto_app_password_sessions")
112 |
--------------------------------------------------------------------------------
/bin/scripts/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -e # Exit immediately if a command exits with a non-zero status
3 |
4 | # Run init script
5 | /app/init.sh
6 |
7 | # Start the AIP server
8 | echo "Starting AIP server..."
9 | exec pdm run aipserver
10 |
--------------------------------------------------------------------------------
/bin/scripts/init-db.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Create the database if it doesn't exist
5 | psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL
6 | CREATE DATABASE aip_db;
7 | EOSQL
8 |
--------------------------------------------------------------------------------
/bin/scripts/init.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | echo "Waiting for database connection..."
3 | until pg_isready -h db -p 5432 -U aip; do
4 | sleep 2
5 | done
6 | echo "Database is ready."
7 |
8 | # Ensure the database exists
9 | echo "Checking if database exists..."
10 | DB_EXISTS=$(PGPASSWORD="aip_password" psql -h db -U aip -tAc "SELECT 1 FROM pg_database WHERE datname='aip_db'")
11 | if [ "$DB_EXISTS" != "1" ]; then
12 | echo "Database aip_db not found, creating..."
13 | PGPASSWORD="aip_password" createdb -h db -U aip aip_db
14 | else
15 | echo "Database already exists."
16 | fi
17 |
18 | # Generate signing keys if they don't exist
19 | if [ ! -f signing_keys.json ]; then
20 | echo "Generating signing keys..."
21 | SIGNING_KEY=$(pdm run aiputil gen-jwk)
22 | if [ -n "$SIGNING_KEY" ]; then
23 | echo "{\"keys\":[$SIGNING_KEY]}" > signing_keys.json
24 | echo "Signing keys generated."
25 | else
26 | echo "Error generating signing keys!"
27 | exit 1
28 | fi
29 | else
30 | echo "Signing keys already exist."
31 | fi
32 |
33 | # Extract 'kid' values from signing_keys.json
34 | if command -v jq >/dev/null 2>&1; then
35 | export ACTIVE_SIGNING_KEYS=$(jq -c '[.keys[].kid]' signing_keys.json)
36 | else
37 | echo "Error: jq is required but not installed. Install jq to continue."
38 | exit 1
39 | fi
40 |
41 | # Run Alembic migrations
42 | echo "Running Alembic migrations..."
43 | pdm run alembic upgrade head || { echo "Alembic migrations failed!"; exit 1; }
44 |
45 |
46 | # Mark initialization as complete
47 | touch /app/init_done
48 | echo "Initialization complete."
49 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | db:
3 | image: postgres:17
4 | container_name: aip_db
5 | restart: always
6 | environment:
7 | POSTGRES_USER: aip
8 | POSTGRES_PASSWORD: aip_password
9 | POSTGRES_DB: aip_db
10 | volumes:
11 | - aip_db_data:/var/lib/postgresql/data
12 | ports:
13 | - "5432:5432"
14 | healthcheck:
15 | test: ["CMD-SHELL", "pg_isready -U aip -d aip_db"]
16 | interval: 5s
17 | retries: 5
18 | timeout: 3s
19 |
20 | valkey:
21 | image: valkey/valkey:8-alpine
22 | container_name: aip_valkey
23 | restart: always
24 | ports:
25 | - "6379:6379"
26 |
27 | telegraf:
28 | image: telegraf:latest
29 | container_name: aip_telegraf
30 | restart: always
31 | depends_on:
32 | - db
33 | - valkey
34 | ports:
35 | - "8125:8125/udp"
36 | volumes:
37 | - ./telegraf.conf:/etc/telegraf/telegraf.conf
38 |
39 | aip:
40 | build: .
41 | container_name: aip_service
42 | command: ["pdm", "run", "aipserver"]
43 | restart: always
44 | depends_on:
45 | db:
46 | condition: service_healthy
47 | valkey:
48 | condition: service_started
49 | telegraf:
50 | condition: service_started
51 | environment:
52 | DATABASE_URL: postgresql+asyncpg://aip:aip_password@db:5432/aip_db
53 | REDIS_URL: redis://valkey:6379/0
54 | DEBUG: "true"
55 | WORKER_ID: "dev1"
56 | ACTIVE_SIGNING_KEYS: '["01JNEKAHBPFQYJX3RS7HH7W2RY"]'
57 | JSON_WEB_KEYS: /app/signing_keys.json
58 | ports:
59 | - "8080:8080"
60 | - "5100:5100"
61 | volumes:
62 | - .:/app # Mount local directory to container
63 | - /app/.venv # Preserve virtual environment
64 | # TODO: move to Tilt-based debugging
65 | # volumes:
66 | # - .:/app
67 | # - /app/.venv
68 |
69 | volumes:
70 | aip_db_data:
71 | influxdb_data:
72 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["pdm-backend"]
3 | build-backend = "pdm.backend"
4 |
5 | [project]
6 | name = "aip"
7 | version = "0.2.4"
8 | description = "ATmosphere Authentication, Identity, and Permission Proxy"
9 | requires-python = ">=3.13"
10 |
11 | license = { file = "LICENSE" }
12 |
13 | authors = [{ name = "Nick Gerakines", email = "nick.gerakines@gmail.com" }]
14 | maintainers = [{ name = "Nick Gerakines", email = "nick.gerakines@gmail.com" }]
15 |
16 | classifiers = [
17 | "Development Status :: 3 - Alpha",
18 | "Programming Language :: Python :: 3",
19 | "Programming Language :: Python :: 3.13",
20 | ]
21 |
22 | dependencies = [
23 | "sqlalchemy>=2.0.37",
24 | "pydantic>=2.10.6",
25 | "aiohttp>=3.11.11",
26 | "jinja2>=3.1.5",
27 | "aiohttp-jinja2>=1.6",
28 | "alembic>=1.14.1",
29 | "psycopg2>=2.9.10",
30 | "jwcrypto>=1.5.6",
31 | "pydantic-settings>=2.7.1",
32 | "python-json-logger>=3.2.1",
33 | "aiodns>=3.2.0",
34 | "asyncpg>=0.30.0",
35 | "python-ulid>=3.0.0",
36 | "cryptography>=44.0.1",
37 | "redis>=5.2.1",
38 | "aio-statsd>=0.2.9",
39 | "sentry-sdk>=2.24.1",
40 | ]
41 | readme = "README.md"
42 |
43 | [project.optional-dependencies]
44 | dev = ["flake8>=7.1.1", "pytest>=8.3.4", "black>=25.1.0"]
45 | test = ["coverage"]
46 |
47 | [project.urls]
48 | "Homepage" = "https://github.com/graze-social/aip"
49 | "Bug Reports" = "https://github.com/graze-social/aip/issues"
50 | "Source" = "https://github.com/graze-social/aip"
51 |
52 | [project.scripts]
53 | aipserver = "social.graze.aip.app.__main__:main"
54 | resolve = "social.graze.aip.resolve.__main__:main"
55 | aiputil = "social.graze.aip.app.util.__main__:main"
56 |
57 | [tool.pdm]
58 | distribution = true
59 |
60 | [tool.pdm.build]
61 | includes = ["src/social", "LICENSE"]
62 |
63 | [tool.pytest.ini_options]
64 | pythonpath = ["src/"]
65 |
66 | [tool.pyright]
67 | venvPath = "."
68 | venv = ".venv"
69 |
--------------------------------------------------------------------------------
/src/social/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/__init__.py
--------------------------------------------------------------------------------
/src/social/graze/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/graze/__init__.py
--------------------------------------------------------------------------------
/src/social/graze/aip/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | AIP - AT Protocol Identity Provider
3 |
4 | This module implements an AT Protocol Identity Provider service that handles user authentication
5 | and identity resolution for the Bluesky/AT Protocol ecosystem. It serves as a bridge between
6 | users and AT Protocol services, managing authentication flows, token management, and identity
7 | resolution.
8 |
9 | Key Components:
10 | - app: Web application layer with request handlers and server configuration
11 | - atproto: Integration with AT Protocol, handling authentication and PDS communication
12 | - model: Database models for storing authentication data and user information
13 | - resolve: Identity resolution utilities for AT Protocol DIDs and handles
14 |
15 | Architecture Overview:
16 | 1. Authentication Flow:
17 | - User initiates OAuth flow through the service
18 | - Service verifies user identity with AT Protocol PDS
19 | - Secure tokens are issued and managed
20 |
21 | 2. Identity Resolution:
22 | - Resolves user handles to DIDs through DNS and HTTP mechanisms
23 | - Resolves DIDs to canonical data (handle, PDS location)
24 |
25 | 3. Token Management:
26 | - Background tasks refresh tokens before expiry
27 | - Redis-backed token caching
28 | - Task distribution using work queues
29 |
30 | The service is designed with security, performance, and reliability in mind,
31 | following OAuth 2.0 and AT Protocol specifications.
32 | """
--------------------------------------------------------------------------------
/src/social/graze/aip/app/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | AIP Application Layer
3 |
4 | This package implements the web application layer for the AIP service, handling HTTP requests
5 | and responses using the aiohttp framework. It provides handlers for OAuth flows, app password
6 | authentication, and internal API endpoints.
7 |
8 | Key Components:
9 | - __main__.py: Entry point for running the application
10 | - server.py: Web server configuration and middleware setup
11 | - config.py: Configuration management using Pydantic settings
12 | - handlers/: Request handlers for different endpoints
13 | - tasks.py: Background tasks for token refresh and health monitoring
14 | - cors.py: CORS handling for cross-origin requests
15 | - util/: Utility functions for the application layer
16 |
17 | The application uses several middleware layers:
18 | - CORS middleware for handling cross-origin requests
19 | - Statsd middleware for metrics collection
20 | - Sentry middleware for error reporting
21 |
22 | It provides the following main endpoints:
23 | - OAuth authentication endpoints (/auth/atproto/*)
24 | - Internal API endpoints (/internal/api/*)
25 | - XRPC proxy endpoints (/xrpc/*)
26 | """
--------------------------------------------------------------------------------
/src/social/graze/aip/app/__main__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from aiohttp import web
3 | import logging
4 | from logging.config import dictConfig
5 | import json
6 |
7 |
8 | def configure_logging():
9 | logging_config_file = os.getenv("LOGGING_CONFIG_FILE", "")
10 |
11 | if len(logging_config_file) > 0:
12 | with open(logging_config_file) as fl:
13 | dictConfig(json.load(fl))
14 | return
15 |
16 | logging.basicConfig()
17 | logging.getLogger().setLevel(logging.DEBUG)
18 |
19 |
20 | def main():
21 | configure_logging()
22 |
23 | from social.graze.aip.app.server import start_web_server
24 |
25 | web.run_app(start_web_server())
26 |
27 |
28 | if __name__ == "__main__":
29 | main()
30 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration Module for AIP Service
3 |
4 | This module defines the configuration system for the AIP (AT Protocol Identity Provider) service,
5 | using Pydantic for settings validation and dependency injection through AppKeys.
6 |
7 | The configuration follows these principles:
8 | 1. Environment-based configuration with sensible defaults
9 | 2. Strong validation and typing through Pydantic
10 | 3. Dependency injection pattern using aiohttp's app context
11 | 4. Secure handling of cryptographic materials
12 |
13 | The Settings class serves as the central configuration point, loaded from environment variables
14 | with defaults suitable for development environments. All application components access settings
15 | and shared resources through typed AppKeys to maintain clean dependency injection.
16 |
17 | Key configuration areas include:
18 | - Service identification and networking
19 | - Database and cache connections
20 | - Cryptographic materials (signing keys, encryption)
21 | - Background processing configuration
22 | - UI customization
23 | """
24 |
25 | import os
26 | import asyncio
27 | from typing import Annotated, Final, List, Optional
28 | import logging
29 | from aio_statsd import TelegrafStatsdClient
30 | from jwcrypto import jwk
31 | from pydantic import (
32 | AliasChoices,
33 | Field,
34 | field_validator,
35 | PostgresDsn,
36 | RedisDsn,
37 | )
38 | import base64
39 | from pydantic_settings import BaseSettings, NoDecode
40 | from aiohttp import web
41 | from cryptography.fernet import Fernet
42 | from sqlalchemy.ext.asyncio import (
43 | AsyncEngine,
44 | async_sessionmaker,
45 | AsyncSession,
46 | )
47 | from aiohttp import ClientSession
48 | from redis import asyncio as redis
49 |
50 | from social.graze.aip.model.health import HealthGauge
51 |
52 |
53 | logger = logging.getLogger(__name__)
54 |
55 |
56 | class Settings(BaseSettings):
57 | """
58 | Application settings for the AIP service.
59 |
60 | This class uses Pydantic's BaseSettings to automatically load values from environment
61 | variables, with sensible defaults for development environments. It handles validation,
62 | type conversion, and provides centralized configuration management.
63 |
64 | Settings are organized into the following categories:
65 | - Environment and debugging
66 | - Network and service identification
67 | - Database and cache connections
68 | - Security and cryptography
69 | - Background task configuration
70 | - Monitoring and observability
71 | - UI customization
72 |
73 | Environment variables are automatically mapped to settings fields, with aliases
74 | provided for backward compatibility. For example, the database connection string
75 | can be set with either PG_DSN or DATABASE_URL environment variables.
76 | """
77 |
78 | # Environment and debugging settings
79 | debug: bool = False
80 | """
81 | Enable debug mode for verbose logging and development features.
82 | Set with DEBUG=true environment variable.
83 | """
84 |
85 | allowed_domains: str = "https://www.graze.social, https://sky-feeder-git-astro-graze.vercel.app"
86 | """
87 | Comma-separated list of domains allowed for CORS.
88 | Set with ALLOWED_DOMAINS environment variable.
89 | """
90 |
91 | # Network and service identification settings
92 | http_port: int = Field(alias="port", default=5100)
93 | """
94 | HTTP port for the service to listen on.
95 | Set with PORT environment variable.
96 | """
97 |
98 | external_hostname: str = "aip_service"
99 | """
100 | Public hostname for the service, used for generating callback URLs.
101 | Set with EXTERNAL_HOSTNAME environment variable.
102 | """
103 |
104 | plc_hostname: str = "plc.directory"
105 | """
106 | Hostname for the PLC directory service for DID resolution.
107 | Set with PLC_HOSTNAME environment variable.
108 | """
109 |
110 | # Monitoring and error reporting
111 | sentry_dsn: Optional[str] = None
112 | """
113 | Sentry DSN for error reporting. Optional, no error reporting if not set.
114 | Set with SENTRY_DSN environment variable.
115 | """
116 |
117 | # Database and cache connections
118 | redis_dsn: RedisDsn = Field(
119 | "redis://valkey:6379/1?decode_responses=True",
120 | validation_alias=AliasChoices("redis_dsn", "redis_url"),
121 | ) # type: ignore
122 | """
123 | Redis connection string for caching and background tasks.
124 | Set with REDIS_DSN or REDIS_URL environment variables.
125 | Default: redis://valkey:6379/1?decode_responses=True
126 | """
127 |
128 | pg_dsn: PostgresDsn = Field(
129 | "postgresql+asyncpg://postgres:password@db/aip",
130 | validation_alias=AliasChoices("pg_dsn", "database_url"),
131 | ) # type: ignore
132 | """
133 | PostgreSQL connection string for database access.
134 | Set with PG_DSN or DATABASE_URL environment variables.
135 | Default: postgresql+asyncpg://postgres:password@db/aip
136 | """
137 |
138 | # Security and cryptography settings
139 | json_web_keys: Annotated[jwk.JWKSet, NoDecode] = jwk.JWKSet()
140 | """
141 | JSON Web Key Set containing signing keys for JWT operations.
142 | Can be set to a JWKSet object or path to a JSON file containing keys.
143 | Set with JSON_WEB_KEYS environment variable.
144 | """
145 |
146 | active_signing_keys: List[str] = list()
147 | """
148 | List of key IDs (kid) from json_web_keys that should be used for signing.
149 | Set with ACTIVE_SIGNING_KEYS environment variable as comma-separated values.
150 | """
151 |
152 | service_auth_keys: List[str] = list()
153 | """
154 | List of key IDs (kid) from json_web_keys used for service-to-service auth.
155 | Set with SERVICE_AUTH_KEYS environment variable as comma-separated values.
156 | """
157 |
158 | encryption_key: Fernet = Fernet(Fernet.generate_key())
159 | """
160 | Fernet symmetric encryption key for sensitive data.
161 | Can be set to a Fernet object or base64-encoded key string.
162 | Set with ENCRYPTION_KEY environment variable.
163 | """
164 |
165 | # Worker identification
166 | worker_id: str
167 | """
168 | Unique identifier for this worker instance (required, no default).
169 | Used to distribute work among multiple instances.
170 | Set with WORKER_ID environment variable.
171 | """
172 |
173 | # Background processing configuration
174 | refresh_queue_oauth: str = "refresh_queue:oauth"
175 | """
176 | Redis queue name for OAuth token refresh tasks.
177 | Set with REFRESH_QUEUE_OAUTH environment variable.
178 | """
179 |
180 | refresh_queue_app_password: str = "refresh_queue:app_password"
181 | """
182 | Redis queue name for App Password refresh tasks.
183 | Set with REFRESH_QUEUE_APP_PASSWORD environment variable.
184 | """
185 |
186 | # Token expiration settings
187 | app_password_access_token_expiry: int = 720 # 12 minutes
188 | """
189 | Expiration time in seconds for app password access tokens.
190 | Set with APP_PASSWORD_ACCESS_TOKEN_EXPIRY environment variable.
191 | Default: 720 (12 minutes)
192 | """
193 |
194 | app_password_refresh_token_expiry: int = 7776000 # 90 days
195 | """
196 | Expiration time in seconds for app password refresh tokens.
197 | Set with APP_PASSWORD_REFRESH_TOKEN_EXPIRY environment variable.
198 | Default: 7776000 (90 days)
199 | """
200 |
201 | token_refresh_before_expiry_ratio: float = 0.8
202 | """
203 | Ratio of token lifetime to wait before refreshing.
204 | For example, 0.8 means tokens are refreshed after 80% of their lifetime.
205 | Set with TOKEN_REFRESH_BEFORE_EXPIRY_RATIO environment variable.
206 | Default: 0.8
207 | """
208 |
209 | oauth_refresh_max_retries: int = 3
210 | """
211 | Maximum number of retry attempts for failed OAuth refresh operations.
212 | Set with OAUTH_REFRESH_MAX_RETRIES environment variable.
213 | Default: 3
214 | """
215 |
216 | oauth_refresh_retry_base_delay: int = 300
217 | """
218 | Base delay in seconds for OAuth refresh retry attempts (exponential backoff).
219 | Actual delay = base_delay * (2 ^ retry_attempt)
220 | Set with OAUTH_REFRESH_RETRY_BASE_DELAY environment variable.
221 | Default: 300 (5 minutes)
222 | """
223 |
224 | default_destination: str = "https://localhost:5100/auth/atproto/debug"
225 | """
226 | Default redirect destination after authentication if none specified.
227 | Set with DEFAULT_DESTINATION environment variable.
228 | """
229 |
230 | # Monitoring and observability settings
231 | statsd_host: str = Field(alias="TELEGRAF_HOST", default="telegraf")
232 | """
233 | StatsD/Telegraf host for metrics collection.
234 | Set with TELEGRAF_HOST environment variable.
235 | """
236 |
237 | statsd_port: int = Field(alias="TELEGRAF_PORT", default=8125)
238 | """
239 | StatsD/Telegraf port for metrics collection.
240 | Set with TELEGRAF_PORT environment variable.
241 | """
242 |
243 | statsd_prefix: str = "aip"
244 | """
245 | Prefix for all StatsD metrics from this service.
246 | Set with STATSD_PREFIX environment variable.
247 | """
248 |
249 | # UI customization settings for login page
250 | svg_logo: str = "https://www.graze.social/logo.svg"
251 | """URL for the logo displayed on the login page"""
252 |
253 | brand_name: str = "Graze"
254 | """Brand name displayed on the login page"""
255 |
256 | destination: str = "https://graze.social/app/auth/callback"
257 | """Default destination URL after authentication"""
258 |
259 | background_from: str = "#0588f0"
260 | """Starting gradient color for login page background"""
261 |
262 | background_to: str = "#5eb1ef"
263 | """Ending gradient color for login page background"""
264 |
265 | text_color: str = "#FFFFFF"
266 | """Text color for login page"""
267 |
268 | form_color: str = "#FFFFFF"
269 | """Form background color for login page"""
270 |
271 | @field_validator("json_web_keys", mode="before")
272 | @classmethod
273 | def decode_json_web_keys(cls, v) -> jwk.JWKSet:
274 | """
275 | Validate and process the json_web_keys setting.
276 |
277 | This validator accepts either:
278 | - An existing JWKSet object (for programmatic configuration)
279 | - A file path to a JSON file containing a JWK Set
280 |
281 | Args:
282 | v: The input value to validate
283 |
284 | Returns:
285 | jwk.JWKSet: A valid JWKSet object
286 |
287 | Raises:
288 | ValueError: If the input is neither a JWKSet nor a valid file path
289 | """
290 | if isinstance(v, jwk.JWKSet): # If it's already a JWKSet, return it directly
291 | return v
292 | elif isinstance(v, str): # If it's a file path, load from file
293 | with open(v) as fd:
294 | data = fd.read()
295 | return jwk.JWKSet.from_json(data)
296 | raise ValueError(
297 | "json_web_keys must be a JWKSet object or a valid JSON file path"
298 | )
299 |
300 | @field_validator("encryption_key", mode="before")
301 | @classmethod
302 | def decode_encryption_key(cls, v) -> Fernet:
303 | """
304 | Validate and process the encryption_key setting.
305 |
306 | This validator accepts either:
307 | - An existing Fernet object (for programmatic configuration)
308 | - A base64-encoded string containing a Fernet key
309 |
310 | Args:
311 | v: The input value to validate
312 |
313 | Returns:
314 | Fernet: A valid Fernet encryption object
315 |
316 | Raises:
317 | ValueError: If the input is neither a Fernet object nor a valid base64 key
318 | """
319 | if isinstance(v, Fernet): # Already a Fernet instance, return it
320 | return v
321 | elif isinstance(v, str): # Decode from a base64-encoded string
322 | key_data = base64.b64decode(v)
323 | return Fernet(key_data)
324 | raise ValueError(
325 | "encryption_key must be a Fernet object or a base64-encoded key string"
326 | )
327 |
328 |
329 | # Background task queue constants
330 | OAUTH_REFRESH_QUEUE = "auth_session:oauth:refresh"
331 | """
332 | Redis sorted set key for scheduling OAuth token refresh operations.
333 | Contains session_group IDs with refresh timestamps as scores.
334 | """
335 |
336 | OAUTH_REFRESH_RETRY_QUEUE = "auth_session:oauth:refresh:retry"
337 | """
338 | Redis hash key for tracking OAuth refresh retry attempts.
339 | Keys are session_group IDs, values are retry counts.
340 | """
341 |
342 | APP_PASSWORD_REFRESH_QUEUE = "auth_session:app-password:refresh"
343 | """
344 | Redis sorted set key for scheduling App Password refresh operations.
345 | Contains user GUIDs with refresh timestamps as scores.
346 | """
347 |
348 | # Application context keys for dependency injection
349 | SettingsAppKey: Final = web.AppKey("settings", Settings)
350 | """AppKey for accessing the application settings"""
351 |
352 | DatabaseAppKey: Final = web.AppKey("database", AsyncEngine)
353 | """AppKey for accessing the SQLAlchemy async database engine"""
354 |
355 | DatabaseSessionMakerAppKey: Final = web.AppKey(
356 | "database_session_maker", async_sessionmaker[AsyncSession]
357 | )
358 | """AppKey for accessing the SQLAlchemy async session factory"""
359 |
360 | SessionAppKey: Final = web.AppKey("http_session", ClientSession)
361 | """AppKey for accessing the shared aiohttp client session"""
362 |
363 | RedisPoolAppKey: Final = web.AppKey("redis_pool", redis.ConnectionPool)
364 | """AppKey for accessing the Redis connection pool"""
365 |
366 | RedisClientAppKey: Final = web.AppKey("redis_client", redis.Redis)
367 | """AppKey for accessing the Redis client"""
368 |
369 | HealthGaugeAppKey: Final = web.AppKey("health_gauge", HealthGauge)
370 | """AppKey for accessing the health monitoring gauge"""
371 |
372 | OAuthRefreshTaskAppKey: Final = web.AppKey("oauth_refresh_task", asyncio.Task[None])
373 | """AppKey for the background task that refreshes OAuth tokens"""
374 |
375 | AppPasswordRefreshTaskAppKey: Final = web.AppKey(
376 | "app_password_refresh_task", asyncio.Task[None]
377 | )
378 | """AppKey for the background task that refreshes App Passwords"""
379 |
380 | TickHealthTaskAppKey: Final = web.AppKey("tick_health_task", asyncio.Task[None])
381 | """AppKey for the background task that monitors service health"""
382 |
383 | TelegrafStatsdClientAppKey: Final = web.AppKey(
384 | "telegraf_statsd_client", TelegrafStatsdClient
385 | )
386 | """AppKey for the Telegraf/StatsD metrics client"""
--------------------------------------------------------------------------------
/src/social/graze/aip/app/cors.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | from urllib.parse import urlparse
3 |
4 | def get_cors_headers(
5 | origin_value: Optional[str], path: str, debug: bool
6 | ) -> Dict[str, str]:
7 | """Return appropriate CORS headers based on origin and path."""
8 | allowed_origins = {
9 | "https://graze.social",
10 | "https://www.graze.social",
11 | "https://sky-feeder-git-astro-graze.vercel.app",
12 | }
13 |
14 | allowed_debug_hosts = {
15 | "localhost",
16 | "127.0.0.1",
17 | }
18 |
19 | headers = {
20 | "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
21 | "Access-Control-Allow-Headers": (
22 | "Keep-Alive, User-Agent, X-Requested-With, "
23 | "If-Modified-Since, Cache-Control, Content-Type, "
24 | "Authorization, X-Subject, X-Service"
25 | ),
26 | "Vary": "Origin"
27 | }
28 |
29 | if path.startswith("/auth/"):
30 | headers["Access-Control-Allow-Origin"] = "*"
31 | elif origin_value:
32 | parsed = urlparse(origin_value)
33 | base = f"{parsed.scheme}://{parsed.hostname}" if parsed.scheme and parsed.hostname else origin_value
34 |
35 | if base in allowed_origins:
36 | headers["Access-Control-Allow-Origin"] = origin_value
37 | elif debug and parsed.hostname in allowed_debug_hosts:
38 | headers["Access-Control-Allow-Origin"] = origin_value
39 |
40 | return headers
41 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/graze/aip/app/handlers/__init__.py
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/app_password.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone, timedelta
2 | import logging
3 | from typing import Optional
4 | from aiohttp import web
5 | from pydantic import BaseModel, ValidationError, field_validator
6 | from sqlalchemy import delete
7 | from sqlalchemy.dialects.postgresql import insert
8 | import sentry_sdk
9 |
10 | from social.graze.aip.app.config import (
11 | APP_PASSWORD_REFRESH_QUEUE,
12 | DatabaseSessionMakerAppKey,
13 | RedisClientAppKey,
14 | TelegrafStatsdClientAppKey,
15 | )
16 | from social.graze.aip.app.handlers.helpers import auth_token_helper
17 | from social.graze.aip.model.app_password import AppPassword
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class AppPasswordOperation(BaseModel):
24 | value: Optional[str] = None
25 |
26 | @field_validator("value")
27 | def app_password_check(cls, v: Optional[str]) -> Optional[str]:
28 | if v is None:
29 | return None
30 |
31 | if len(v) != 19:
32 | raise ValueError("invalid format")
33 |
34 | if v.count("-") != 3:
35 | raise ValueError("invalid format")
36 |
37 | return v
38 |
39 |
40 | async def handle_internal_app_password(request: web.Request) -> web.Response:
41 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
42 | redis_session = request.app[RedisClientAppKey]
43 | statsd_client = request.app[TelegrafStatsdClientAppKey]
44 |
45 | try:
46 | data = await request.read()
47 | app_password_operation = AppPasswordOperation.model_validate_json(data)
48 | except (OSError, ValidationError) as e:
49 | # TODO: Fix the returned error message when JSON fails because of pydantic validation functions.
50 | sentry_sdk.capture_exception(e)
51 | return web.json_response(status=400, data={"error": "Invalid JSON"})
52 |
53 | try:
54 | async with (database_session_maker() as database_session,):
55 | auth_token = await auth_token_helper(
56 | database_session, statsd_client, request, allow_permissions=False
57 | )
58 | if auth_token is None:
59 | return web.json_response(status=401, data={"error": "Not Authorized"})
60 |
61 | now = datetime.now(timezone.utc)
62 |
63 | async with database_session.begin():
64 | if app_password_operation.value is None:
65 | stmt = delete(AppPassword).where(
66 | AppPassword.guid == auth_token.guid
67 | )
68 | await database_session.execute(stmt)
69 | await redis_session.zrem(
70 | APP_PASSWORD_REFRESH_QUEUE, auth_token.guid
71 | )
72 | else:
73 | stmt = (
74 | insert(AppPassword)
75 | .values(
76 | [
77 | {
78 | "guid": auth_token.guid,
79 | "app_password": app_password_operation.value,
80 | "created_at": now,
81 | }
82 | ]
83 | )
84 | .on_conflict_do_update(
85 | index_elements=["guid"],
86 | set_={"app_password": app_password_operation.value},
87 | )
88 | )
89 | await database_session.execute(stmt)
90 |
91 | refresh_at = now + timedelta(0, 5)
92 |
93 | await redis_session.zadd(
94 | APP_PASSWORD_REFRESH_QUEUE,
95 | {auth_token.guid: int(refresh_at.timestamp())},
96 | )
97 |
98 | await database_session.commit()
99 |
100 | app_password_key = f"auth_session:app-password:{auth_token.guid}"
101 | await redis_session.delete(app_password_key)
102 |
103 | return web.Response(status=200)
104 | except web.HTTPException as e:
105 | sentry_sdk.capture_exception(e)
106 | logging.exception("handle_internal_permissions: web.HTTPException")
107 | raise e
108 | except Exception as e:
109 | sentry_sdk.capture_exception(e)
110 | logging.exception("handle_internal_permissions: Exception")
111 | return web.json_response(status=500, data={"error": "Internal Server Error"})
112 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/credentials.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from aiohttp import web
3 | import sentry_sdk
4 |
5 | from social.graze.aip.app.config import (
6 | DatabaseSessionMakerAppKey,
7 | TelegrafStatsdClientAppKey,
8 | )
9 | from social.graze.aip.app.handlers.helpers import auth_token_helper
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | async def handle_internal_credentials(request: web.Request) -> web.Response:
15 | statsd_client = request.app[TelegrafStatsdClientAppKey]
16 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
17 |
18 | try:
19 | async with (database_session_maker() as database_session,):
20 | # TODO: Allow optional auth here.
21 | auth_token = await auth_token_helper(
22 | database_session, statsd_client, request
23 | )
24 | if auth_token is None:
25 | return web.json_response(status=401, data={"error": "Not Authorized"})
26 | except web.HTTPException as e:
27 | sentry_sdk.capture_exception(e)
28 | raise e
29 | except Exception as e:
30 | sentry_sdk.capture_exception(e)
31 | return web.json_response(status=500, data={"error": "Internal Server Error"})
32 |
33 | if auth_token.app_password_session is not None:
34 | return web.json_response(
35 | {
36 | "type": "bearer",
37 | "token": auth_token.app_password_session.access_token,
38 | }
39 | )
40 |
41 | if auth_token.oauth_session is not None:
42 | return web.json_response(
43 | {
44 | "type": "dpop",
45 | "token": auth_token.oauth_session.access_token,
46 | "jwk": auth_token.oauth_session.dpop_jwk,
47 | "issuer": auth_token.oauth_session.issuer,
48 | }
49 | )
50 |
51 | return web.json_response(
52 | {
53 | "type": "none",
54 | }
55 | )
56 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/helpers.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from datetime import datetime, timezone
3 | import json
4 | import logging
5 | from typing import (
6 | Optional,
7 | Dict,
8 | )
9 | from aio_statsd import TelegrafStatsdClient
10 | from aiohttp import web
11 | from jwcrypto import jwt
12 | from sqlalchemy import select
13 | from sqlalchemy.ext.asyncio import (
14 | AsyncSession,
15 | )
16 | import sentry_sdk
17 |
18 | from social.graze.aip.app.config import (
19 | SettingsAppKey,
20 | )
21 | from social.graze.aip.model.app_password import AppPasswordSession
22 | from social.graze.aip.model.handles import Handle
23 | from social.graze.aip.model.oauth import OAuthSession, Permission
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | @dataclass(repr=False, eq=False)
29 | class AuthToken:
30 | """
31 | Represents an authenticated token with user identity and session information.
32 |
33 | This class contains both the authenticating user's information (guid, subject, handle)
34 | and the context for the current request (which may be different if using permissions
35 | to act on behalf of another user).
36 |
37 | Attributes:
38 | guid: The guid of the authenticating user
39 | subject: The DID of the authenticating user
40 | handle: The handle of the authenticating user
41 |
42 | context_service: The service URL for the context of the request
43 | context_guid: The guid for the context of the request
44 | context_subject: The DID for the context of the request
45 | context_pds: The PDS URL for the context of the request
46 |
47 | oauth_session: The OAuth session if authenticated via OAuth
48 | app_password_session: The App Password session if authenticated via App Password
49 | """
50 | guid: str
51 | subject: str
52 | handle: str
53 |
54 | context_service: str
55 | context_guid: str
56 | context_subject: str
57 | context_pds: str
58 |
59 | oauth_session: Optional[OAuthSession] = None
60 | app_password_session: Optional[AppPasswordSession] = None
61 |
62 |
63 | class AuthenticationException(Exception):
64 | """
65 | Exception raised for authentication failures.
66 |
67 | This exception class provides static methods for creating specific
68 | authentication failure instances with appropriate error messages.
69 | """
70 |
71 | @staticmethod
72 | def jwt_subject_missing() -> "AuthenticationException":
73 | """JWT is missing the required 'sub' claim."""
74 | return AuthenticationException("error-auth-helper-1000 JWT missing subject")
75 |
76 | @staticmethod
77 | def jwt_session_group_missing() -> "AuthenticationException":
78 | """JWT is missing the required 'grp' claim."""
79 | return AuthenticationException(
80 | "error-auth-helper-1001 JWT missing session group"
81 | )
82 |
83 | @staticmethod
84 | def session_not_found() -> "AuthenticationException":
85 | """No valid session was found for the authenticated user."""
86 | return AuthenticationException(
87 | "error-auth-helper-1002 No valid session found"
88 | )
89 |
90 | @staticmethod
91 | def session_expired() -> "AuthenticationException":
92 | """The session has expired and is no longer valid."""
93 | return AuthenticationException(
94 | "error-auth-helper-1003 Session has expired"
95 | )
96 |
97 | @staticmethod
98 | def handle_not_found() -> "AuthenticationException":
99 | """No handle record was found for the authenticated user."""
100 | return AuthenticationException(
101 | "error-auth-helper-1004 Handle record not found"
102 | )
103 |
104 | @staticmethod
105 | def permission_denied() -> "AuthenticationException":
106 | """User does not have permission to perform the requested action."""
107 | return AuthenticationException(
108 | "error-auth-helper-1005 Permission denied"
109 | )
110 |
111 | @staticmethod
112 | def unexpected(msg: str = "") -> "AuthenticationException":
113 | """An unexpected error occurred during authentication."""
114 | return AuthenticationException(
115 | f"error-auth-helper-1999 Unexpected authentication error: {msg}"
116 | )
117 |
118 |
119 | async def auth_token_helper(
120 | database_session: AsyncSession,
121 | statsd_client: TelegrafStatsdClient,
122 | request: web.Request,
123 | allow_permissions: bool = True,
124 | ) -> Optional[AuthToken]:
125 | """
126 | Authenticate a request and return an AuthToken with user and context information.
127 |
128 | This helper enforces the following policies:
129 | * All API calls must be authenticated and require an `Authorization` header with a bearer token.
130 | * The `X-Subject` header can optionally specify the subject of the request. The value must be a known guid.
131 | * The `X-Service` header can optionally specify the hostname of the service providing the invoked XRPC method.
132 |
133 | The function validates the JWT token, retrieves the associated session information, and checks permissions
134 | if the request is acting on behalf of another user.
135 |
136 | Args:
137 | database_session: SQLAlchemy async session for database queries
138 | statsd_client: Statsd client for metrics
139 | request: The HTTP request to authenticate
140 | allow_permissions: Whether to allow acting on behalf of another user via permissions
141 |
142 | Returns:
143 | An AuthToken object if authentication succeeds, None otherwise
144 |
145 | Raises:
146 | AuthenticationException: If there's a specific authentication failure that should be reported
147 | """
148 |
149 | # Check for Authorization header with Bearer token
150 | authorizations: Optional[str] = request.headers.getone("Authorization", None)
151 | if (
152 | authorizations is None
153 | or not authorizations.startswith("Bearer ")
154 | or len(authorizations) < 8
155 | ):
156 | return None
157 |
158 | serialized_auth_token = authorizations[7:]
159 |
160 | settings = request.app[SettingsAppKey]
161 |
162 | try:
163 | # Validate the JWT token
164 | validated_auth_token = jwt.JWT(
165 | jwt=serialized_auth_token, key=settings.json_web_keys, algs=["ES256"]
166 | )
167 |
168 | # Parse and validate claims
169 | auth_token_claims: Dict[str, str] = json.loads(validated_auth_token.claims)
170 |
171 | auth_token_subject: Optional[str] = auth_token_claims.get("sub", None)
172 | if auth_token_subject is None:
173 | raise AuthenticationException.jwt_subject_missing()
174 |
175 | auth_token_session_group: Optional[str] = auth_token_claims.get("grp", None)
176 | if auth_token_session_group is None:
177 | raise AuthenticationException.jwt_session_group_missing()
178 |
179 | now = datetime.now(timezone.utc)
180 |
181 | async with database_session.begin():
182 |
183 | # 1. Get the OAuthSession from the database and validate it
184 | oauth_session_stmt = select(OAuthSession).where(
185 | OAuthSession.guid == auth_token_subject,
186 |
187 | # Sessions are not limited to the same session group because we
188 | # don't actually care about which session we end up getting. The
189 | # only requirement is that it's valid.
190 | # OAuthSession.session_group == auth_token_session_group,
191 |
192 | OAuthSession.access_token_expires_at > now,
193 | OAuthSession.hard_expires_at > now
194 | ).order_by(OAuthSession.created_at.desc())
195 |
196 | oauth_session: Optional[OAuthSession] = (
197 | await database_session.scalars(oauth_session_stmt)
198 | ).first()
199 |
200 | if oauth_session is None:
201 | raise AuthenticationException.session_not_found()
202 |
203 | # 2. Get the Handle from the database and validate it
204 | oauth_session_handle_stmt = select(Handle).where(
205 | Handle.guid == auth_token_subject
206 | )
207 | oauth_session_handle_result = await database_session.scalars(oauth_session_handle_stmt)
208 | oauth_session_handle = oauth_session_handle_result.first()
209 |
210 | if oauth_session_handle is None:
211 | raise AuthenticationException.handle_not_found()
212 |
213 | # 3. Get the X-Subject header, defaulting to the current oauth session handle's guid
214 | x_subject: str = request.headers.getone(
215 | "X-Subject", oauth_session_handle.guid
216 | )
217 |
218 | # If the subject of the request is the same as the subject of the auth token,
219 | # then we have everything we need and can return a fully formed AuthToken
220 | if x_subject == oauth_session_handle.guid:
221 | x_service: str = request.headers.getone(
222 | "X-Service", oauth_session_handle.pds
223 | )
224 |
225 | # Look up app password session if it exists
226 | app_password_session_stmt = select(AppPasswordSession).where(
227 | AppPasswordSession.guid == oauth_session.guid,
228 | )
229 | app_password_session: Optional[AppPasswordSession] = (
230 | await database_session.scalars(app_password_session_stmt)
231 | ).first()
232 |
233 | return AuthToken(
234 | oauth_session=oauth_session,
235 | app_password_session=app_password_session,
236 | guid=oauth_session_handle.guid,
237 | subject=oauth_session_handle.did,
238 | handle=oauth_session_handle.handle,
239 | context_service=x_service,
240 | context_guid=oauth_session_handle.guid,
241 | context_subject=oauth_session_handle.did,
242 | context_pds=oauth_session_handle.pds,
243 | )
244 |
245 | # If permissions are not allowed but the subject differs, deny the request
246 | if allow_permissions is False:
247 | raise AuthenticationException.permission_denied()
248 |
249 | # 4. Get the permission record for the oauth session handle to the x_repository guid
250 | permission_stmt = select(Permission).where(
251 | Permission.guid == oauth_session_handle.guid,
252 | Permission.target_guid == x_subject,
253 | Permission.permission > 0,
254 | )
255 | permission: Optional[Permission] = (
256 | await database_session.scalars(permission_stmt)
257 | ).first()
258 |
259 | # If no permission is found, deny access
260 | if permission is None:
261 | raise AuthenticationException.permission_denied()
262 |
263 | # Get the handle for the target subject
264 | subject_handle_stmt = select(Handle).where(
265 | Handle.guid == permission.target_guid
266 | )
267 | subject_handle_result = await database_session.scalars(subject_handle_stmt)
268 | subject_handle = subject_handle_result.first()
269 |
270 | if subject_handle is None:
271 | raise AuthenticationException.handle_not_found()
272 |
273 | # Get the app password session for the target subject if it exists
274 | app_password_session_stmt = select(AppPasswordSession).where(
275 | AppPasswordSession.guid == subject_handle.guid,
276 | )
277 | app_password_session: Optional[AppPasswordSession] = (
278 | await database_session.scalars(app_password_session_stmt)
279 | ).first()
280 |
281 | # Get a valid OAuth session for the target subject
282 | target_oauth_session_stmt = select(OAuthSession).where(
283 | OAuthSession.guid == subject_handle.guid,
284 | OAuthSession.access_token_expires_at > now,
285 | OAuthSession.hard_expires_at > now
286 | ).order_by(OAuthSession.created_at.desc())
287 |
288 | target_oauth_session: Optional[OAuthSession] = (
289 | await database_session.scalars(target_oauth_session_stmt)
290 | ).first()
291 |
292 | # Get the service endpoint from the X-Service header or use the subject's PDS
293 | x_service: str = request.headers.getone("X-Service", subject_handle.pds)
294 |
295 | await database_session.commit()
296 |
297 | return AuthToken(
298 | oauth_session=target_oauth_session,
299 | app_password_session=app_password_session,
300 | guid=oauth_session_handle.guid,
301 | subject=oauth_session_handle.did,
302 | handle=oauth_session_handle.handle,
303 | context_service=x_service,
304 | context_guid=subject_handle.guid,
305 | context_subject=subject_handle.did,
306 | context_pds=subject_handle.pds,
307 | )
308 | except AuthenticationException as e:
309 | sentry_sdk.capture_exception(e)
310 | raise
311 | except Exception as e:
312 | sentry_sdk.capture_exception(e)
313 | statsd_client.increment(
314 | "aip.auth.exception",
315 | 1,
316 | tag_dict={"exception": type(e).__name__},
317 | )
318 | logger.exception("auth_token_helper: Exception")
319 | return None
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/internal.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 | from aiohttp import web
5 | import sentry_sdk
6 |
7 | from social.graze.aip.app.config import (
8 | DatabaseSessionMakerAppKey,
9 | HealthGaugeAppKey,
10 | SessionAppKey,
11 | SettingsAppKey,
12 | TelegrafStatsdClientAppKey,
13 | )
14 | from social.graze.aip.app.handlers.helpers import auth_token_helper
15 | from social.graze.aip.model.handles import upsert_handle_stmt
16 | from social.graze.aip.resolve.handle import resolve_subject
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | async def handle_internal_me(request: web.Request):
22 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
23 | statsd_client = request.app[TelegrafStatsdClientAppKey]
24 |
25 | try:
26 | async with (database_session_maker() as database_session,):
27 | auth_token = await auth_token_helper(
28 | database_session, statsd_client, request, allow_permissions=False
29 | )
30 | if auth_token is None:
31 | raise web.HTTPUnauthorized(
32 | body=json.dumps({"error": "Not Authorized"}),
33 | content_type="application/json",
34 | )
35 |
36 | # TODO: Include has_app_password boolean in the response.
37 |
38 | return web.json_response(
39 | {
40 | "handle": auth_token.handle,
41 | "pds": auth_token.context_service,
42 | "did": auth_token.subject,
43 | "guid": auth_token.guid,
44 | "oauth_session_valid": auth_token.oauth_session is not None,
45 | "app_password_session_valid": auth_token.app_password_session
46 | is not None,
47 | }
48 | )
49 | except web.HTTPException as e:
50 | sentry_sdk.capture_exception(e)
51 | raise e
52 | except Exception as e:
53 | sentry_sdk.capture_exception(e)
54 | raise web.HTTPInternalServerError(
55 | body=json.dumps({"error": "Internal Server Error"}),
56 | content_type="application/json",
57 | )
58 |
59 |
60 | async def handle_internal_ready(request: web.Request):
61 | health_gauge = request.app[HealthGaugeAppKey]
62 | if await health_gauge.is_healthy():
63 | return web.Response(status=200)
64 | return web.Response(status=503)
65 |
66 |
67 | async def handle_internal_alive(request: web.Request):
68 | return web.Response(status=200)
69 |
70 |
71 | async def handle_internal_resolve(request: web.Request):
72 | subjects = request.query.getall("subject", [])
73 | if len(subjects) == 0:
74 | return web.json_response([])
75 |
76 | # Nick: This could be improved by using Redis to cache results. Eventually inputs should go into a queue to be
77 | # processed in the background and the results streamed back to the client via SSE. If this becomes a high volume
78 | # endpoint, then we should run our own PLC replica or consider tapping into jetstream.
79 |
80 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
81 | settings = request.app[SettingsAppKey]
82 |
83 | # TODO: Use a pydantic list structure for this.
84 | results = []
85 | async with database_session_maker() as database_session:
86 | for subject in subjects:
87 | resolved_subject = await resolve_subject(
88 | request.app[SessionAppKey], settings.plc_hostname, subject
89 | )
90 | if resolved_subject is None:
91 | continue
92 | async with database_session.begin():
93 | stmt = upsert_handle_stmt(
94 | resolved_subject.did, resolved_subject.handle, resolved_subject.pds
95 | )
96 | await database_session.execute(stmt)
97 | await database_session.commit()
98 | results.append(resolved_subject.model_dump())
99 | return web.json_response(results)
100 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/oauth.py:
--------------------------------------------------------------------------------
1 | """
2 | AT Protocol OAuth Handlers
3 |
4 | This module implements the web request handlers for OAuth authentication with AT Protocol.
5 | It provides endpoints for initiating authentication, handling callbacks from the AT Protocol
6 | authorization server, refreshing tokens, and debugging authentication information.
7 |
8 | OAuth Flow with AT Protocol:
9 | 1. User enters their handle/DID in the login form
10 | 2. Application initiates OAuth flow by redirecting to AT Protocol authorization server
11 | 3. User authenticates with their AT Protocol PDS (Personal Data Server)
12 | 4. PDS redirects back to the application with an authorization code
13 | 5. Application exchanges the code for access and refresh tokens
14 | 6. Application stores tokens and returns an auth token to the client
15 | 7. Tokens are refreshed before they expire
16 |
17 | The handlers in this module provide the following endpoints:
18 | - GET /auth/atproto - Login form for entering AT Protocol handle/DID
19 | - POST /auth/atproto - Submit login form to initiate OAuth flow
20 | - GET /auth/atproto/callback - OAuth callback from AT Protocol authorization server
21 | - GET /auth/atproto/refresh - Manually refresh tokens
22 | - GET /auth/atproto/debug - Display debug information about authentication
23 | - GET /.well-known/jwks.json - JWKS endpoint for key verification
24 | - GET /auth/atproto/client-metadata.json - OAuth client metadata
25 | """
26 |
27 | import json
28 | import logging
29 | from typing import (
30 | Optional,
31 | Dict,
32 | List,
33 | Any,
34 | )
35 | from aiohttp import web
36 | import aiohttp_jinja2
37 | from jwcrypto import jwt
38 | from pydantic import BaseModel
39 | from sqlalchemy import select
40 | import sentry_sdk
41 | from urllib.parse import urlparse, urlencode, parse_qsl, urlunparse
42 |
43 |
44 | from social.graze.aip.app.config import (
45 | DatabaseSessionMakerAppKey,
46 | RedisClientAppKey,
47 | SessionAppKey,
48 | SettingsAppKey,
49 | TelegrafStatsdClientAppKey,
50 | )
51 | from social.graze.aip.app.cors import get_cors_headers
52 | from social.graze.aip.atproto.oauth import oauth_complete, oauth_init, oauth_refresh
53 | from social.graze.aip.model.handles import Handle
54 | from social.graze.aip.model.oauth import OAuthSession
55 |
56 | logger = logging.getLogger(__name__)
57 |
58 |
59 | def context_vars(settings):
60 | """
61 | Create a context dictionary for template rendering with UI customization settings.
62 |
63 | This function extracts UI customization settings from the application settings
64 | to be passed to the template rendering engine.
65 |
66 | Args:
67 | settings: Application settings object
68 |
69 | Returns:
70 | Dict containing UI customization variables for templates
71 | """
72 | return {
73 | "svg_logo": settings.svg_logo,
74 | "brand_name": settings.brand_name,
75 | "destination": settings.destination,
76 | "background_from": settings.background_from,
77 | "background_to": settings.background_to,
78 | "text_color": settings.text_color,
79 | "form_color": settings.form_color,
80 | }
81 |
82 |
83 | class ATProtocolOAuthClientMetadata(BaseModel):
84 | """
85 | OAuth 2.0 Client Metadata for AT Protocol integration.
86 |
87 | This model represents the client metadata used for OAuth registration with AT Protocol.
88 | It follows the OAuth 2.0 Dynamic Client Registration Protocol (RFC 7591) with
89 | additional fields specific to AT Protocol requirements.
90 |
91 | The metadata is exposed at the client-metadata.json endpoint and is used by
92 | AT Protocol authorization servers to validate OAuth requests.
93 | """
94 | client_id: str
95 | """Client identifier URI"""
96 |
97 | dpop_bound_access_tokens: bool
98 | """Whether access tokens are bound to DPoP proofs"""
99 |
100 | application_type: str
101 | """Type of application (web, native)"""
102 |
103 | redirect_uris: List[str]
104 | """List of allowed redirect URIs for this client"""
105 |
106 | client_uri: str
107 | """URI of the client's homepage"""
108 |
109 | grant_types: List[str]
110 | """OAuth grant types supported by this client"""
111 |
112 | response_types: List[str]
113 | """OAuth response types supported by this client"""
114 |
115 | scope: str
116 | """OAuth scopes requested by this client"""
117 |
118 | client_name: str
119 | """Human-readable name of the client application"""
120 |
121 | token_endpoint_auth_method: str
122 | """Authentication method for the token endpoint"""
123 |
124 | jwks_uri: str
125 | """URI of the client's JWKS (JSON Web Key Set)"""
126 |
127 | logo_uri: str
128 | """URI of the client's logo"""
129 |
130 | tos_uri: str
131 | """URI of the client's terms of service"""
132 |
133 | policy_uri: str
134 | """URI of the client's policy document"""
135 |
136 | subject_type: str
137 | """Subject type requested for responses"""
138 |
139 | token_endpoint_auth_signing_alg: str
140 | """Algorithm used for signing token endpoint authentication assertions"""
141 |
142 |
143 | async def handle_atproto_login(request: web.Request):
144 | """
145 | Handle GET request to the AT Protocol login page.
146 |
147 | This handler renders the login form where users can enter their
148 | AT Protocol handle or DID to begin the authentication process.
149 |
150 | Args:
151 | request: HTTP request object
152 |
153 | Returns:
154 | HTTP response with rendered login template
155 | """
156 | settings = request.app[SettingsAppKey]
157 | context = context_vars(settings)
158 |
159 | destination = request.query.get("destination")
160 | if destination:
161 | context["destination"] = destination
162 |
163 | return await aiohttp_jinja2.render_template_async(
164 | "atproto_login.html", request, context=context
165 | )
166 |
167 |
168 | async def handle_atproto_login_submit(request: web.Request):
169 | """
170 | Handle POST request from the AT Protocol login form.
171 |
172 | This handler processes the login form submission and initiates the OAuth flow.
173 | It extracts the subject (handle or DID) and optional destination from the form,
174 | then calls oauth_init to start the OAuth process.
175 |
176 | Request Parameters:
177 | subject: AT Protocol handle or DID
178 | destination: Optional redirect URL after authentication
179 |
180 | Args:
181 | request: HTTP request object
182 |
183 | Returns:
184 | HTTP redirect to the AT Protocol authorization server
185 |
186 | Raises:
187 | HTTPFound: To redirect to authorization server
188 |
189 | Flow:
190 | 1. Extract subject and destination from form
191 | 2. Initialize OAuth flow with oauth_init
192 | 3. Redirect user to authorization server
193 | """
194 | settings = request.app[SettingsAppKey]
195 | data = await request.post()
196 | subject: Optional[str] = data.get("subject", None) # type: ignore
197 | destination: Optional[str] = data.get("destination", None) # type: ignore
198 |
199 | if subject is None:
200 | return await aiohttp_jinja2.render_template_async(
201 | "atproto_login.html",
202 | request,
203 | context=dict(
204 | **context_vars(settings), **{"error_message": "No subject provided"}
205 | ),
206 | )
207 |
208 | http_session = request.app[SessionAppKey]
209 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
210 | statsd_client = request.app[TelegrafStatsdClientAppKey]
211 | redis_session = request.app[RedisClientAppKey]
212 |
213 | if destination is None:
214 | destination = settings.default_destination
215 |
216 | try:
217 | redirect_destination = await oauth_init(
218 | settings,
219 | statsd_client,
220 | http_session,
221 | database_session_maker,
222 | redis_session,
223 | subject,
224 | destination,
225 | )
226 | except Exception as e:
227 | logger.exception("login error")
228 |
229 | sentry_sdk.capture_exception(e)
230 | # TODO: Return a localized error message.
231 | return await aiohttp_jinja2.render_template_async(
232 | "atproto_login.html",
233 | request,
234 | context=dict(**context_vars(settings), **{"error_message": str(e)}),
235 | )
236 | raise web.HTTPFound(
237 | str(redirect_destination),
238 | headers=get_cors_headers(request.headers.get("Origin"), request.path, settings.debug),
239 | )
240 |
241 |
242 | async def handle_atproto_callback(request: web.Request):
243 | """
244 | Handle OAuth callback from AT Protocol authorization server.
245 |
246 | This handler processes the callback from the AT Protocol authorization server,
247 | exchanging the authorization code for access and refresh tokens, then redirecting
248 | the user to their final destination with an auth token.
249 |
250 | Query Parameters:
251 | state: OAuth state parameter to prevent CSRF
252 | iss: Issuer identifier (authorization server)
253 | code: Authorization code to exchange for tokens
254 |
255 | Args:
256 | request: HTTP request object
257 |
258 | Returns:
259 | HTTP redirect to the final destination with auth token
260 |
261 | Raises:
262 | HTTPFound: To redirect to final destination
263 |
264 | Flow:
265 | 1. Extract state, issuer, and code from query parameters
266 | 2. Complete OAuth flow with oauth_complete
267 | 3. Add auth token to destination URL
268 | 4. Redirect user to final destination
269 | """
270 | state: Optional[str] = request.query.get("state", None)
271 | issuer: Optional[str] = request.query.get("iss", None)
272 | code: Optional[str] = request.query.get("code", None)
273 |
274 | settings = request.app[SettingsAppKey]
275 | http_session = request.app[SessionAppKey]
276 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
277 | redis_session = request.app[RedisClientAppKey]
278 | statsd_client = request.app[TelegrafStatsdClientAppKey]
279 |
280 | try:
281 | (serialized_auth_token, destination) = await oauth_complete(
282 | settings,
283 | http_session,
284 | statsd_client,
285 | database_session_maker,
286 | redis_session,
287 | state,
288 | issuer,
289 | code,
290 | )
291 | except Exception as e:
292 | return await aiohttp_jinja2.render_template_async(
293 | "alert.html",
294 | request,
295 | context={"error_message": str(e)},
296 | )
297 |
298 | parsed_destination = urlparse(destination)
299 | query = dict(parse_qsl(parsed_destination.query))
300 | query.update({"auth_token": serialized_auth_token})
301 | parsed_destination = parsed_destination._replace(query=urlencode(query))
302 | redirect_destination = urlunparse(parsed_destination)
303 | raise web.HTTPFound(redirect_destination)
304 |
305 |
306 | async def handle_atproto_refresh(request: web.Request):
307 | """
308 | Handle manual token refresh request.
309 |
310 | This handler allows for manual refreshing of OAuth tokens. It extracts the
311 | auth token from the query parameters, validates it, finds the associated
312 | OAuth session, and refreshes the token.
313 |
314 | Query Parameters:
315 | auth_token: JWT authentication token
316 |
317 | Args:
318 | request: HTTP request object
319 |
320 | Returns:
321 | HTTP redirect to debug page with refreshed token
322 |
323 | Raises:
324 | HTTPFound: To redirect to debug page
325 | Exception: If auth token is invalid or session not found
326 |
327 | Flow:
328 | 1. Extract and validate auth token
329 | 2. Find associated OAuth session
330 | 3. Refresh tokens with oauth_refresh
331 | 4. Redirect to debug page
332 | """
333 | settings = request.app[SettingsAppKey]
334 | http_session = request.app[SessionAppKey]
335 | statsd_client = request.app[TelegrafStatsdClientAppKey]
336 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
337 | redis_session = request.app[RedisClientAppKey]
338 |
339 | serialized_auth_token: Optional[str] = request.query.get("auth_token", None)
340 | if serialized_auth_token is None:
341 | raise Exception("Invalid request")
342 |
343 | validated_auth_token = jwt.JWT(
344 | jwt=serialized_auth_token, key=settings.json_web_keys, algs=["ES256"]
345 | )
346 | auth_token_claims = json.loads(validated_auth_token.claims)
347 |
348 | auth_token_subject: Optional[str] = auth_token_claims.get("sub", None)
349 | auth_token_session_group: Optional[str] = auth_token_claims.get("grp", None)
350 |
351 | # Fetch OAuth session first
352 | async with (database_session_maker() as database_session,):
353 | async with database_session.begin():
354 | oauth_session_stmt = select(OAuthSession).where(
355 | OAuthSession.guid == auth_token_subject,
356 | OAuthSession.session_group == auth_token_session_group,
357 | )
358 | oauth_session: OAuthSession = (
359 | await database_session.scalars(oauth_session_stmt)
360 | ).one()
361 |
362 | # Create fresh database session for oauth_refresh to avoid transaction conflicts
363 | async with (database_session_maker() as fresh_database_session,):
364 | await oauth_refresh(
365 | settings,
366 | http_session,
367 | statsd_client,
368 | fresh_database_session,
369 | redis_session,
370 | oauth_session,
371 | )
372 |
373 | # The same auth token is returned, but the access token is updated.
374 | raise web.HTTPFound(f"/auth/atproto/debug?auth_token={serialized_auth_token}")
375 |
376 |
377 | async def handle_atproto_debug(request: web.Request):
378 | """
379 | Handle debug page request showing authentication information.
380 |
381 | This handler displays detailed information about the authentication session,
382 | including the JWT token contents, OAuth session details, and user handle.
383 | It's primarily used for debugging and development purposes.
384 |
385 | Query Parameters:
386 | auth_token: JWT authentication token
387 |
388 | Args:
389 | request: HTTP request object
390 |
391 | Returns:
392 | HTTP response with rendered debug template
393 |
394 | Raises:
395 | Exception: If auth token is invalid or session/handle not found
396 | """
397 | settings = request.app[SettingsAppKey]
398 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
399 |
400 | serialized_auth_token: Optional[str] = request.query.get("auth_token", None)
401 | if serialized_auth_token is None:
402 | raise Exception("Invalid request")
403 |
404 | validated_auth_token = jwt.JWT(
405 | jwt=serialized_auth_token, key=settings.json_web_keys, algs=["ES256"]
406 | )
407 | auth_token_claims = json.loads(validated_auth_token.claims)
408 | auth_token_header = json.loads(validated_auth_token.header)
409 |
410 | auth_token_subject: Optional[str] = auth_token_claims.get("sub", None)
411 | auth_token_session_group: Optional[str] = auth_token_claims.get("grp", None)
412 |
413 | async with database_session_maker() as database_session:
414 |
415 | async with database_session.begin():
416 |
417 | oauth_session_stmt = select(OAuthSession).where(
418 | OAuthSession.guid == auth_token_subject,
419 | OAuthSession.session_group == auth_token_session_group,
420 | )
421 | oauth_session: Optional[OAuthSession] = (
422 | await database_session.scalars(oauth_session_stmt)
423 | ).first()
424 | if oauth_session is None:
425 | raise Exception("Invalid request: no matching session")
426 |
427 | handle_stmt = select(Handle).where(Handle.guid == oauth_session.guid)
428 | handle: Optional[Handle] = (
429 | await database_session.scalars(handle_stmt)
430 | ).first()
431 | if handle is None:
432 | raise Exception("Invalid request: no matching handle")
433 |
434 | await database_session.commit()
435 |
436 | return await aiohttp_jinja2.render_template_async(
437 | "atproto_debug.html",
438 | request,
439 | context={
440 | "auth_token": {"claims": auth_token_claims, "header": auth_token_header},
441 | "oauth_session": oauth_session,
442 | "handle": handle,
443 | "serialized_auth_token": serialized_auth_token,
444 | },
445 | )
446 |
447 |
448 | async def handle_jwks(request: web.Request):
449 | """
450 | Handle JWKS (JSON Web Key Set) endpoint request.
451 |
452 | This handler provides the public keys used for verifying JWT signatures.
453 | It returns a JWKS document containing the public portions of the active signing keys.
454 |
455 | Args:
456 | request: HTTP request object
457 |
458 | Returns:
459 | HTTP JSON response with JWKS document
460 | """
461 | settings = request.app[SettingsAppKey]
462 | results: List[Dict[str, Any]] = []
463 | for kid in settings.active_signing_keys:
464 | key = settings.json_web_keys.get_key(kid)
465 | if key is None:
466 | continue
467 | results.append(key.export_public(as_dict=True))
468 | return web.json_response({"keys": results})
469 |
470 |
471 | async def handle_atproto_client_metadata(request: web.Request):
472 | """
473 | Handle OAuth client metadata endpoint request.
474 |
475 | This handler provides OAuth client metadata according to the OAuth 2.0
476 | Dynamic Client Registration Protocol (RFC 7591). It returns a JSON document
477 | describing this client to AT Protocol authorization servers.
478 |
479 | The metadata includes client identification, capabilities, endpoints,
480 | and authentication methods.
481 |
482 | Args:
483 | request: HTTP request object
484 |
485 | Returns:
486 | HTTP JSON response with client metadata
487 | """
488 | settings = request.app[SettingsAppKey]
489 | client_id = (
490 | f"https://{settings.external_hostname}/auth/atproto/client-metadata.json"
491 | )
492 | client_uri = f"https://{settings.external_hostname}"
493 | jwks_uri = f"https://{settings.external_hostname}/.well-known/jwks.json"
494 | logo_uri = f"https://{settings.external_hostname}/logo.png"
495 | policy_uri = f"https://{settings.external_hostname}/PLACEHOLDER"
496 | redirect_uris = [f"https://{settings.external_hostname}/auth/atproto/callback"]
497 | tos_uri = f"https://{settings.external_hostname}/PLACEHOLDER"
498 | client_metadata = ATProtocolOAuthClientMetadata(
499 | application_type="web",
500 | client_id=client_id,
501 | client_name="Graze Social",
502 | client_uri=client_uri,
503 | dpop_bound_access_tokens=True,
504 | grant_types=["authorization_code", "refresh_token"],
505 | jwks_uri=jwks_uri,
506 | logo_uri=logo_uri,
507 | policy_uri=policy_uri,
508 | redirect_uris=redirect_uris,
509 | response_types=["code"],
510 | scope="atproto transition:generic",
511 | token_endpoint_auth_method="private_key_jwt",
512 | token_endpoint_auth_signing_alg="ES256",
513 | subject_type="public",
514 | tos_uri=tos_uri,
515 | )
516 | return web.json_response(client_metadata.dict())
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/permissions.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | from typing import (
4 | Literal,
5 | Optional,
6 | Dict,
7 | List,
8 | Any,
9 | )
10 | from aiohttp import web
11 | from pydantic import BaseModel, PositiveInt, RootModel, ValidationError
12 | from sqlalchemy import delete, select
13 | import sentry_sdk
14 |
15 | from social.graze.aip.app.config import (
16 | DatabaseSessionMakerAppKey,
17 | TelegrafStatsdClientAppKey,
18 | )
19 | from social.graze.aip.app.handlers.helpers import auth_token_helper
20 | from social.graze.aip.model.oauth import (
21 | Permission,
22 | upsert_permission_stmt,
23 | )
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | class PermissionOperation(BaseModel):
29 | op: Literal["test", "add", "remove", "replace"]
30 | path: str
31 |
32 | # TODO: Make these permission values mean something.
33 | value: Optional[PositiveInt] = None
34 |
35 |
36 | PermissionOperations = RootModel[list[PermissionOperation]]
37 |
38 |
39 | async def handle_internal_permissions(request: web.Request) -> web.Response:
40 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
41 | statsd_client = request.app[TelegrafStatsdClientAppKey]
42 |
43 | # TODO: Support GET requests that returns paginated permission objects.
44 |
45 | try:
46 | data = await request.read()
47 | operations = PermissionOperations.model_validate_json(data)
48 | except (OSError, ValidationError) as e:
49 | sentry_sdk.capture_exception(e)
50 | return web.Response(text="Invalid JSON", status=400)
51 |
52 | try:
53 | async with (database_session_maker() as database_session,):
54 | auth_token = await auth_token_helper(
55 | database_session, statsd_client, request, allow_permissions=False
56 | )
57 | if auth_token is None:
58 | return web.json_response(status=401, data={"error": "Not Authorized"})
59 |
60 | # TODO: Fail with error if the user does not have an app-password set.
61 |
62 | now = datetime.datetime.now(datetime.timezone.utc)
63 |
64 | async with database_session.begin():
65 | results: List[Dict[str, Any]] = []
66 |
67 | for operation in operations.root:
68 | guid = operation.path.removeprefix("/")
69 | if guid == auth_token.guid:
70 | return web.json_response(
71 | status=400, data={"error": "Invalid permission"}
72 | )
73 |
74 | if (
75 | operation.op == "add" or operation.op == "replace"
76 | ) and operation.value is not None:
77 |
78 | # TODO: Fail if the guid is unknown. Clients should use /internal/api/resolve on all subjects
79 | # prior to setting permissions.
80 |
81 | stmt = upsert_permission_stmt(
82 | guid=guid,
83 | target_guid=auth_token.guid,
84 | permission=operation.value,
85 | created_at=now,
86 | )
87 | await database_session.execute(stmt)
88 |
89 | if operation.op == "remove":
90 | stmt = delete(Permission).where(
91 | Permission.guid == guid,
92 | Permission.target_guid == auth_token.guid,
93 | )
94 | await database_session.execute(stmt)
95 |
96 | if operation.op == "test":
97 | permission_stmt = select(Permission).where(
98 | Permission.guid == guid,
99 | Permission.target_guid == auth_token.guid,
100 | )
101 | if operation.value is not None:
102 | permission_stmt = permission_stmt.where(
103 | Permission.permission == operation.value
104 | )
105 | permission: Optional[Permission] = (
106 | await database_session.scalars(permission_stmt)
107 | ).first()
108 | if permission is not None:
109 | results.append(
110 | {"path": operation.path, "value": permission.permission}
111 | )
112 |
113 | await database_session.commit()
114 |
115 | return web.json_response(results)
116 | except web.HTTPException as e:
117 | sentry_sdk.capture_exception(e)
118 | logging.exception("handle_internal_permissions: web.HTTPException")
119 | raise e
120 | except Exception as e:
121 | sentry_sdk.capture_exception(e)
122 | logging.exception("handle_internal_permissions: Exception")
123 | return web.json_response(status=500, data={"error": "Internal Server Error"})
124 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/handlers/proxy.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone
2 | import logging
3 | from time import time
4 | from typing import (
5 | List,
6 | Optional,
7 | Dict,
8 | Any,
9 | )
10 | from aiohttp import web
11 | import hashlib
12 | import base64
13 | from jwcrypto import jwk
14 | import sentry_sdk
15 | from urllib.parse import urlparse, urlencode, parse_qsl, urlunparse
16 |
17 | from social.graze.aip.app.config import (
18 | DatabaseSessionMakerAppKey,
19 | SessionAppKey,
20 | SettingsAppKey,
21 | TelegrafStatsdClientAppKey,
22 | )
23 | from social.graze.aip.app.handlers.helpers import auth_token_helper
24 | from social.graze.aip.atproto.chain import (
25 | ChainMiddlewareClient,
26 | DebugMiddleware,
27 | GenerateDpopMiddleware,
28 | RequestMiddlewareBase,
29 | StatsdMiddleware,
30 | )
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | async def handle_xrpc_proxy(request: web.Request) -> web.Response:
36 | statsd_client = request.app[TelegrafStatsdClientAppKey]
37 |
38 | # TODO: Validate this against an allowlist.
39 | xrpc_method: Optional[str] = request.match_info.get("method", None)
40 | if xrpc_method is None:
41 | statsd_client.increment("aip.proxy.invalid_method", 1)
42 | return web.json_response(status=400, data={"error": "Invalid XRPC method"})
43 |
44 | database_session_maker = request.app[DatabaseSessionMakerAppKey]
45 |
46 | try:
47 | async with (database_session_maker() as database_session,):
48 | # TODO: Allow optional auth here.
49 | auth_token = await auth_token_helper(
50 | database_session, statsd_client, request
51 | )
52 | if auth_token is None:
53 | statsd_client.increment(
54 | "aip.proxy.unauthorized", 1, tag_dict={"method": xrpc_method}
55 | )
56 | return web.json_response(status=401, data={"error": "Not Authorized"})
57 | except web.HTTPException as e:
58 | sentry_sdk.capture_exception(e)
59 | raise e
60 | except Exception as e:
61 | sentry_sdk.capture_exception(e)
62 | statsd_client.increment(
63 | "aip.proxy.exception",
64 | 1,
65 | tag_dict={"exception": type(e).__name__, "method": xrpc_method},
66 | )
67 | return web.json_response(status=500, data={"error": "Internal Server Error"})
68 |
69 | now = datetime.now(timezone.utc)
70 |
71 | # TODO: Look for endpoint header and fallback to context PDS value.
72 | # TODO: Support more complex URLs that include prefixes and or suffixes.
73 | parsed_destination = urlparse(f"{auth_token.context_service}/xrpc/{xrpc_method}")
74 | parsed_destination_query = dict(parse_qsl(request.query_string))
75 | parsed_destination = parsed_destination._replace(
76 | query=urlencode(parsed_destination_query)
77 | )
78 | xrpc_url = urlunparse(parsed_destination)
79 |
80 | http_session = request.app[SessionAppKey]
81 |
82 | headers = {
83 | "Content-Type": request.headers.get("Content-Type", "application/json"),
84 | }
85 |
86 | chain_middleware: List[RequestMiddlewareBase] = [StatsdMiddleware(statsd_client)]
87 |
88 | # App password sessions use the `Authorization` header with `Bearer` scheme.
89 | if auth_token.app_password_session is not None:
90 | headers["Authorization"] = (
91 | f"Bearer {auth_token.app_password_session.access_token}"
92 | )
93 |
94 | # OAuth sessions use the `Authorization` header with `DPoP` scheme.
95 | elif (
96 | auth_token.app_password_session is None and auth_token.oauth_session is not None
97 | ):
98 | hashed_access_token = hashlib.sha256(
99 | str(auth_token.oauth_session.access_token).encode("ascii")
100 | ).digest()
101 | encoded_hashed_access_token = base64.urlsafe_b64encode(hashed_access_token)
102 | pkcs_access_token = encoded_hashed_access_token.decode("ascii").rstrip("=")
103 |
104 | dpop_key = jwk.JWK(**auth_token.oauth_session.dpop_jwk)
105 | dpop_key_public_key = dpop_key.export_public(as_dict=True)
106 | dpop_assertation_header = {
107 | "alg": "ES256",
108 | "jwk": dpop_key_public_key,
109 | "typ": "dpop+jwt",
110 | }
111 | dpop_assertation_claims = {
112 | "htm": request.method,
113 | "htu": f"{auth_token.context_service}/xrpc/{xrpc_method}",
114 | "iat": int(now.timestamp()) - 1,
115 | "exp": int(now.timestamp()) + 30,
116 | "nonce": "tmp",
117 | "ath": pkcs_access_token,
118 | "iss": f"{auth_token.oauth_session.issuer}",
119 | }
120 |
121 | headers["Authorization"] = f"DPoP {auth_token.oauth_session.access_token}"
122 |
123 | chain_middleware.append(
124 | GenerateDpopMiddleware(
125 | dpop_key,
126 | dpop_assertation_header,
127 | dpop_assertation_claims,
128 | )
129 | )
130 |
131 | settings = request.app[SettingsAppKey]
132 | if settings.debug:
133 | chain_middleware.append(DebugMiddleware())
134 |
135 | rargs: Dict[str, Any] = {}
136 |
137 | if request.method == "POST":
138 | rargs["data"] = await request.read()
139 |
140 | chain_client = ChainMiddlewareClient(
141 | client_session=http_session, raise_for_status=False, middleware=chain_middleware
142 | )
143 |
144 | start_time = time()
145 | cross_subject = auth_token.guid != auth_token.context_guid
146 | auth_method = "anonymous"
147 | if auth_token.app_password_session is not None:
148 | auth_method = "app-password"
149 | elif (
150 | auth_token.oauth_session is not None and auth_token.app_password_session is None
151 | ):
152 | auth_method = "oauth"
153 |
154 | try:
155 | async with chain_client.request(
156 | request.method, xrpc_url, raise_for_status=None, headers=headers, **rargs
157 | ) as (
158 | client_response,
159 | chain_response,
160 | ):
161 | # TODO: Figure out if websockets or SSE support is needed. Gut says no.
162 | # TODO: Think about using a header like `X-AIP-Error` for additional error context.
163 | return chain_response.to_web_response()
164 | finally:
165 | statsd_client.timer(
166 | "aip.proxy.request.time",
167 | time() - start_time,
168 | tag_dict={
169 | "xrpc_service": auth_token.context_service.removeprefix("https://"),
170 | "xrpc_method": xrpc_method,
171 | "method": request.method.lower(),
172 | "authentication": auth_method,
173 | "cross_subject": str(cross_subject),
174 | },
175 | )
176 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/server.py:
--------------------------------------------------------------------------------
1 | import re
2 | import asyncio
3 | import contextlib
4 | import os
5 | import logging
6 | from time import time
7 | from typing import (
8 | Optional,
9 | )
10 | from aio_statsd import TelegrafStatsdClient
11 | import jinja2
12 | from aiohttp import web
13 | import aiohttp_jinja2
14 | import aiohttp
15 | import redis.asyncio as redis
16 | from sqlalchemy.ext.asyncio import (
17 | create_async_engine,
18 | async_sessionmaker,
19 | AsyncSession,
20 | )
21 | import sentry_sdk
22 | from sentry_sdk.integrations.aiohttp import AioHttpIntegration
23 |
24 | from social.graze.aip.app.config import (
25 | DatabaseAppKey,
26 | DatabaseSessionMakerAppKey,
27 | HealthGaugeAppKey,
28 | RedisClientAppKey,
29 | RedisPoolAppKey,
30 | SessionAppKey,
31 | Settings,
32 | SettingsAppKey,
33 | AppPasswordRefreshTaskAppKey,
34 | OAuthRefreshTaskAppKey,
35 | TelegrafStatsdClientAppKey,
36 | TickHealthTaskAppKey,
37 | )
38 | from social.graze.aip.app.cors import get_cors_headers
39 | from social.graze.aip.app.handlers.app_password import handle_internal_app_password
40 | from social.graze.aip.app.handlers.credentials import handle_internal_credentials
41 | from social.graze.aip.app.handlers.internal import (
42 | handle_internal_alive,
43 | handle_internal_me,
44 | handle_internal_ready,
45 | handle_internal_resolve,
46 | )
47 | from social.graze.aip.app.handlers.oauth import (
48 | handle_atproto_callback,
49 | handle_atproto_client_metadata,
50 | handle_atproto_debug,
51 | handle_atproto_login,
52 | handle_atproto_login_submit,
53 | handle_atproto_refresh,
54 | handle_jwks,
55 | )
56 | from social.graze.aip.app.handlers.permissions import handle_internal_permissions
57 | from social.graze.aip.app.handlers.proxy import handle_xrpc_proxy
58 | from social.graze.aip.app.tasks import (
59 | oauth_refresh_task,
60 | tick_health_task,
61 | app_password_refresh_task,
62 | )
63 | from social.graze.aip.model.health import HealthGauge
64 |
65 | logger = logging.getLogger(__name__)
66 | allowed_origin_pattern = re.compile(
67 | r"https:\/\/(www\.)?graze\.social"
68 | r"https:\/\/(www\.)?sky-feeder-git-astro-graze\.vercel\.app"
69 | r"http:\/\/localhost\:\d+"
70 | r"http:\/\/127\.0\.0\.1\:\d+"
71 | )
72 |
73 |
74 | async def handle_index(request: web.Request):
75 | return await aiohttp_jinja2.render_template_async("index.html", request, context={})
76 |
77 |
78 | async def background_tasks(app):
79 | logger.info("Starting up")
80 | settings: Settings = app[SettingsAppKey]
81 |
82 | engine = create_async_engine(str(settings.pg_dsn))
83 | app[DatabaseAppKey] = engine
84 | database_session = async_sessionmaker(
85 | engine, class_=AsyncSession, expire_on_commit=False
86 | )
87 | app[DatabaseSessionMakerAppKey] = database_session
88 |
89 | trace_config = aiohttp.TraceConfig()
90 |
91 | if settings.debug:
92 |
93 | async def on_request_start(
94 | session, trace_config_ctx, params: aiohttp.TraceRequestStartParams
95 | ):
96 | logging.info("Starting request: %s", params)
97 |
98 | async def on_request_chunk_sent(
99 | session, trace_config_ctx, params: aiohttp.TraceRequestChunkSentParams
100 | ):
101 | logging.info("Chunk sent: %s", str(params.chunk))
102 |
103 | async def on_request_end(session, trace_config_ctx, params):
104 | logging.info("Ending request: %s", params)
105 |
106 | trace_config.on_request_start.append(on_request_start)
107 | trace_config.on_request_end.append(on_request_end)
108 | trace_config.on_request_chunk_sent.append(on_request_chunk_sent)
109 |
110 | app[SessionAppKey] = aiohttp.ClientSession(trace_configs=[trace_config])
111 |
112 | app[RedisPoolAppKey] = redis.ConnectionPool.from_url(str(settings.redis_dsn))
113 |
114 | app[RedisClientAppKey] = redis.Redis(
115 | connection_pool=redis.ConnectionPool.from_url(str(settings.redis_dsn))
116 | )
117 |
118 | statsd_client = TelegrafStatsdClient(
119 | host=settings.statsd_host, port=settings.statsd_port, debug=settings.debug
120 | )
121 | await statsd_client.connect()
122 | app[TelegrafStatsdClientAppKey] = statsd_client
123 |
124 | logger.info("Startup complete")
125 |
126 | app[TickHealthTaskAppKey] = asyncio.create_task(tick_health_task(app))
127 | app[OAuthRefreshTaskAppKey] = asyncio.create_task(oauth_refresh_task(app))
128 | app[AppPasswordRefreshTaskAppKey] = asyncio.create_task(
129 | app_password_refresh_task(app)
130 | )
131 |
132 | yield
133 |
134 | print("Shutting down background tasks")
135 |
136 | app[TickHealthTaskAppKey].cancel()
137 | app[OAuthRefreshTaskAppKey].cancel()
138 | app[AppPasswordRefreshTaskAppKey].cancel()
139 |
140 | with contextlib.suppress(asyncio.exceptions.CancelledError):
141 | await app[TickHealthTaskAppKey]
142 |
143 | with contextlib.suppress(asyncio.exceptions.CancelledError):
144 | await app[OAuthRefreshTaskAppKey]
145 |
146 | with contextlib.suppress(asyncio.exceptions.CancelledError):
147 | await app[AppPasswordRefreshTaskAppKey]
148 |
149 | await app[DatabaseAppKey].dispose()
150 | await app[SessionAppKey].close()
151 | await app[RedisPoolAppKey].aclose()
152 | await app[TelegrafStatsdClientAppKey].close()
153 |
154 |
155 | @web.middleware
156 | async def cors_middleware(request: web.Request, handler):
157 | settings = request.app[SettingsAppKey]
158 |
159 | origin = request.headers.get("Origin")
160 | host = request.headers.get("Host")
161 | origin_value = origin if origin else host
162 | path = request.path
163 |
164 | headers = get_cors_headers(origin_value, path, settings.debug)
165 |
166 | if request.method == "OPTIONS":
167 | logger.debug(f"[CORS] Returning early for OPTIONS request: {path}")
168 | return web.Response(status=200, headers=headers)
169 |
170 | response = await handler(request)
171 |
172 | for k, v in headers.items():
173 | response.headers[k] = v
174 |
175 | return response
176 |
177 | @web.middleware
178 | async def sentry_middleware(request: web.Request, handler):
179 | request_method: str = request.method
180 | request_path = request.path
181 |
182 | try:
183 | response = await handler(request)
184 | return response
185 | except Exception as e:
186 | sentry_sdk.capture_exception(e)
187 | raise e
188 |
189 |
190 | @web.middleware
191 | async def statsd_middleware(request: web.Request, handler):
192 | statsd_client = request.app[TelegrafStatsdClientAppKey]
193 | request_method: str = request.method
194 | request_path = request.path
195 |
196 | start_time: float = time()
197 | response_status_code = 0
198 |
199 | try:
200 | response = await handler(request)
201 | response_status_code = response.status
202 | return response
203 | except Exception as e:
204 | statsd_client.increment(
205 | "aip.server.request.exception",
206 | 1,
207 | tag_dict={
208 | "exception": type(e).__name__,
209 | "path": request_path,
210 | "method": request_method,
211 | },
212 | )
213 | raise e
214 | finally:
215 | statsd_client.timer(
216 | "aip.server.request.time",
217 | time() - start_time,
218 | tag_dict={"path": request_path, "method": request_method},
219 | )
220 | statsd_client.increment(
221 | "aip.server.request.count",
222 | 1,
223 | tag_dict={
224 | "path": request_path,
225 | "method": request_method,
226 | "status": response_status_code,
227 | },
228 | )
229 |
230 |
231 | async def shutdown(app):
232 | await app[DatabaseAppKey].dispose()
233 | await app[SessionAppKey].close()
234 | await app[RedisPoolAppKey].aclose()
235 | await app[TelegrafStatsdClientAppKey].close()
236 |
237 |
238 | async def start_web_server(settings: Optional[Settings] = None):
239 |
240 | if settings is None:
241 | settings = Settings() # type: ignore
242 | if settings.sentry_dsn:
243 | sentry_sdk.init(
244 | dsn=settings.sentry_dsn,
245 | send_default_pii=True,
246 | integrations=[AioHttpIntegration()],
247 | )
248 |
249 | app = web.Application(
250 | middlewares=[cors_middleware, statsd_middleware, sentry_middleware]
251 | )
252 |
253 | app[SettingsAppKey] = settings
254 | app[HealthGaugeAppKey] = HealthGauge()
255 |
256 | app.add_routes(
257 | [
258 | web.static(
259 | "/static", os.path.join(os.getcwd(), "static"), append_version=True
260 | )
261 | ]
262 | )
263 |
264 | app.add_routes([web.get("/.well-known/jwks.json", handle_jwks)])
265 |
266 | app.add_routes([web.get("/", handle_index)])
267 | app.add_routes([web.get("/auth/atproto", handle_atproto_login)])
268 | app.add_routes([web.post("/auth/atproto", handle_atproto_login_submit)])
269 | app.add_routes([web.get("/auth/atproto/callback", handle_atproto_callback)])
270 | app.add_routes([web.get("/auth/atproto/debug", handle_atproto_debug)])
271 | app.add_routes([web.get("/auth/atproto/refresh", handle_atproto_refresh)])
272 |
273 | app.add_routes(
274 | [web.get("/auth/atproto/client-metadata.json", handle_atproto_client_metadata)]
275 | )
276 |
277 | app.add_routes(
278 | [
279 | web.get("/internal/alive", handle_internal_alive),
280 | web.get("/internal/ready", handle_internal_ready),
281 | web.get("/internal/api/me", handle_internal_me),
282 | web.get("/internal/api/resolve", handle_internal_resolve),
283 | web.post("/internal/api/permissions", handle_internal_permissions),
284 | web.post("/internal/api/app_password", handle_internal_app_password),
285 | web.get("/internal/api/credentials", handle_internal_credentials),
286 | ]
287 | )
288 |
289 | app.add_routes(
290 | [
291 | web.get("/xrpc/{method}", handle_xrpc_proxy),
292 | web.post("/xrpc/{method}", handle_xrpc_proxy),
293 | ]
294 | )
295 |
296 | _ = aiohttp_jinja2.setup(
297 | app,
298 | enable_async=True,
299 | loader=jinja2.FileSystemLoader(os.path.join(os.getcwd(), "templates")),
300 | )
301 |
302 | app["static_root_url"] = "/static"
303 |
304 | # app.on_startup.append(startup)
305 | # app.on_cleanup.append(shutdown)
306 | app.cleanup_ctx.append(background_tasks)
307 |
308 | return app
309 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/tasks.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from datetime import datetime, timezone
3 | import logging
4 | from time import time
5 | from typing import List, NoReturn, Tuple
6 | from aiohttp import web
7 | from sqlalchemy import select, delete
8 | import sentry_sdk
9 |
10 | from social.graze.aip.app.config import (
11 | APP_PASSWORD_REFRESH_QUEUE,
12 | OAUTH_REFRESH_QUEUE,
13 | OAUTH_REFRESH_RETRY_QUEUE,
14 | DatabaseSessionMakerAppKey,
15 | HealthGaugeAppKey,
16 | RedisClientAppKey,
17 | SessionAppKey,
18 | SettingsAppKey,
19 | TelegrafStatsdClientAppKey,
20 | )
21 | from social.graze.aip.atproto.app_password import populate_session
22 | from social.graze.aip.atproto.oauth import oauth_refresh
23 | from social.graze.aip.model.oauth import OAuthSession, OAuthRequest
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | async def tick_health_task(app: web.Application) -> NoReturn:
29 | """
30 | Tick the health gauge every 30 seconds, reducing the health score by 1 each time.
31 | """
32 |
33 | logger.info("Starting health gauge task")
34 |
35 | health_gauge = app[HealthGaugeAppKey]
36 | while True:
37 | await health_gauge.tick()
38 | await asyncio.sleep(30)
39 |
40 |
41 | async def oauth_refresh_task(app: web.Application) -> NoReturn:
42 | """
43 | oauth_refresh_task is a background process that refreshes OAuth sessions immediately before they expire.
44 |
45 | The process is as follows:
46 |
47 | Given the queue name is "auth_refresh"
48 | Given the worker id is "worker1"
49 | Given `now = datetime.datetime.now(datetime.timezone.utc)`
50 |
51 | 1. In a redis pipeline, get some work.
52 | * Populate the worker queue with work. This stores a range of things from the begining of time to "now" into a
53 | new queue.
54 | ZRANGESTORE "auth_refresh_worker1" "auth_refresh" 1 {now} LIMIT 5
55 |
56 | * Get the work that we just populated.
57 | ZRANGE "auth_refresh_worker1" 0 -1
58 |
59 | * Store the difference between the worker queue and the main queue to remove the pulled work from the main
60 | queue.
61 | ZDIFFSTORE "auth_refresh" 2 "auth_refresh" "auth_refresh_worker1"
62 |
63 | 2. For the work that we just got, process it all and remove each from the worker queue.
64 | ZREM "auth_refresh_worker1" {work_id}
65 |
66 | 3. Sleep 15-30 seconds and repeat.
67 |
68 | This does a few things that are important to note.
69 |
70 | 1. Work is queued up and indexed (redis zindex) against the time that it needs to be processed, not when
71 | it was queued. This lets the queue be lazily evaluated and also pull work that needs to be processed
72 | soonest.
73 |
74 | 2. Work is batched into a worker queue outside of app instances, so it can be processed in parallel. If
75 | we need to scale up workers, we can do so by adjusting the deployment replica count.
76 |
77 | 3. Work is grabbed in batches that don't need to be uniform, so there is no arbitrary delay. Workers
78 | don't have to wait for 5 jobs to be ready before taking them.
79 |
80 | 4. If a worker dies, we have the temporary worker queue to recover the work that was in progress. If
81 | needed, we can create a watchdog worker that looks at orphaned worker queues and adds the work back to
82 | the main queue.
83 | """
84 |
85 | logger.info("Starting oauth refresh task")
86 |
87 | settings = app[SettingsAppKey]
88 | database_session_maker = app[DatabaseSessionMakerAppKey]
89 | http_session = app[SessionAppKey]
90 |
91 | redis_session = app[RedisClientAppKey]
92 | statsd_client = app[TelegrafStatsdClientAppKey]
93 |
94 | while True:
95 |
96 | await asyncio.sleep(10)
97 |
98 | now = datetime.now(timezone.utc)
99 |
100 | worker_queue = f"{OAUTH_REFRESH_QUEUE}:{settings.worker_id}"
101 | workers_heartbeat = f"{OAUTH_REFRESH_QUEUE}:workers"
102 |
103 | await redis_session.hset(
104 | workers_heartbeat, settings.worker_id, str(int(now.timestamp()))
105 | ) # type: ignore
106 |
107 | worker_queue_count: int = await redis_session.zcount(
108 | worker_queue, 0, int(now.timestamp())
109 | )
110 | statsd_client.gauge(
111 | "aip.task.oauth_refresh.worker_queue_count",
112 | worker_queue_count,
113 | tag_dict={"worker_id": settings.worker_id},
114 | )
115 |
116 | global_queue_count: int = await redis_session.zcount(
117 | OAUTH_REFRESH_QUEUE, 0, int(now.timestamp())
118 | )
119 | statsd_client.gauge(
120 | "aip.task.oauth_refresh.global_queue_count",
121 | global_queue_count,
122 | tag_dict={"worker_id": settings.worker_id},
123 | )
124 |
125 | if worker_queue_count == 0 and global_queue_count > 0:
126 | async with redis_session.pipeline() as redis_pipe:
127 | try:
128 | logger.debug(
129 | f"tick_task: processing {OAUTH_REFRESH_QUEUE} up to {int(now.timestamp())}"
130 | )
131 | redis_pipe.zrangestore(
132 | worker_queue,
133 | OAUTH_REFRESH_QUEUE,
134 | 0,
135 | int(now.timestamp()),
136 | num=5,
137 | offset=0,
138 | byscore=True,
139 | )
140 |
141 | redis_pipe.zdiffstore(
142 | OAUTH_REFRESH_QUEUE, [OAUTH_REFRESH_QUEUE, worker_queue]
143 | )
144 | (zrangestore_res, zdiffstore_res) = await redis_pipe.execute()
145 | statsd_client.increment(
146 | "aip.task.oauth_refresh.work_queued",
147 | zrangestore_res,
148 | tag_dict={"worker_id": settings.worker_id},
149 | )
150 | except Exception as e:
151 | sentry_sdk.capture_exception(e)
152 | logging.exception("error populating worker queue")
153 |
154 | tasks: List[Tuple[str, float]] = await redis_session.zrange(
155 | worker_queue, 0, int(now.timestamp()), byscore=True, withscores=True
156 | )
157 | if len(tasks) > 0:
158 | async with (database_session_maker() as database_session,):
159 | for session_group, deadline in tasks:
160 |
161 | logger.debug(
162 | "tick_task: processing session_group %s deadline %s",
163 | session_group,
164 | deadline,
165 | )
166 |
167 | start_time = time()
168 |
169 | try:
170 | async with database_session.begin():
171 | if isinstance(session_group, bytes):
172 | session_group = session_group.decode()
173 | oauth_session_stmt = select(OAuthSession).where(
174 | OAuthSession.session_group == session_group
175 | )
176 | oauth_session: OAuthSession = (
177 | await database_session.scalars(oauth_session_stmt)
178 | ).one()
179 |
180 | await oauth_refresh(
181 | settings,
182 | http_session,
183 | statsd_client,
184 | database_session,
185 | redis_session,
186 | oauth_session,
187 | )
188 |
189 | # Clear retry count on successful refresh
190 | session_group_str = session_group.decode() if isinstance(session_group, bytes) else session_group
191 | await redis_session.hdel(OAUTH_REFRESH_RETRY_QUEUE, session_group_str)
192 |
193 | except Exception as e:
194 | sentry_sdk.capture_exception(e)
195 | logging.exception(
196 | "error processing session group %s", session_group
197 | )
198 |
199 | # Implement retry logic with exponential backoff
200 | session_group_str = session_group.decode() if isinstance(session_group, bytes) else session_group
201 | current_retries = await redis_session.hget(OAUTH_REFRESH_RETRY_QUEUE, session_group_str)
202 | current_retries = int(current_retries) if current_retries else 0
203 |
204 | if current_retries < settings.oauth_refresh_max_retries:
205 | # Calculate exponential backoff delay
206 | retry_delay = settings.oauth_refresh_retry_base_delay * (2 ** current_retries)
207 | retry_timestamp = int(now.timestamp()) + retry_delay
208 |
209 | # Re-queue with delay and increment retry count
210 | await redis_session.zadd(OAUTH_REFRESH_QUEUE, {session_group_str: retry_timestamp})
211 | await redis_session.hset(OAUTH_REFRESH_RETRY_QUEUE, session_group_str, current_retries + 1)
212 |
213 | logging.info(
214 | "Scheduled retry %d/%d for session_group %s in %d seconds",
215 | current_retries + 1,
216 | settings.oauth_refresh_max_retries,
217 | session_group_str,
218 | retry_delay
219 | )
220 |
221 | statsd_client.increment(
222 | "aip.task.oauth_refresh.retry_scheduled",
223 | 1,
224 | tag_dict={
225 | "retry_attempt": str(current_retries + 1),
226 | "worker_id": settings.worker_id,
227 | },
228 | )
229 | else:
230 | # Max retries exceeded, give up and clean up retry tracking
231 | await redis_session.hdel(OAUTH_REFRESH_RETRY_QUEUE, session_group_str)
232 |
233 | logging.error(
234 | "Max retries exceeded for session_group %s, giving up after %d attempts",
235 | session_group_str,
236 | settings.oauth_refresh_max_retries
237 | )
238 |
239 | statsd_client.increment(
240 | "aip.task.oauth_refresh.max_retries_exceeded",
241 | 1,
242 | tag_dict={"worker_id": settings.worker_id},
243 | )
244 |
245 | # TODO: Don't actually tag session_group because cardinality will be very high.
246 | statsd_client.increment(
247 | "aip.task.oauth_refresh.exception",
248 | 1,
249 | tag_dict={
250 | "exception": type(e).__name__,
251 | "session_group": session_group,
252 | "worker_id": settings.worker_id,
253 | },
254 | )
255 |
256 | finally:
257 | statsd_client.timer(
258 | "aip.task.oauth_refresh.time",
259 | time() - start_time,
260 | tag_dict={"worker_id": settings.worker_id},
261 | )
262 | # TODO: Probably don't need this because it is the same as `COUNT(aip.task.oauth_refresh.time)`
263 | statsd_client.increment(
264 | "aip.task.oauth_refresh.count",
265 | 1,
266 | tag_dict={"worker_id": settings.worker_id},
267 | )
268 | await redis_session.zrem(worker_queue, session_group)
269 |
270 |
271 | async def app_password_refresh_task(app: web.Application) -> NoReturn:
272 |
273 | logger.info("Starting app password refresh task")
274 |
275 | settings = app[SettingsAppKey]
276 | database_session_maker = app[DatabaseSessionMakerAppKey]
277 | http_session = app[SessionAppKey]
278 | redis_session = app[RedisClientAppKey]
279 | statsd_client = app[TelegrafStatsdClientAppKey]
280 |
281 | while True:
282 | try:
283 | await asyncio.sleep(10)
284 |
285 | now = datetime.now(timezone.utc)
286 |
287 | worker_queue = f"{APP_PASSWORD_REFRESH_QUEUE}:{settings.worker_id}"
288 | workers_heartbeat = f"{APP_PASSWORD_REFRESH_QUEUE}:workers"
289 |
290 | await redis_session.hset(
291 | workers_heartbeat, settings.worker_id, str(int(now.timestamp()))
292 | ) # type: ignore
293 |
294 | worker_queue_count: int = await redis_session.zcount(
295 | worker_queue, 0, int(now.timestamp())
296 | )
297 | statsd_client.gauge(
298 | "aip.task.app_password_refresh.worker_queue_count",
299 | worker_queue_count,
300 | tag_dict={"worker_id": settings.worker_id},
301 | )
302 |
303 | global_queue_count: int = await redis_session.zcount(
304 | APP_PASSWORD_REFRESH_QUEUE, 0, int(now.timestamp())
305 | )
306 | statsd_client.gauge(
307 | "aip.task.app_password_refresh.global_queue_count",
308 | global_queue_count,
309 | tag_dict={"worker_id": settings.worker_id},
310 | )
311 |
312 | if worker_queue_count == 0 and global_queue_count > 0:
313 | async with redis_session.pipeline() as redis_pipe:
314 | try:
315 | logger.debug(
316 | f"tick_task: processing {APP_PASSWORD_REFRESH_QUEUE} up to {int(now.timestamp())}"
317 | )
318 |
319 | redis_pipe.zrangestore(
320 | worker_queue,
321 | APP_PASSWORD_REFRESH_QUEUE,
322 | 0,
323 | int(now.timestamp()),
324 | num=5,
325 | offset=0,
326 | byscore=True,
327 | )
328 |
329 | redis_pipe.zdiffstore(
330 | APP_PASSWORD_REFRESH_QUEUE,
331 | [APP_PASSWORD_REFRESH_QUEUE, worker_queue],
332 | )
333 |
334 | (zrangestore_res, zdiffstore_res) = await redis_pipe.execute()
335 | statsd_client.increment(
336 | "aip.task.app_password_refresh.work_queued",
337 | zrangestore_res,
338 | tag_dict={"worker_id": settings.worker_id},
339 | )
340 | except Exception as e:
341 | sentry_sdk.capture_exception(e)
342 | logging.exception("error populating app password worker queue")
343 |
344 | tasks: List[Tuple[str, float]] = await redis_session.zrange(
345 | worker_queue, 0, int(now.timestamp()), byscore=True, withscores=True
346 | )
347 |
348 | if len(tasks) > 0:
349 | for handle_guid, deadline in tasks:
350 |
351 | logger.debug(
352 | "tick_task: processing guid %s deadline %s",
353 | handle_guid,
354 | deadline,
355 | )
356 |
357 | start_time = time()
358 | try:
359 | await populate_session(
360 | http_session,
361 | database_session_maker,
362 | redis_session,
363 | handle_guid,
364 | settings,
365 | )
366 |
367 | except Exception as e:
368 | sentry_sdk.capture_exception(e)
369 | logging.exception("error processing guid %s", handle_guid)
370 | # TODO: Don't actually tag session_group because cardinality will be very high.
371 | statsd_client.increment(
372 | "aip.task.app_password_refresh.exception",
373 | 1,
374 | tag_dict={
375 | "exception": type(e).__name__,
376 | "guid": handle_guid,
377 | "worker_id": settings.worker_id,
378 | },
379 | )
380 |
381 | finally:
382 | statsd_client.timer(
383 | "aip.task.app_password_refresh.time",
384 | time() - start_time,
385 | tag_dict={"worker_id": settings.worker_id},
386 | )
387 | # TODO: Probably don't need this because it is the same as
388 | # `COUNT(aip.task.app_password_refresh.time)`
389 | statsd_client.increment(
390 | "aip.task.app_password_refresh.count",
391 | 1,
392 | tag_dict={"worker_id": settings.worker_id},
393 | )
394 | await redis_session.zrem(worker_queue, handle_guid)
395 |
396 | except Exception as e:
397 | sentry_sdk.capture_exception(e)
398 | logging.exception("app password tick failed")
399 |
400 |
401 | async def oauth_cleanup_task(app: web.Application) -> NoReturn:
402 | """
403 | Background task to clean up expired OAuth records.
404 |
405 | This task runs every hour and removes:
406 | - OAuthRequest records that have expired
407 | - OAuthSession records that have reached their hard expiration
408 |
409 | This prevents database bloat from accumulating expired records.
410 | """
411 | logger.info("Starting OAuth cleanup task")
412 |
413 | settings = app[SettingsAppKey]
414 | database_session_maker = app[DatabaseSessionMakerAppKey]
415 | statsd_client = app[TelegrafStatsdClientAppKey]
416 |
417 | while True:
418 | try:
419 | # Run cleanup every hour
420 | await asyncio.sleep(3600)
421 |
422 | now = datetime.now(timezone.utc)
423 |
424 | async with (database_session_maker() as database_session,):
425 | async with database_session.begin():
426 | # Clean up expired OAuthRequest records
427 | expired_requests_stmt = delete(OAuthRequest).where(
428 | OAuthRequest.expires_at < now
429 | )
430 | expired_requests_result = await database_session.execute(expired_requests_stmt)
431 | expired_requests_count = expired_requests_result.rowcount
432 |
433 | # Clean up expired OAuthSession records
434 | expired_sessions_stmt = delete(OAuthSession).where(
435 | OAuthSession.hard_expires_at < now
436 | )
437 | expired_sessions_result = await database_session.execute(expired_sessions_stmt)
438 | expired_sessions_count = expired_sessions_result.rowcount
439 |
440 | await database_session.commit()
441 |
442 | if expired_requests_count > 0 or expired_sessions_count > 0:
443 | logger.info(
444 | "Cleaned up %d expired OAuth requests and %d expired OAuth sessions",
445 | expired_requests_count,
446 | expired_sessions_count
447 | )
448 |
449 | # Report metrics
450 | statsd_client.increment(
451 | "aip.task.oauth_cleanup.expired_requests_removed",
452 | expired_requests_count,
453 | tag_dict={"worker_id": settings.worker_id},
454 | )
455 | statsd_client.increment(
456 | "aip.task.oauth_cleanup.expired_sessions_removed",
457 | expired_sessions_count,
458 | tag_dict={"worker_id": settings.worker_id},
459 | )
460 |
461 | except Exception as e:
462 | sentry_sdk.capture_exception(e)
463 | logging.exception("OAuth cleanup task failed")
464 |
--------------------------------------------------------------------------------
/src/social/graze/aip/app/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/graze-social/aip/c740a2cf63fdf5a9d2cef6da3bfb8dd660e15ea0/src/social/graze/aip/app/util/__init__.py
--------------------------------------------------------------------------------
/src/social/graze/aip/app/util/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import logging
4 | from typing import Any, Dict, Optional
5 | import aiohttp
6 | from jwcrypto import jwk
7 | from ulid import ULID
8 | import base64
9 | from cryptography.fernet import Fernet
10 |
11 | from social.graze.aip.resolve.handle import resolve_subject
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | async def genAppPassword(
17 | handle: str,
18 | password: str,
19 | label: str,
20 | plc_hostname: str,
21 | auth_token: Optional[str] = None,
22 | ) -> None:
23 | async with aiohttp.ClientSession() as http_session:
24 | resolved_handle = await resolve_subject(http_session, plc_hostname, handle)
25 | assert resolved_handle is not None
26 |
27 | async with aiohttp.ClientSession() as http_session:
28 | create_session_url = (
29 | f"{resolved_handle.pds}/xrpc/com.atproto.server.createSession"
30 | )
31 | create_session_body = {
32 | "identifier": resolved_handle.did,
33 | "password": password,
34 | }
35 | if auth_token is not None:
36 | create_session_body["authToken"] = auth_token
37 |
38 | async with http_session.post(
39 | create_session_url, json=create_session_body
40 | ) as resp:
41 | assert resp.status == 200
42 |
43 | created_session: Dict[str, Any] = await resp.json()
44 |
45 | assert created_session.get("did", str) == resolved_handle.did
46 | assert created_session.get("handle", str) == handle
47 |
48 | create_app_password_url = (
49 | f"{resolved_handle.pds}/xrpc/com.atproto.server.createAppPassword"
50 | )
51 | create_app_password_body = {"name": label}
52 | create_app_password_headers = {
53 | "Authorization": f"Bearer {created_session['accessJwt']}"
54 | }
55 |
56 | async with http_session.post(
57 | create_app_password_url,
58 | headers=create_app_password_headers,
59 | json=create_app_password_body,
60 | ) as resp:
61 | if resp.status != 200:
62 | print("Error creating app password: %s", await resp.text())
63 | return
64 | app_password = await resp.json()
65 | print(
66 | f"{app_password["name"]} created {app_password["createdAt"]}: {app_password["password"]}"
67 | )
68 |
69 |
70 | async def genJwk() -> None:
71 | key = jwk.JWK.generate(kty="EC", crv="P-256", kid=str(ULID()), alg="ES256")
72 | print(key.export(private_key=True))
73 |
74 |
75 | async def genCryptoKey() -> None:
76 | key = Fernet.generate_key()
77 | print(base64.b64encode(key).decode("utf-8"))
78 |
79 |
80 | async def realMain() -> None:
81 | parser = argparse.ArgumentParser(prog="aiputil", description="AIP utilities")
82 |
83 | parser.add_argument(
84 | "--plc-hostname",
85 | default="plc.directory",
86 | help="The PLC hostname to use for resolving did-method-plc DIDs.",
87 | )
88 |
89 | subparsers = parser.add_subparsers(dest="command", required=True)
90 |
91 | _ = subparsers.add_parser("gen-jwk", help="Generate a JWK")
92 | _ = subparsers.add_parser("gen-crypto", help="Generate an encryption key")
93 | gen_app_password = subparsers.add_parser(
94 | "gen-app-password", help="Generate an app-password"
95 | )
96 |
97 | gen_app_password.add_argument("handle", help="The handle to authenticate with.")
98 | gen_app_password.add_argument("password", help="The password to authenticate with.")
99 | gen_app_password.add_argument("label", help="The label for the app-password.")
100 |
101 | args = vars(parser.parse_args())
102 | command = args.get("command", None)
103 |
104 | if command == "gen-jwk":
105 | await genJwk()
106 | elif command == "gen-crypto":
107 | await genCryptoKey()
108 | elif command == "gen-app-password":
109 | handle: str = args.get("handle", str)
110 | password: str = args.get("password", str)
111 | label: str = args.get("label", str)
112 | plc_hostname: str = args.get("plc_hostname", str)
113 | await genAppPassword(handle, password, label, plc_hostname)
114 |
115 |
116 | def main() -> None:
117 | asyncio.run(realMain())
118 |
119 |
120 | if __name__ == "__main__":
121 | main()
122 |
--------------------------------------------------------------------------------
/src/social/graze/aip/atproto/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | AT Protocol Integration
3 |
4 | This package provides integration with the AT Protocol, handling authentication flows
5 | and communication with Personal Data Server (PDS) instances.
6 |
7 | Key Components:
8 | - app_password.py: Implementation of app password authentication
9 | - oauth.py: Implementation of OAuth flows including authorization code and refresh token
10 | - chain.py: Middleware chain for API requests (DPoP, claims, metrics)
11 | - pds.py: Interaction with PDS (Personal Data Server) instances
12 |
13 | Key Features:
14 | - OAuth 2.0 flow implementation with PKCE
15 | - DPoP (Demonstrating Proof-of-Possession) for access tokens
16 | - JWT-based client assertion for secure client authentication
17 | - Proactive token refresh to maintain session validity
18 |
19 | The authentication flow follows these steps:
20 | 1. Initialize OAuth flow with subject (handle or DID)
21 | 2. Generate PKCE challenge and redirect to authorization server
22 | 3. Complete OAuth flow with authorization code
23 | 4. Store tokens and set up refresh schedule
24 | 5. Refresh tokens before expiry to maintain session validity
25 |
26 | All communication with AT Protocol services uses middleware chains for
27 | consistent handling of authentication, metrics, and error reporting.
28 | """
--------------------------------------------------------------------------------
/src/social/graze/aip/atproto/app_password.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone, timedelta
2 | import logging
3 | from typing import Any, Dict, Optional
4 | from aiohttp import ClientSession
5 | import redis.asyncio as redis
6 | from sqlalchemy import select, update
7 | from sqlalchemy.ext.asyncio import (
8 | async_sessionmaker,
9 | AsyncSession,
10 | )
11 | from sqlalchemy.dialects.postgresql import insert
12 | import sentry_sdk
13 | from social.graze.aip.model.app_password import AppPassword, AppPasswordSession
14 | from social.graze.aip.model.handles import Handle
15 | from social.graze.aip.app.config import APP_PASSWORD_REFRESH_QUEUE, Settings
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | async def populate_session(
21 | http_session: ClientSession,
22 | database_session_maker: async_sessionmaker[AsyncSession],
23 | redis_session: redis.Redis,
24 | subject_guid: str,
25 | settings: Optional[Settings] = None,
26 | ) -> None:
27 | now = datetime.now(timezone.utc)
28 | async with database_session_maker() as database_session:
29 | async with database_session.begin():
30 | # 1. Get the handle by guid
31 | handle_stmt = select(Handle).where(Handle.guid == subject_guid)
32 | handle: Optional[Handle] = (
33 | await database_session.scalars(handle_stmt)
34 | ).first()
35 |
36 | if handle is None:
37 | raise ValueError(f"Handle not found for guid: {subject_guid}")
38 |
39 | # 2. Get the AppPassword by guid
40 | app_password_stmt = select(AppPassword).where(
41 | AppPassword.guid == subject_guid
42 | )
43 | app_password: Optional[AppPassword] = (
44 | await database_session.scalars(app_password_stmt)
45 | ).first()
46 |
47 | if app_password is None:
48 | raise ValueError(f"App password not found for guid: {subject_guid}")
49 |
50 | # 3. Get optional AppPasswordSession by guid
51 | app_password_session_stmt = select(AppPasswordSession).where(
52 | AppPasswordSession.guid == subject_guid
53 | )
54 | app_password_session: Optional[AppPasswordSession] = (
55 | await database_session.scalars(app_password_session_stmt)
56 | ).first()
57 |
58 | start_over = False
59 | access_token: str | None = None
60 |
61 | # Use settings if provided, otherwise use defaults
62 | access_token_expiry = settings.app_password_access_token_expiry if settings else 720
63 | refresh_token_expiry = settings.app_password_refresh_token_expiry if settings else 7776000
64 |
65 | # TODO: Pull this from the access token JWT claims payload
66 | access_token_expires_at = now + timedelta(0, access_token_expiry)
67 |
68 | refresh_token: str | None = None
69 |
70 | # TODO: Pull this from the refresh token JWT claims payload
71 | refresh_token_expires_at = now + timedelta(0, refresh_token_expiry)
72 |
73 | # 4. If AppPasswordSession exists: refresh it, update row, and return
74 | if app_password_session is not None:
75 | try:
76 | refresh_url = f"{handle.pds}/xrpc/com.atproto.server.refreshSession"
77 | headers = {
78 | "Authorization": f"Bearer {app_password_session.refresh_token}"
79 | }
80 | # This could fail if the app password was revoked or if the server isn't honoring the expiration
81 | # time of the refresh token. It's more likely that a user will remove / replace an app password.
82 | # When that happens, we want to start over and create a new session.
83 | async with http_session.post(
84 | refresh_url, headers=headers
85 | ) as response:
86 | if response.status != 200:
87 | raise Exception(
88 | f"Failed to refresh session: {response.status}"
89 | )
90 |
91 | body: Dict[str, Any] = await response.json()
92 | access_token = body.get("accessJwt", None)
93 | refresh_token = body.get("refreshJwt", None)
94 | is_active = body.get("active", False)
95 | found_did = body.get("did", "")
96 |
97 | if found_did != handle.did:
98 | start_over = True
99 |
100 | if not is_active:
101 | start_over = True
102 |
103 | except Exception as e:
104 | sentry_sdk.capture_exception(e)
105 | logger.exception("Error refreshing session")
106 | start_over = True
107 |
108 | if (
109 | access_token is not None
110 | and refresh_token is not None
111 | and not start_over
112 | ):
113 | update_session_stmt = (
114 | update(AppPasswordSession)
115 | .where(
116 | AppPasswordSession.guid == app_password_session.guid,
117 | )
118 | .values(
119 | access_token=access_token,
120 | access_token_expires_at=access_token_expires_at,
121 | refresh_token=refresh_token,
122 | refresh_token_expires_at=refresh_token_expires_at,
123 | )
124 | )
125 | await database_session.execute(update_session_stmt)
126 |
127 | refresh_ratio = settings.token_refresh_before_expiry_ratio if settings else 0.8
128 | expires_in_mod = access_token_expiry * refresh_ratio
129 | refresh_at = now + timedelta(0, expires_in_mod)
130 | await redis_session.zadd(
131 | APP_PASSWORD_REFRESH_QUEUE,
132 | {handle.guid: int(refresh_at.timestamp())},
133 | )
134 |
135 | if app_password_session is None or start_over:
136 | try:
137 | refresh_url = f"{handle.pds}/xrpc/com.atproto.server.createSession"
138 | headers = {}
139 | payload = {
140 | "identifier": handle.did,
141 | "password": app_password.app_password,
142 | }
143 | async with http_session.post(
144 | refresh_url, headers=headers, json=payload
145 | ) as response:
146 | if response.status != 200:
147 | raise Exception(
148 | f"Failed to refresh session: {response.status}"
149 | )
150 |
151 | body: Dict[str, Any] = await response.json()
152 | access_token = body.get("accessJwt", None)
153 | refresh_token = body.get("refreshJwt", None)
154 | is_active = body.get("active", False)
155 | found_did = body.get("did", "")
156 |
157 | # It'd be pretty wild if this didn't match the handle, but would also lead to some really
158 | # unexpected behavior.
159 | if found_did != handle.did:
160 | raise ValueError(
161 | f"Handle did does not match found did: {handle.did} != {found_did}"
162 | )
163 |
164 | if not is_active:
165 | raise ValueError("Handle is not active.")
166 |
167 | except Exception as e:
168 | sentry_sdk.capture_exception(e)
169 | logger.exception("Error creating session")
170 |
171 | # 5. Create new AppPasswordSession
172 | if access_token is not None and refresh_token is not None:
173 | update_session_stmt = (
174 | insert(AppPasswordSession)
175 | .values(
176 | [
177 | {
178 | "guid": handle.guid,
179 | "access_token": access_token,
180 | "access_token_expires_at": access_token_expires_at,
181 | "refresh_token": refresh_token,
182 | "refresh_token_expires_at": refresh_token_expires_at,
183 | "created_at": now,
184 | }
185 | ]
186 | )
187 | .on_conflict_do_update(
188 | index_elements=["guid"],
189 | set_={
190 | "access_token": access_token,
191 | "access_token_expires_at": access_token_expires_at,
192 | "refresh_token": refresh_token,
193 | "refresh_token_expires_at": refresh_token_expires_at,
194 | },
195 | )
196 | )
197 | await database_session.execute(update_session_stmt)
198 |
199 | refresh_ratio = settings.token_refresh_before_expiry_ratio if settings else 0.8
200 | expires_in_mod = access_token_expiry * refresh_ratio
201 | refresh_at = now + timedelta(0, expires_in_mod)
202 | await redis_session.zadd(
203 | APP_PASSWORD_REFRESH_QUEUE,
204 | {handle.guid: int(refresh_at.timestamp())},
205 | )
206 |
--------------------------------------------------------------------------------
/src/social/graze/aip/atproto/chain.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | import json
4 | import re
5 | import secrets
6 | from time import time
7 | from types import TracebackType
8 | from typing import (
9 | Any,
10 | Awaitable,
11 | Callable,
12 | Generator,
13 | Optional,
14 | Sequence,
15 | Tuple,
16 | Union,
17 | Protocol,
18 | Dict,
19 | )
20 | import logging
21 | from aio_statsd import TelegrafStatsdClient
22 | from aiohttp import web, ClientResponse, ClientSession, FormData, hdrs
23 | from aiohttp.typedefs import StrOrURL
24 | from multidict import CIMultiDictProxy
25 | from jwcrypto import jwt, jwk
26 | import sentry_sdk
27 |
28 | RequestFunc = Callable[..., Awaitable[ClientResponse]]
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | class _LoggerStub(Protocol):
34 | """_Logger defines which methods logger object should have."""
35 |
36 | @abstractmethod
37 | def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
38 | pass
39 |
40 | @abstractmethod
41 | def warning(self, msg: str, *args: Any, **kwargs: Any) -> None:
42 | pass
43 |
44 | @abstractmethod
45 | def exception(self, msg: str, *args: Any, **kwargs: Any) -> None:
46 | pass
47 |
48 |
49 | _LoggerType = Union[_LoggerStub, logging.Logger]
50 |
51 |
52 | @dataclass
53 | class ChainRequest:
54 | method: str
55 | url: StrOrURL
56 | headers: dict[str, Any] | None = None
57 | trace_request_ctx: dict[str, Any] | None = None
58 | kwargs: dict[str, Any] | None = None
59 |
60 | # TODO: Fix bug where IDE doesn't like `def from_chain_request(request: Self) -> Self`
61 | @staticmethod
62 | def from_chain_request(request: "ChainRequest") -> "ChainRequest":
63 | return ChainRequest(
64 | method=request.method,
65 | url=request.url,
66 | headers=request.headers,
67 | trace_request_ctx=request.trace_request_ctx,
68 | kwargs=request.kwargs,
69 | )
70 |
71 |
72 | @dataclass
73 | class ChainResponse:
74 | status: int
75 | headers: CIMultiDictProxy[str]
76 | body: str | bytes | dict[str, Any] | None = None
77 | # exception: BaseException | None = None
78 |
79 | # TODO: Fix bug where IDE doesn't like `-> Self`
80 | @staticmethod
81 | async def from_aiohttp_response(response: ClientResponse) -> "ChainResponse":
82 | status = response.status
83 | headers = response.headers
84 |
85 | content_type = response.headers.get(hdrs.CONTENT_TYPE, "")
86 |
87 | if content_type.startswith("application/json"):
88 | return ChainResponse(
89 | status=status, headers=headers, body=await response.json()
90 | )
91 | elif content_type.startswith("text/"):
92 | return ChainResponse(
93 | status=status, headers=headers, body=await response.text()
94 | )
95 | else:
96 | return ChainResponse(
97 | status=status, headers=headers, body=await response.read()
98 | )
99 |
100 | def body_contains(self, text: str) -> bool:
101 | if self.body is None:
102 | return False
103 |
104 | if isinstance(self.body, str):
105 | return text in self.body
106 |
107 | elif isinstance(self.body, bytes):
108 | return text.encode("utf-8") in self.body
109 |
110 | elif isinstance(self.body, dict):
111 | return text in self.body
112 |
113 | return False
114 |
115 | def body_matches_kv(self, key: str, value: Any) -> bool:
116 | if self.body is None:
117 | return False
118 |
119 | return (
120 | isinstance(self.body, dict) and key in self.body and self.body[key] == value
121 | )
122 |
123 | def to_web_response(self) -> web.Response:
124 | rargs = {
125 | "status": self.status,
126 | "headers": {"Content-Type": self.headers.get(hdrs.CONTENT_TYPE, "")},
127 | }
128 | if isinstance(self.body, str):
129 | rargs["text"] = self.body
130 | elif isinstance(self.body, bytes):
131 | rargs["body"] = self.body
132 | elif isinstance(self.body, dict):
133 | rargs["body"] = json.dumps(self.body)
134 | return web.Response(**rargs)
135 |
136 |
137 | NextChainResponseCallbackType = (
138 | Tuple[ClientResponse, ChainResponse]
139 | | Tuple[ClientResponse, ChainResponse, ChainRequest]
140 | )
141 |
142 | NextChainCallbackType = Callable[
143 | [ChainRequest], Awaitable[NextChainResponseCallbackType]
144 | ]
145 |
146 |
147 | class RequestMiddlewareBase(ABC):
148 | @abstractmethod
149 | async def handle(
150 | self, next: NextChainCallbackType, request: ChainRequest
151 | ) -> NextChainResponseCallbackType:
152 | pass
153 |
154 | def handle_gen(self, next: NextChainCallbackType) -> NextChainCallbackType:
155 | async def next_invoke(request: ChainRequest) -> NextChainResponseCallbackType:
156 | return await self.handle(next, request)
157 |
158 | return next_invoke
159 |
160 |
161 | class StatsdMiddleware(RequestMiddlewareBase):
162 | def __init__(
163 | self,
164 | statsd_client: TelegrafStatsdClient,
165 | ) -> None:
166 | super().__init__()
167 | self._statsd_client = statsd_client
168 | self._regex = re.compile(r"\W+")
169 |
170 | async def handle(
171 | self, next: NextChainCallbackType, request: ChainRequest
172 | ) -> NextChainResponseCallbackType:
173 | start_time = time()
174 | try:
175 | return await next(request)
176 | except Exception as e:
177 | sentry_sdk.capture_exception(e)
178 | raise e
179 | finally:
180 | # TODO: Don't record URL. Cardinality is too high here and will cause issues.
181 | self._statsd_client.timer(
182 | "aip.client.request.time",
183 | time() - start_time,
184 | tag_dict={
185 | "url": self._regex.sub("", str(request.url)),
186 | "method": request.method.lower(),
187 | },
188 | )
189 | self._statsd_client.increment(
190 | "aip.client.request.count",
191 | 1,
192 | tag_dict={
193 | "url": self._regex.sub("", str(request.url)),
194 | "method": request.method.lower(),
195 | },
196 | )
197 |
198 |
199 | class DebugMiddleware(RequestMiddlewareBase):
200 | async def handle(
201 | self, next: NextChainCallbackType, request: ChainRequest
202 | ) -> NextChainResponseCallbackType:
203 | logger.debug(
204 | f"Request: {request.method} {request.url} {request.headers} {request.kwargs}"
205 | )
206 | response = await next(request)
207 | logger.debug(
208 | f"Response: {response[1].status} {response[1].headers} {response[1].body}"
209 | )
210 | return response
211 |
212 |
213 | class GenerateClaimAssertionMiddleware(RequestMiddlewareBase):
214 | def __init__(
215 | self,
216 | signing_key: jwk.JWK,
217 | client_assertion_header: Dict[str, Any],
218 | client_assertion_claims: Dict[str, Any],
219 | ) -> None:
220 | super().__init__()
221 | self._signing_key = signing_key
222 | self._client_assertion_header = client_assertion_header
223 | self._client_assertion_claims = client_assertion_claims
224 |
225 | async def handle(
226 | self, next: NextChainCallbackType, request: ChainRequest
227 | ) -> NextChainResponseCallbackType:
228 |
229 | if request.kwargs is None:
230 | return await next(request)
231 |
232 | self._client_assertion_claims["jti"] = secrets.token_urlsafe(32)
233 | claims_assertation = jwt.JWT(
234 | header=self._client_assertion_header,
235 | claims=self._client_assertion_claims,
236 | )
237 | claims_assertation.make_signed_token(self._signing_key)
238 | claims_assertation_token = claims_assertation.serialize()
239 |
240 | data: Optional[FormData] = None
241 | if request.kwargs is not None:
242 | data = request.kwargs.get("data", None)
243 |
244 | if data is None:
245 | data = FormData()
246 |
247 | data.add_field("client_assertion", claims_assertation_token)
248 |
249 | request.kwargs["data"] = data
250 |
251 | return await next(request)
252 |
253 |
254 | class GenerateDpopMiddleware(RequestMiddlewareBase):
255 | def __init__(
256 | self,
257 | dpop_key: jwk.JWK,
258 | dop_assertion_header: Dict[str, Any],
259 | dop_assertion_claims: Dict[str, Any],
260 | ) -> None:
261 | super().__init__()
262 | self._dpop_key = dpop_key
263 | self._dpop_assertion_header = dop_assertion_header
264 | self._dpop_assertion_claims = dop_assertion_claims
265 |
266 | async def handle(
267 | self, next: NextChainCallbackType, request: ChainRequest
268 | ) -> NextChainResponseCallbackType:
269 | self._dpop_assertion_claims["jti"] = secrets.token_urlsafe(32)
270 |
271 | dpop_assertation = jwt.JWT(
272 | header=self._dpop_assertion_header,
273 | claims=self._dpop_assertion_claims,
274 | )
275 | dpop_assertation.make_signed_token(self._dpop_key)
276 | dpop_assertation_token = dpop_assertation.serialize()
277 |
278 | if request.headers is None:
279 | request.headers = {}
280 | request.headers["DPoP"] = dpop_assertation_token
281 |
282 | response = await next(request)
283 | client_response = response[0]
284 | chain_response = response[1]
285 | new_request = None
286 | if len(response) == 3:
287 | new_request = response[2]
288 |
289 | if chain_response.status == 401 or chain_response.status == 400:
290 | if chain_response.headers is None:
291 | raise ValueError("Response headers are None")
292 |
293 | if chain_response.body_matches_kv(
294 | "error", "invalid_dpop_proof"
295 | ) or chain_response.body_matches_kv("error", "use_dpop_nonce"):
296 | self._dpop_assertion_claims["nonce"] = chain_response.headers.get(
297 | "DPoP-Nonce", ""
298 | )
299 |
300 | if new_request is None:
301 | new_request = ChainRequest.from_chain_request(request)
302 |
303 | if new_request is None:
304 | return client_response, chain_response
305 | return client_response, chain_response, new_request
306 |
307 |
308 | class EndOfLineChainMiddleware:
309 | def __init__(
310 | self,
311 | request_func: RequestFunc,
312 | logger: _LoggerType,
313 | raise_for_status: bool = False,
314 | ) -> None:
315 | super().__init__()
316 | self._request_func = request_func
317 | self._raise_for_status = raise_for_status
318 | self._logger = logger
319 |
320 | async def handle(self, request: ChainRequest) -> NextChainResponseCallbackType:
321 | response: ClientResponse = await self._request_func(
322 | request.method.lower(),
323 | request.url,
324 | headers=request.headers,
325 | trace_request_ctx={
326 | **(request.trace_request_ctx or {}),
327 | },
328 | **(request.kwargs or {}),
329 | )
330 |
331 | if self._raise_for_status:
332 | response.raise_for_status()
333 |
334 | return response, await ChainResponse.from_aiohttp_response(response)
335 |
336 |
337 | class ChainMiddlewareContext:
338 | def __init__(
339 | self,
340 | chain_callback: NextChainCallbackType,
341 | chain_request: ChainRequest,
342 | logger: _LoggerType,
343 | raise_for_status: bool = False,
344 | attempt_max: int = 3,
345 | ) -> None:
346 | self._chain_callback = chain_callback
347 | self._chain_request = chain_request
348 | self._logger = logger
349 | self._raise_for_status = raise_for_status
350 |
351 | self._chain_response: ChainResponse | None = None
352 | self.client_response: ClientResponse | None = None
353 |
354 | self._attempt_max = attempt_max
355 |
356 | async def _do_request(self) -> Tuple[ClientResponse, ChainResponse]:
357 | current_attempt = 0
358 |
359 | chain_request = self._chain_request
360 |
361 | while True:
362 | current_attempt += 1
363 |
364 | if current_attempt > self._attempt_max:
365 | raise Exception("Max attempts reached")
366 |
367 | response = await self._chain_callback(chain_request)
368 | client_response = response[0]
369 | chain_response = response[1]
370 | new_request = None
371 | if len(response) == 3:
372 | new_request = response[2]
373 |
374 | self._chain_response = chain_response
375 | self._client_response = client_response
376 |
377 | if new_request is None:
378 | return client_response, chain_response
379 |
380 | chain_request = new_request
381 |
382 | if self._raise_for_status:
383 | client_response.raise_for_status()
384 |
385 | def __await__(self) -> Generator[Any, None, Tuple[ClientResponse, ChainResponse]]:
386 | return self.__aenter__().__await__()
387 |
388 | async def __aenter__(self) -> Tuple[ClientResponse, ChainResponse]:
389 | return await self._do_request()
390 |
391 | async def __aexit__(
392 | self,
393 | exc_type: type[BaseException] | None,
394 | exc_val: BaseException | None,
395 | exc_tb: TracebackType | None,
396 | ) -> None:
397 | if self.client_response is not None and not self.client_response.closed:
398 | self.client_response.close()
399 |
400 |
401 | class ChainMiddlewareClient:
402 | def __init__(
403 | self,
404 | client_session: ClientSession | None = None,
405 | logger: _LoggerType | None = None,
406 | middleware: Sequence[RequestMiddlewareBase] | None = None,
407 | raise_for_status: bool = False,
408 | *args: Any,
409 | **kwargs: Any,
410 | ) -> None:
411 | if client_session is not None:
412 | client = client_session
413 | closed = None
414 | else:
415 | client = ClientSession(*args, **kwargs)
416 | closed = False
417 |
418 | self._middleware = middleware
419 |
420 | self._client = client
421 | self._closed = closed
422 |
423 | self._logger: _LoggerType = logger or logging.getLogger("aiohttp_chain")
424 | self._raise_for_status = raise_for_status
425 |
426 | def request(
427 | self,
428 | method: str,
429 | url: StrOrURL,
430 | raise_for_status: bool | None = None,
431 | **kwargs: Any,
432 | ) -> ChainMiddlewareContext:
433 | return self._make_request(
434 | method=method,
435 | url=url,
436 | **kwargs,
437 | )
438 |
439 | def get(
440 | self,
441 | url: StrOrURL,
442 | raise_for_status: bool | None = None,
443 | **kwargs: Any,
444 | ) -> ChainMiddlewareContext:
445 | return self._make_request(
446 | method=hdrs.METH_GET,
447 | url=url,
448 | **kwargs,
449 | )
450 |
451 | def options(
452 | self,
453 | url: StrOrURL,
454 | raise_for_status: bool | None = None,
455 | **kwargs: Any,
456 | ) -> ChainMiddlewareContext:
457 | return self._make_request(
458 | method=hdrs.METH_OPTIONS,
459 | url=url,
460 | **kwargs,
461 | )
462 |
463 | def head(
464 | self,
465 | url: StrOrURL,
466 | raise_for_status: bool | None = None,
467 | **kwargs: Any,
468 | ) -> ChainMiddlewareContext:
469 | return self._make_request(
470 | method=hdrs.METH_HEAD,
471 | url=url,
472 | **kwargs,
473 | )
474 |
475 | def post(
476 | self,
477 | url: StrOrURL,
478 | raise_for_status: bool | None = None,
479 | **kwargs: Any,
480 | ) -> ChainMiddlewareContext:
481 | return self._make_request(
482 | method=hdrs.METH_POST,
483 | url=url,
484 | **kwargs,
485 | )
486 |
487 | def put(
488 | self,
489 | url: StrOrURL,
490 | raise_for_status: bool | None = None,
491 | **kwargs: Any,
492 | ) -> ChainMiddlewareContext:
493 | return self._make_request(
494 | method=hdrs.METH_PUT,
495 | url=url,
496 | **kwargs,
497 | )
498 |
499 | def patch(
500 | self,
501 | url: StrOrURL,
502 | raise_for_status: bool | None = None,
503 | **kwargs: Any,
504 | ) -> ChainMiddlewareContext:
505 | return self._make_request(
506 | method=hdrs.METH_PATCH,
507 | url=url,
508 | **kwargs,
509 | )
510 |
511 | def delete(
512 | self,
513 | url: StrOrURL,
514 | raise_for_status: bool | None = None,
515 | **kwargs: Any,
516 | ) -> ChainMiddlewareContext:
517 | return self._make_request(
518 | method=hdrs.METH_DELETE,
519 | url=url,
520 | **kwargs,
521 | )
522 |
523 | async def close(self) -> None:
524 | await self._client.close()
525 | self._closed = True
526 |
527 | def _make_request(
528 | self,
529 | method: str,
530 | url: StrOrURL,
531 | raise_for_status: bool | None = None,
532 | **kwargs: Any,
533 | ) -> ChainMiddlewareContext:
534 | chain_request = ChainRequest(
535 | method=method,
536 | url=url,
537 | headers=kwargs.pop("headers", {}),
538 | trace_request_ctx=kwargs.pop("trace_request_ctx", None),
539 | kwargs=kwargs,
540 | )
541 |
542 | if raise_for_status is None:
543 | raise_for_status = self._raise_for_status
544 |
545 | end_of_line_middleware = EndOfLineChainMiddleware(
546 | request_func=self._client.request,
547 | logger=self._logger,
548 | raise_for_status=raise_for_status,
549 | )
550 |
551 | chain_callback: NextChainCallbackType = end_of_line_middleware.handle
552 |
553 | full_middleware_chain = reversed(self._middleware or [])
554 |
555 | for mw in full_middleware_chain:
556 | chain_callback = mw.handle_gen(chain_callback)
557 |
558 | return ChainMiddlewareContext(
559 | chain_callback=chain_callback,
560 | chain_request=chain_request,
561 | logger=self._logger,
562 | raise_for_status=raise_for_status,
563 | )
564 |
565 | async def __aenter__(self) -> "ChainMiddlewareClient":
566 | return self
567 |
568 | async def __aexit__(
569 | self,
570 | exc_type: type[BaseException] | None,
571 | exc_val: BaseException | None,
572 | exc_tb: TracebackType | None,
573 | ) -> None:
574 | await self.close()
575 |
576 | def __del__(self) -> None:
577 | if getattr(self, "_closed", None) is None:
578 | # in case object was not initialized (__init__ raised an exception)
579 | return
580 |
581 | if not self._closed:
582 | self._logger.warning("Aiohttp chain client was not closed")
583 |
--------------------------------------------------------------------------------
/src/social/graze/aip/atproto/pds.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any, Dict
2 | from aiohttp import ClientSession
3 |
4 |
5 | async def oauth_protected_resource(
6 | session: ClientSession, pds: str
7 | ) -> Optional[Dict[str, Any]]:
8 | async with session.get(f"{pds}/.well-known/oauth-protected-resource") as resp:
9 | if resp.status != 200:
10 | return None
11 | return await resp.json()
12 | return None
13 |
14 |
15 | async def oauth_authorization_server(
16 | session: ClientSession, authorization_server: str
17 | ) -> Optional[Dict[str, Any]]:
18 | async with session.get(
19 | f"{authorization_server}/.well-known/oauth-authorization-server"
20 | ) as resp:
21 | if resp.status != 200:
22 | return None
23 | return await resp.json()
24 | return None
25 |
--------------------------------------------------------------------------------
/src/social/graze/aip/model/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Database Models
3 |
4 | This package defines the database models for the AIP service using SQLAlchemy ORM.
5 | These models represent the persistent data structures for authentication, user
6 | information, and service health.
7 |
8 | Key Models:
9 | - base.py: Base SQLAlchemy model with common type definitions
10 | - app_password.py: Models for app password authentication
11 | - oauth.py: Models for OAuth sessions and token management
12 | - handles.py: Models for user handles and DIDs
13 | - health.py: Health monitoring model
14 |
15 | The data models follow these relationships:
16 | - Handle: Represents a user identity with DID, handle, and PDS location
17 | - OAuthRequest: Temporary storage for OAuth request data
18 | - OAuthSession: Active OAuth session with tokens and expiration info
19 | - AppPassword: App-specific password credentials
20 |
21 | Each model includes:
22 | - Creation and expiration timestamps
23 | - Secure storage of credentials (encrypted where appropriate)
24 | - Relationships to other models where needed
25 |
26 | The models use SQLAlchemy's async interface for non-blocking database operations
27 | and include methods for common operations like upserts.
28 | """
--------------------------------------------------------------------------------
/src/social/graze/aip/model/app_password.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from sqlalchemy import String, DateTime
3 | from sqlalchemy.orm import Mapped, mapped_column
4 |
5 | from social.graze.aip.model.base import Base, str1024
6 |
7 |
8 | class AppPassword(Base):
9 | __tablename__ = "atproto_app_passwords"
10 |
11 | guid: Mapped[str] = mapped_column(String(64), primary_key=True)
12 | app_password: Mapped[str]
13 | created_at: Mapped[datetime] = mapped_column(
14 | DateTime(timezone=True), nullable=False
15 | )
16 |
17 |
18 | class AppPasswordSession(Base):
19 | __tablename__ = "atproto_app_password_sessions"
20 |
21 | guid: Mapped[str] = mapped_column(String(64), primary_key=True)
22 | access_token: Mapped[str1024]
23 | access_token_expires_at: Mapped[datetime] = mapped_column(
24 | DateTime(timezone=True), nullable=False
25 | )
26 | refresh_token: Mapped[str]
27 | refresh_token_expires_at: Mapped[datetime] = mapped_column(
28 | DateTime(timezone=True), nullable=False
29 | )
30 | created_at: Mapped[datetime] = mapped_column(
31 | DateTime(timezone=True), nullable=False
32 | )
33 |
--------------------------------------------------------------------------------
/src/social/graze/aip/model/base.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import String, orm
2 | from sqlalchemy.orm import mapped_column
3 |
4 | from typing_extensions import Annotated
5 |
6 | str512 = Annotated[str, 512]
7 | str1024 = Annotated[str, 1024]
8 | guidpk = Annotated[str, mapped_column(String(512), primary_key=True)]
9 |
10 |
11 | class Base(orm.DeclarativeBase):
12 | type_annotation_map = {
13 | str512: String(512),
14 | str1024: String(1024),
15 | guidpk: String(512),
16 | }
17 |
--------------------------------------------------------------------------------
/src/social/graze/aip/model/handles.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy.orm import Mapped
2 | from sqlalchemy.dialects.postgresql import insert
3 | from ulid import ULID
4 |
5 | from social.graze.aip.model.base import Base, str512, guidpk
6 |
7 |
8 | class Handle(Base):
9 | __tablename__ = "handles"
10 |
11 | guid: Mapped[guidpk]
12 | did: Mapped[str512]
13 | handle: Mapped[str512]
14 | pds: Mapped[str512]
15 |
16 |
17 | def upsert_handle_stmt(did: str, handle: str, pds: str):
18 | return (
19 | insert(Handle)
20 | .values(
21 | [
22 | {
23 | "guid": str(ULID()),
24 | "did": did,
25 | "handle": handle,
26 | "pds": pds,
27 | }
28 | ]
29 | )
30 | .on_conflict_do_update(
31 | index_elements=["did"],
32 | set_={
33 | "handle": handle,
34 | "pds": pds,
35 | },
36 | )
37 | .returning(Handle.guid)
38 | )
39 |
--------------------------------------------------------------------------------
/src/social/graze/aip/model/health.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 |
4 | class HealthGauge:
5 | """
6 | This is a makeshift health check system.
7 |
8 | This health gauge is used to track the health of the application and provide a way to return a somewhat meaningful
9 | value to readiness probes.
10 |
11 | The health gauge is a simple counter that can be incremented and decremented. When an exception occurs that is
12 | outside of regular application flow-control (an actual error and not a warning), then the counter is
13 | incremented. As time goes on, the counter decrements. When a burst of exceptions occurs, the is_healthy method will
14 | return false, triggering a failed readiness check.
15 | """
16 |
17 | def __init__(self, value: int = 0, health_threshold: int = 100) -> None:
18 | self._value = value
19 | self._health_threshold = health_threshold
20 | self._lock = asyncio.Lock()
21 |
22 | async def womp(self, d=1) -> int:
23 | async with self._lock:
24 | self._value += int(d)
25 | return self._value
26 |
27 | async def tick(self) -> None:
28 | async with self._lock:
29 | if self._value > 0:
30 | self._value -= 1
31 |
32 | async def is_healthy(self) -> bool:
33 | async with self._lock:
34 | return self._value <= self._health_threshold
35 |
--------------------------------------------------------------------------------
/src/social/graze/aip/model/oauth.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from datetime import datetime
3 | from sqlalchemy import Integer, String, DateTime
4 | from sqlalchemy.orm import Mapped, mapped_column
5 | from sqlalchemy.dialects.postgresql import JSON, insert
6 |
7 | from social.graze.aip.model.base import Base, str512, str1024
8 |
9 |
10 | class OAuthRequest(Base):
11 | __tablename__ = "oauth_requests"
12 |
13 | oauth_state: Mapped[str] = mapped_column(String(64), primary_key=True)
14 | issuer: Mapped[str512]
15 | guid: Mapped[str512]
16 | pkce_verifier: Mapped[str]
17 | secret_jwk_id: Mapped[str]
18 | dpop_jwk: Mapped[Any] = mapped_column(JSON, nullable=False)
19 | destination: Mapped[str]
20 | created_at: Mapped[datetime] = mapped_column(
21 | DateTime(timezone=True), nullable=False
22 | )
23 | expires_at: Mapped[datetime] = mapped_column(
24 | DateTime(timezone=True), nullable=False
25 | )
26 |
27 |
28 | class OAuthSession(Base):
29 | __tablename__ = "oauth_sessions"
30 |
31 | session_group: Mapped[str] = mapped_column(String(64), primary_key=True)
32 | issuer: Mapped[str512]
33 | guid: Mapped[str512]
34 | access_token: Mapped[str1024]
35 | refresh_token: Mapped[str512]
36 | secret_jwk_id: Mapped[str512]
37 | dpop_jwk: Mapped[Any] = mapped_column(JSON, nullable=False)
38 | created_at: Mapped[datetime] = mapped_column(
39 | DateTime(timezone=True), nullable=False
40 | )
41 | access_token_expires_at: Mapped[datetime] = mapped_column(
42 | DateTime(timezone=True), nullable=False
43 | )
44 | hard_expires_at: Mapped[datetime] = mapped_column(
45 | DateTime(timezone=True), nullable=False
46 | )
47 |
48 |
49 | class Permission(Base):
50 | __tablename__ = "guid_permissions"
51 |
52 | guid: Mapped[str] = mapped_column(String(512), primary_key=True)
53 | target_guid: Mapped[str] = mapped_column(String(512), primary_key=True)
54 | permission: Mapped[int] = mapped_column(Integer, nullable=False)
55 | created_at: Mapped[datetime] = mapped_column(
56 | DateTime(timezone=True), nullable=False
57 | )
58 |
59 |
60 | def upsert_permission_stmt(
61 | guid: str, target_guid: str, permission: int, created_at: datetime
62 | ):
63 | return (
64 | insert(Permission)
65 | .values(
66 | [
67 | {
68 | "guid": guid,
69 | "target_guid": target_guid,
70 | "permission": permission,
71 | "created_at": created_at,
72 | }
73 | ]
74 | )
75 | .on_conflict_do_update(
76 | index_elements=["guid", "target_guid"],
77 | set_={
78 | "permission": permission,
79 | },
80 | )
81 | )
82 |
--------------------------------------------------------------------------------
/src/social/graze/aip/resolve/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Identity Resolution
3 |
4 | This package provides utilities for resolving AT Protocol identifiers (DIDs, handles)
5 | to their canonical forms, implementing both DNS-based and HTTP-based resolution methods.
6 |
7 | Key Components:
8 | - handle.py: Handle resolution implementation
9 | - __main__.py: CLI interface for resolution
10 |
11 | Resolution Types:
12 | 1. Handle Resolution
13 | - DNS-based resolution via TXT records (_atproto.{handle})
14 | - HTTP-based resolution via well-known endpoints (.well-known/atproto-did)
15 |
16 | 2. DID Resolution
17 | - did:plc method resolution via PLC directory
18 | - did:web method resolution via well-known endpoints
19 |
20 | The resolution flow typically follows these steps:
21 | 1. Parse the input to determine if it's a handle or DID
22 | 2. For handles, attempt DNS resolution first, then HTTP
23 | 3. For DIDs, use the appropriate method based on the DID type
24 | 4. Return the resolved canonical data (DID, handle, PDS location)
25 |
26 | This implementation follows the AT Protocol specification for identity
27 | resolution, ensuring compatibility with the broader ecosystem.
28 | """
--------------------------------------------------------------------------------
/src/social/graze/aip/resolve/__main__.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import argparse
3 | import aiohttp
4 | import asyncio
5 | import logging
6 | import sentry_sdk
7 |
8 | from social.graze.aip.resolve.handle import resolve_subject
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | async def realMain() -> None:
14 | parser = argparse.ArgumentParser(prog="resolve", description="Resolve handles")
15 | parser.add_argument("subject", nargs="+", help="The subject(s) to resolve.")
16 | parser.add_argument(
17 | "--plc-hostname",
18 | default="plc.directory",
19 | help="The PLC hostname to use for resolving did-method-plc DIDs.",
20 | )
21 |
22 | args = vars(parser.parse_args())
23 |
24 | subjects: List[str] = args.get("subject", [])
25 |
26 | async with aiohttp.ClientSession() as session:
27 | for subject in subjects:
28 | try:
29 | resolved_handle = await resolve_subject(
30 | session, args.get("plc_hostname", str), subject
31 | )
32 | print(f"resolved_handle {resolved_handle}")
33 | except Exception as e:
34 | sentry_sdk.capture_exception(e)
35 | logging.exception("Exception resolving subject %s", subject)
36 |
37 |
38 | def main() -> None:
39 | asyncio.run(realMain())
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/src/social/graze/aip/resolve/handle.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 | from aiohttp import ClientSession
3 | from pydantic import BaseModel
4 | from aiodns import DNSResolver
5 | from typing import Optional, Any, Dict
6 | import sentry_sdk
7 |
8 |
9 | class SubjectType(IntEnum):
10 | did_method_plc = 1
11 | did_method_web = 2
12 | hostname = 3
13 |
14 |
15 | class ParsedSubject(BaseModel):
16 | subject_type: SubjectType
17 | subject: str
18 |
19 |
20 | class ResolvedSubject(BaseModel):
21 | did: str
22 | handle: str
23 | pds: str
24 |
25 | # @model_validator(mode='after')
26 | # def validate_also_known_as(self) -> 'ResolvedSubject':
27 | # for value in self.also_known_as:
28 | # if subject.startswith("at://did:")
29 | # continue
30 | # if subject.startswith("at://")
31 | # return value
32 | # elif subject.startswith("https://")
33 | # continue
34 | # raise ValueError("No handles returned")
35 |
36 |
37 | async def resolve_handle_dns(handle: str) -> Optional[str]:
38 | # TODO: Wrap hickory-dns and use that
39 | resolver = DNSResolver()
40 | try:
41 | results = await resolver.query(f"_atproto.{handle}", "TXT")
42 | except Exception as e:
43 | sentry_sdk.capture_exception(e)
44 | return None
45 | first_result = next(iter(results or []), None)
46 | if first_result is not None:
47 | return first_result.text.removeprefix("did=")
48 | return None
49 |
50 |
51 | async def resolve_handle_http(session: ClientSession, handle: str) -> Optional[str]:
52 | async with session.get(f"https://{handle}/.well-known/atproto-did") as resp:
53 | if resp.status != 200:
54 | return None
55 | body = await resp.text()
56 | if body is not None:
57 | return body
58 | return None
59 |
60 |
61 | async def resolve_handle(session: ClientSession, handle: str) -> Optional[str]:
62 | did: Optional[str] = None
63 |
64 | did = await resolve_handle_dns(handle)
65 | if did is not None:
66 | return did
67 |
68 | return await resolve_handle_http(session, handle)
69 |
70 | # Nick: Alternatively, we could use an async task group to do these in parallel.
71 | #
72 | # async with asyncio.TaskGroup() as tg:
73 | # dns_result = tg.create_task(resolve_handle_dns(handle))
74 | # http_result = tg.create_task(resolve_handle_http(session, handle))
75 | # dns_result = dns_result.result()
76 | # http_result = http_result.result()
77 | # if dns_result is not None:
78 | # return dns_result
79 | # return http_result
80 |
81 |
82 | def handle_predicate(value: str) -> bool:
83 | return value.startswith("at://")
84 |
85 |
86 | def pds_predicate(value: Dict[str, Any]) -> bool:
87 | return (
88 | value.get("type", None) == "AtprotoPersonalDataServer"
89 | and "serviceEndpoint" in value
90 | )
91 |
92 |
93 | async def resolve_did_method_plc(
94 | plc_directory: str, session: ClientSession, did: str
95 | ) -> Optional[ResolvedSubject]:
96 | async with session.get(f"https://{plc_directory}/{did}") as resp:
97 | if resp.status != 200:
98 | return None
99 | body = await resp.json()
100 | if body is None:
101 | return None
102 | handle = next(filter(handle_predicate, body.get("alsoKnownAs", [])), None)
103 | pds = next(filter(pds_predicate, body.get("service", [])), None)
104 | if handle is not None and pds is not None:
105 | return ResolvedSubject(
106 | did=did,
107 | handle=handle.removeprefix("at://"),
108 | pds=pds.get("serviceEndpoint"),
109 | )
110 | return None
111 |
112 |
113 | async def resolve_did_method_web(
114 | session: ClientSession, did: str
115 | ) -> Optional[ResolvedSubject]:
116 |
117 | parts = did.removeprefix("did:web:").split(":")
118 | if len(parts) == 0:
119 | return None
120 |
121 | if len(parts) == 1:
122 | parts.append(".well-known")
123 |
124 | url = "https://{inner}/did.json".format(inner="/".join(parts))
125 |
126 | async with session.get(url) as resp:
127 | if resp.status != 200:
128 | return None
129 | body = await resp.json()
130 | if body is None:
131 | return None
132 | handle = next(filter(handle_predicate, body.get("alsoKnownAs", [])), None)
133 | pds = next(filter(pds_predicate, body.get("service", [])), None)
134 | if handle is not None and pds is not None:
135 | return ResolvedSubject(
136 | did=did,
137 | handle=handle.removeprefix("at://"),
138 | pds=pds.get("serviceEndpoint"),
139 | )
140 | return None
141 |
142 |
143 | async def resolve_did(
144 | session: ClientSession, plc_hostname: str, did: str
145 | ) -> Optional[ResolvedSubject]:
146 | if did.startswith("did:plc:"):
147 | return await resolve_did_method_plc(plc_hostname, session, did)
148 | elif did.startswith("did:web:"):
149 | return await resolve_did_method_web(session, did)
150 | return None
151 |
152 |
153 | async def resolve_subject(
154 | session: ClientSession, plc_hostname: str, subject: str
155 | ) -> Optional[ResolvedSubject]:
156 | parsed_subject = parse_input(subject)
157 | if parsed_subject is None:
158 | return None
159 |
160 | did: Optional[str] = None
161 | if parsed_subject.subject_type == SubjectType.hostname:
162 | did = await resolve_handle(session, parsed_subject.subject)
163 | elif parsed_subject.subject_type == SubjectType.did_method_plc:
164 | did = parsed_subject.subject
165 | elif parsed_subject.subject_type == SubjectType.did_method_web:
166 | did = parsed_subject.subject
167 |
168 | if did is None:
169 | return None
170 |
171 | return await resolve_did(session, plc_hostname, did)
172 |
173 |
174 | def parse_input(subject: str) -> Optional[ParsedSubject]:
175 | subject = subject.strip()
176 | subject = subject.removeprefix("at://")
177 | subject = subject.removeprefix("@")
178 |
179 | if subject.startswith("did:plc:"):
180 | return ParsedSubject(subject_type=SubjectType.did_method_plc, subject=subject)
181 | elif subject.startswith("did:web:"):
182 | return ParsedSubject(subject_type=SubjectType.did_method_web, subject=subject)
183 |
184 | # TODO: Validate this hostname
185 | return ParsedSubject(subject_type=SubjectType.hostname, subject=subject)
186 |
--------------------------------------------------------------------------------
/telegraf.conf:
--------------------------------------------------------------------------------
1 | # Global agent configuration
2 | [agent]
3 | interval = "10s"
4 | round_interval = true
5 | metric_batch_size = 1000
6 | metric_buffer_limit = 10000
7 | collection_jitter = "0s"
8 | flush_interval = "10s"
9 | flush_jitter = "0s"
10 | precision = ""
11 | hostname = "aip_telegraf"
12 | omit_hostname = false
13 |
14 | # Input plugins
15 | [[inputs.cpu]]
16 | percpu = true
17 | totalcpu = true
18 | collect_cpu_time = false
19 | report_active = false
20 |
21 | [[inputs.mem]]
22 | fieldpass = ["used_percent"]
23 |
24 | [[inputs.disk]]
25 | mount_points = ["/"]
26 | ignore_fs = ["tmpfs", "devtmpfs"]
27 |
28 | [[inputs.statsd]]
29 | service_address = ":8125"
30 |
31 | [[outputs.file]]
32 | files = ["/tmp/telegraf.out"]
33 | data_format = "json"
34 |
--------------------------------------------------------------------------------
/templates/atproto_debug.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block content %}
3 | Debug
6 |
7 |
8 | {{ auth_token.header }}
9 | {{ auth_token.claims }}
10 | {{ handle.__dict__ }}
11 | {{ oauth_session.__dict__ }}
14 | {% endif %} {% if brand_name %}
15 |
19 | {{ brand_name }}
20 |
21 | {% endif %}
22 |
The Graze Social Authentication, Identity, and Permissions Server
7 | 8 |The Graze Social Authentication, Identity, and Permissions Server
6 | 7 |