├── abode ├── __init__.py ├── lib │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ └── test_query.py │ └── query.py ├── schema │ ├── changelog.sql │ ├── users.sql │ ├── emoji.sql │ ├── guilds.sql │ ├── channels.sql │ └── messages.sql ├── client.py ├── backfill.py ├── db │ ├── users.py │ ├── guilds.py │ ├── emoji.py │ ├── messages.py │ ├── channels.py │ └── __init__.py ├── cli.py ├── events.py └── server.py ├── .gitignore ├── requirements.txt ├── frontend ├── templates │ ├── guild.html │ ├── user.html │ ├── message.html │ ├── emoji.html │ ├── channel.html │ └── results.html ├── styles.css ├── index.html └── script.js ├── .vscode └── settings.json └── README.md /abode/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /abode/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /abode/lib/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | config.json 3 | test.db* 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asyncpg 2 | discord.py 3 | sanic 4 | -------------------------------------------------------------------------------- /frontend/templates/guild.html: -------------------------------------------------------------------------------- 1 |
2 | {{ row.name }} 3 |
-------------------------------------------------------------------------------- /frontend/templates/user.html: -------------------------------------------------------------------------------- 1 |
2 | {{ row.name }}#{{ row.discriminator|discrim }} 3 |
-------------------------------------------------------------------------------- /frontend/templates/message.html: -------------------------------------------------------------------------------- 1 |
2 | {{ row.user.name }}#{{ row.user.discriminator|discrim }} {{ row.content }} 3 |
-------------------------------------------------------------------------------- /frontend/templates/emoji.html: -------------------------------------------------------------------------------- 1 |
2 | 3 | {{ row.name }} 4 |
-------------------------------------------------------------------------------- /abode/schema/changelog.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS changelog ( 2 | entity_id text, 3 | version integer, 4 | field text, 5 | value text, 6 | 7 | PRIMARY KEY (entity_id, version) 8 | ); -------------------------------------------------------------------------------- /frontend/templates/channel.html: -------------------------------------------------------------------------------- 1 |
2 | {% if row.type == 0 or row.type == 2 or row.type == 4 %} 3 | {{ row.name }} ({{ row.guild.name }}) 4 | {% else %} 5 | {{ row.name }} 6 | {% endif %} 7 |
-------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/andrei/code/python/abode/.venv/bin/python3.8", 3 | "python.linting.pylintEnabled": false, 4 | "python.linting.flake8Enabled": true, 5 | "python.linting.enabled": true 6 | } -------------------------------------------------------------------------------- /abode/schema/users.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS users ( 2 | id BIGINT PRIMARY KEY, 3 | name text NOT NULL, 4 | discriminator smallint NOT NULL, 5 | avatar text, 6 | bot boolean, 7 | system boolean 8 | ); 9 | 10 | CREATE INDEX IF NOT EXISTS users_name_trgm ON users USING gin (name gin_trgm_ops); -------------------------------------------------------------------------------- /abode/schema/emoji.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS emoji ( 2 | id BIGINT PRIMARY KEY, 3 | guild_id BIGINT, 4 | author_id BIGINT, 5 | name text, 6 | require_colons boolean, 7 | managed boolean, 8 | animated boolean, 9 | roles jsonb, 10 | created_at timestamp, 11 | deleted boolean 12 | ); -------------------------------------------------------------------------------- /frontend/templates/results.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | {% for field in fields %} 4 | 5 | {% endfor %} 6 | 7 | 8 | {% for row in rows %} 9 | 10 | {% for col in row %} 11 | 14 | {% endfor %} 15 | 16 | {% endfor %} 17 |
{{ field }}
12 |

{{ col|safe }}

13 |
-------------------------------------------------------------------------------- /frontend/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: sans-serif; 3 | font-size: 0.8rem; 4 | padding-bottom: 0.9rem; 5 | } 6 | 7 | .search-container { 8 | text-align: center; 9 | margin-bottom: 12px; 10 | } 11 | 12 | .search-container input { 13 | width: 33%; 14 | } 15 | 16 | table, td, th { 17 | border: 1px solid #ddd; 18 | text-align: left; 19 | } 20 | 21 | table { 22 | border-collapse: collapse; 23 | width: 100%; 24 | } 25 | 26 | th, td { 27 | padding: 8px; 28 | } -------------------------------------------------------------------------------- /abode/schema/guilds.sql: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION IF NOT EXISTS pg_trgm; 2 | 3 | CREATE TABLE IF NOT EXISTS guilds ( 4 | id BIGINT PRIMARY KEY, 5 | owner_id BIGINT, 6 | name text, 7 | region text, 8 | icon text, 9 | is_currently_joined boolean, 10 | features jsonb, 11 | banner text, 12 | description text, 13 | splash text, 14 | discovery_splash text, 15 | premium_tier smallint, 16 | premium_subscription_count int 17 | ); 18 | 19 | CREATE INDEX IF NOT EXISTS guilds_name_trgm ON guilds USING gin (name gin_trgm_ops); 20 | CREATE INDEX IF NOT EXISTS guilds_owner_id_idx ON guilds (owner_id); -------------------------------------------------------------------------------- /abode/schema/channels.sql: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION IF NOT EXISTS pg_trgm; 2 | 3 | CREATE TABLE IF NOT EXISTS channels ( 4 | id BIGINT PRIMARY KEY, 5 | type SMALLINT NOT NULL, 6 | name text, 7 | topic text, 8 | guild_id BIGINT, 9 | category_id BIGINT, 10 | position BIGINT, 11 | slowmode_delay BIGINT, 12 | overwrites jsonb, 13 | bitrate INT, 14 | user_limit SMALLINT, 15 | recipients jsonb, 16 | owner_id BIGINT, 17 | icon text 18 | ); 19 | 20 | CREATE INDEX IF NOT EXISTS channels_name_trgm ON channels USING gin (name gin_trgm_ops); 21 | CREATE INDEX IF NOT EXISTS channels_type_idx ON channels (type); 22 | CREATE INDEX IF NOT EXISTS channels_guild_id_idx ON channels (guild_id); -------------------------------------------------------------------------------- /frontend/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 |
13 |
14 | 15 | 17 |
18 |
19 |
20 |
21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /abode/client.py: -------------------------------------------------------------------------------- 1 | import discord 2 | import functools 3 | 4 | 5 | from .events import ( 6 | on_ready, 7 | on_guild_join, 8 | on_guild_update, 9 | on_guild_remove, 10 | on_message, 11 | ) 12 | 13 | 14 | USER_AGENT = "Mozilla/5.0 (X11; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" 15 | 16 | 17 | def setup_client(config, loop): 18 | client = discord.Client(loop=loop) 19 | 20 | def bind(fn): 21 | return functools.partial(fn, client) 22 | 23 | client.on_ready = bind(on_ready) 24 | client.on_guild_join = bind(on_guild_join) 25 | client.on_guild_update = bind(on_guild_update) 26 | client.on_guild_remove = bind(on_guild_remove) 27 | client.on_message = bind(on_message) 28 | 29 | # TODO: cf cookies, IDENTIFY information 30 | client.http.user_agent = USER_AGENT 31 | return client.start(config["token"], bot=False), client.logout() 32 | -------------------------------------------------------------------------------- /abode/schema/messages.sql: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION IF NOT EXISTS pg_trgm; 2 | 3 | CREATE TABLE IF NOT EXISTS messages ( 4 | id BIGINT PRIMARY KEY, 5 | guild_id BIGINT, 6 | channel_id BIGINT NOT NULL, 7 | author_id BIGINT NOT NULL, 8 | webhook_id BIGINT, 9 | 10 | tts boolean NOT NULL, 11 | type integer NOT NULL, 12 | content text NOT NULL, 13 | embeds jsonb, 14 | mention_everyone boolean NOT NULL, 15 | flags integer NOT NULL, 16 | activity jsonb, 17 | application jsonb, 18 | 19 | created_at timestamp NOT NULL, 20 | edited_at timestamp, 21 | deleted boolean NOT NULL 22 | ); 23 | 24 | CREATE INDEX IF NOT EXISTS messages_content_trgm ON messages USING gin (content gin_trgm_ops); 25 | CREATE INDEX IF NOT EXISTS messages_content_fts ON messages USING gin (to_tsvector('english', content)); 26 | CREATE INDEX IF NOT EXISTS messages_guild_id_idx ON messages (guild_id); 27 | CREATE INDEX IF NOT EXISTS messages_channel_id_idx ON messages (channel_id); 28 | CREATE INDEX IF NOT EXISTS messages_author_id_idx ON messages (author_id); -------------------------------------------------------------------------------- /abode/backfill.py: -------------------------------------------------------------------------------- 1 | from .db.messages import insert_message 2 | from .db.channels import upsert_channel 3 | from discord import TextChannel 4 | 5 | 6 | async def backfill_channel(channel): 7 | print(f"Backfilling channel {channel.id}") 8 | await upsert_channel(channel) 9 | idx = 0 10 | async for message in channel.history(limit=None, oldest_first=True): 11 | idx += 1 12 | try: 13 | await insert_message(message) 14 | except Exception as e: 15 | print(f" [{channel.id}] failed to backfill message {message.id}: {e}") 16 | 17 | if idx % 5000 == 0: 18 | print(f" [{channel.id}] {idx} messages") 19 | print(f"Done backfilling channel {channel.id}, scanned {idx}") 20 | 21 | 22 | async def backfill_guild(guild): 23 | print(f"Backfilling guild {guild.id}") 24 | for channel in guild.channels: 25 | if isinstance(channel, TextChannel): 26 | try: 27 | await upsert_channel(channel) 28 | await backfill_channel(channel) 29 | except Exception as e: 30 | print(f"failed to backfill channel {channel.id}: {e}") 31 | -------------------------------------------------------------------------------- /abode/db/users.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from datetime import datetime 4 | from . import ( 5 | with_conn, 6 | build_insert_query, 7 | build_select_query, 8 | JSONB, 9 | Snowflake, 10 | BaseModel, 11 | ) 12 | 13 | 14 | @dataclass 15 | class User(BaseModel): 16 | id: Snowflake 17 | name: str 18 | discriminator: int 19 | avatar: Optional[str] 20 | bot: bool 21 | system: bool 22 | 23 | _pk = "id" 24 | 25 | @classmethod 26 | def from_discord(cls, user): 27 | return cls( 28 | id=user.id, 29 | name=user.name, 30 | discriminator=int(user.discriminator), 31 | avatar=user.avatar, 32 | bot=user.bot, 33 | system=user.system, 34 | ) 35 | 36 | 37 | @with_conn 38 | async def upsert_user(conn, user): 39 | user = User.from_discord(user) 40 | 41 | # TODO: calculate diff 42 | existing_user = await conn.fetchrow(build_select_query(user, "id = $1"), user.id) 43 | 44 | query, args = build_insert_query(user, upsert=True) 45 | await conn.execute(query, *args) 46 | 47 | if existing_user is not None: 48 | existing_user = User.from_record(existing_user) 49 | diff = list(user.diff(existing_user)) 50 | if diff: 51 | print(f"[user] diff is {diff}") 52 | -------------------------------------------------------------------------------- /abode/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import asyncio 4 | import argparse 5 | 6 | from .db import init_db, close_db 7 | from .client import setup_client 8 | from .server import setup_server 9 | 10 | parser = argparse.ArgumentParser("abode") 11 | parser.add_argument("--run-api", action="store_true") 12 | parser.add_argument("--run-client", action="store_true") 13 | 14 | 15 | def main(): 16 | args = parser.parse_args() 17 | 18 | with open(os.getenv("ABODE_CONFIG_PATH", "config.json"), "r") as f: 19 | config = json.load(f) 20 | 21 | start_tasks = [] 22 | cleanup_tasks = [] 23 | 24 | if args.run_api: 25 | server_start = setup_server(config) 26 | start_tasks.append(server_start) 27 | 28 | loop = asyncio.get_event_loop() 29 | loop.create_task(init_db(config, loop)) 30 | cleanup_tasks.append(close_db()) 31 | 32 | if args.run_client: 33 | client_start, client_logout = setup_client(config, loop) 34 | start_tasks.append(client_start) 35 | cleanup_tasks.append(client_logout) 36 | 37 | try: 38 | for task in start_tasks: 39 | loop.create_task(task) 40 | loop.run_forever() 41 | except KeyboardInterrupt: 42 | for task in cleanup_tasks: 43 | loop.run_until_complete(task) 44 | finally: 45 | loop.close() 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /frontend/script.js: -------------------------------------------------------------------------------- 1 | const models = ["message", "emoji", "channel", "guild", "user"]; 2 | const templates = { "results": null }; 3 | 4 | const env = nunjucks.configure({ autoescape: true }); 5 | env.addFilter("discrim", (str) => { 6 | return ('0000' + str).slice(-4); 7 | }); 8 | 9 | for (const model of Object.keys(templates).concat(models)) { 10 | fetch(`/templates/${model}.html`).then((response) => { 11 | return response.text(); 12 | }).then((body) => { 13 | templates[model] = body; 14 | console.log(body); 15 | }); 16 | } 17 | 18 | function renderModelRow(name, row) { 19 | return nunjucks.renderString(templates[name], { row }); 20 | } 21 | 22 | function renderResult(results, fields) { 23 | $("#error").hide(); 24 | $("#results").html(nunjucks.renderString(templates["results"], { 25 | fields: fields, 26 | rows: results, 27 | })) 28 | } 29 | 30 | function renderError(error) { 31 | $("#error").show().text(error); 32 | } 33 | 34 | function handleSearchChange(event) { 35 | var currentModel = $("#model option:selected").text(); 36 | var query = { 37 | "query": $(event.target).val(), 38 | "limit": 1000, 39 | "order_by": "id", 40 | "order_dir": "DESC", 41 | }; 42 | 43 | fetch(`/search/${currentModel}`, { 44 | method: 'POST', 45 | body: JSON.stringify(query), 46 | }).then((response) => { 47 | return response.json(); 48 | }).then((data) => { 49 | console.log("[Debug]", data); 50 | if (data.results && !data.error) { 51 | renderResult(data.results, data.fields); 52 | } else if (data.error) { 53 | renderError(data.error); 54 | } 55 | }); 56 | } 57 | 58 | $(document).ready(function () { 59 | for (model of models) { 60 | $("#model").append(``); 61 | } 62 | 63 | var val = $("#search").val(); 64 | $("#search").focus().val("").val(val); 65 | $("#search").keyup((e) => { 66 | if (e.keyCode == 13) { 67 | handleSearchChange(e); 68 | } 69 | }); 70 | }); -------------------------------------------------------------------------------- /abode/db/guilds.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields 2 | from typing import Optional 3 | from . import ( 4 | with_conn, 5 | build_insert_query, 6 | build_select_query, 7 | convert_to_type, 8 | Snowflake, 9 | BaseModel, 10 | JSONB, 11 | ) 12 | from .users import User 13 | 14 | 15 | @dataclass 16 | class Guild(BaseModel): 17 | id: Snowflake 18 | owner_id: Snowflake 19 | name: str 20 | region: str 21 | icon: Optional[str] 22 | features: JSONB 23 | banner: Optional[str] 24 | description: Optional[str] 25 | splash: Optional[str] 26 | discovery_splash: Optional[str] 27 | premium_tier: int 28 | premium_subscription_count: int 29 | is_currently_joined: bool = None 30 | 31 | _pk = "id" 32 | _refs = {"owner": (User, ("owner_id", "id"), True)} 33 | _fts = set() 34 | 35 | @classmethod 36 | def from_attrs(cls, guild, is_currently_joined=None): 37 | kwargs = {"is_currently_joined": is_currently_joined} 38 | for field in fields(cls): 39 | if not hasattr(guild, field.name): 40 | continue 41 | 42 | kwargs[field.name] = convert_to_type(getattr(guild, field.name), field.type) 43 | return cls(**kwargs) 44 | 45 | 46 | @with_conn 47 | async def upsert_guild(conn, guild, is_currently_joined=None): 48 | from .emoji import upsert_emoji 49 | from .users import upsert_user 50 | from .channels import upsert_channel 51 | 52 | new_guild = Guild.from_attrs(guild, is_currently_joined=is_currently_joined) 53 | 54 | # TODO: calculate diff 55 | existing_guild = await conn.fetchrow( 56 | build_select_query(new_guild, "id = $1"), new_guild.id 57 | ) 58 | 59 | query, args = build_insert_query(new_guild, upsert=True) 60 | await conn.execute(query, *args) 61 | 62 | if existing_guild is not None: 63 | existing_guild = Guild.from_record(existing_guild) 64 | diff = list(new_guild.diff(existing_guild)) 65 | if diff: 66 | print(f"[guilds] diff is {diff}") 67 | 68 | for channel in guild.channels: 69 | await upsert_channel(channel, conn=conn) 70 | 71 | for emoji in guild.emojis: 72 | await upsert_emoji(emoji, conn=conn) 73 | 74 | for member in guild.members: 75 | await upsert_user(member, conn=conn) 76 | -------------------------------------------------------------------------------- /abode/db/emoji.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from datetime import datetime 4 | from . import ( 5 | with_conn, 6 | build_insert_query, 7 | build_select_query, 8 | JSONB, 9 | Snowflake, 10 | BaseModel, 11 | ) 12 | from .guilds import Guild 13 | 14 | 15 | @dataclass 16 | class Emoji(BaseModel): 17 | id: Snowflake 18 | guild_id: Snowflake 19 | author_id: Optional[Snowflake] 20 | name: str 21 | require_colons: bool 22 | animated: bool 23 | managed: bool 24 | roles: JSONB 25 | created_at: datetime 26 | deleted: bool 27 | 28 | _pk = "id" 29 | _table_name = "emoji" 30 | _refs = {"guild": (Guild, ("guild_id", "id"), True)} 31 | _virtual_fields = {"image": ("id", "animated"), "image_url": ("id", "animated")} 32 | 33 | @staticmethod 34 | def image(id, animated): 35 | # TODO: leaky, idk how to avoid 36 | ext = "gif" if animated else "png" 37 | return f'' 38 | 39 | @staticmethod 40 | def image_url(id, animated): 41 | ext = "gif" if animated else "png" 42 | return f"https://cdn.discordapp.com/emojis/{id}.{ext}" 43 | 44 | @classmethod 45 | def from_discord(cls, emoji, deleted=False): 46 | return cls( 47 | id=emoji.id, 48 | guild_id=emoji.guild.id, 49 | author_id=emoji.user.id if emoji.user else None, 50 | name=emoji.name, 51 | require_colons=bool(emoji.require_colons), 52 | managed=bool(emoji.managed), 53 | animated=bool(emoji.animated), 54 | roles=[str(i.id) for i in emoji.roles], 55 | created_at=emoji.created_at, 56 | deleted=deleted, 57 | ) 58 | 59 | 60 | @with_conn 61 | async def upsert_emoji(conn, emoji): 62 | emoji = Emoji.from_discord(emoji) 63 | 64 | # TODO: calculate diff 65 | existing_emoji = await conn.fetchrow(build_select_query(emoji, "id = $1"), emoji.id) 66 | 67 | query, args = build_insert_query(emoji, upsert=True) 68 | await conn.execute(query, *args) 69 | 70 | if existing_emoji is not None: 71 | existing_emoji = Emoji.from_record(existing_emoji) 72 | diff = list(emoji.diff(existing_emoji)) 73 | if diff: 74 | print(f"[emoji] diff is {diff}") 75 | -------------------------------------------------------------------------------- /abode/events.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from abode.db.guilds import upsert_guild 3 | from abode.db.messages import insert_message 4 | from abode.db.channels import upsert_channel 5 | from abode.backfill import backfill_channel, backfill_guild 6 | 7 | 8 | async def backfill(client, message, args): 9 | snowflake = int(args) 10 | channel = client.get_channel(snowflake) 11 | if not channel: 12 | user = client.get_user(snowflake) 13 | channel = user.dm_channel or await user.create_dm() 14 | 15 | if channel: 16 | await message.add_reaction(client.get_emoji(580596825128697874)) 17 | await backfill_channel(channel) 18 | else: 19 | await message.add_reaction(client.get_emoji(494901623731126272)) 20 | 21 | 22 | async def backfillg(client, message, args): 23 | guild = client.get_guild(int(args)) 24 | if guild: 25 | await message.add_reaction(client.get_emoji(580596825128697874)) 26 | await backfill_guild(guild) 27 | else: 28 | await message.add_reaction(client.get_emoji(494901623731126272)) 29 | 30 | 31 | async def backfilldms(client, message, args): 32 | await message.add_reaction(client.get_emoji(580596825128697874)) 33 | for channel in client.private_channels: 34 | await backfill_channel(channel) 35 | 36 | 37 | commands = {"backfill": backfill, "backfillg": backfillg, "backfilldms": backfilldms} 38 | 39 | 40 | async def on_ready(client): 41 | print("Connected!") 42 | 43 | await asyncio.wait([upsert_channel(channel) for channel in client.private_channels]) 44 | 45 | await asyncio.wait( 46 | [upsert_guild(guild, is_currently_joined=True) for guild in client.guilds] 47 | ) 48 | 49 | 50 | async def on_guild_join(client, guild): 51 | await upsert_guild(guild, is_currently_joined=True) 52 | 53 | 54 | async def on_guild_update(client, old, new): 55 | await upsert_guild(new, is_currently_joined=True) 56 | 57 | 58 | async def on_guild_remove(client, guild): 59 | await upsert_guild(guild, is_currently_joined=False) 60 | 61 | 62 | async def on_message(client, message): 63 | await insert_message(message) 64 | 65 | if message.author.id == client.user.id: 66 | if message.content.startswith(";"): 67 | command, args = message.content.split(" ") 68 | fn = commands.get(command[1:]) 69 | if fn: 70 | await fn(client, message, args) 71 | -------------------------------------------------------------------------------- /abode/server.py: -------------------------------------------------------------------------------- 1 | import time 2 | from sanic import Sanic 3 | from sanic.response import json 4 | from abode.lib.query import compile_query, decode_query_results 5 | from abode.db.guilds import Guild 6 | from abode.db.messages import Message 7 | from abode.db.emoji import Emoji 8 | from abode.db.users import User 9 | from abode.db.channels import Channel 10 | from abode.db import get_pool 11 | from traceback import format_exc 12 | 13 | app = Sanic() 14 | app.static("/", "./frontend/index.html") 15 | app.static("/styles.css", "./frontend/styles.css") 16 | app.static("/script.js", "./frontend/script.js") 17 | app.static("/templates/", "./frontend/templates") 18 | 19 | 20 | SUPPORTED_MODELS = { 21 | "guild": Guild, 22 | "message": Message, 23 | "emoji": Emoji, 24 | "user": User, 25 | "channel": Channel, 26 | } 27 | 28 | 29 | def setup_server(config): 30 | return app.create_server( 31 | host=config.get("host", "0.0.0.0"), 32 | port=config.get("port", 9999), 33 | return_asyncio_server=True, 34 | ) 35 | 36 | 37 | @app.route("/search/", methods=["POST"]) 38 | async def route_search(request, model): 39 | model = SUPPORTED_MODELS.get(model) 40 | if not model: 41 | return json({"error": "unsupported model"}, status=404) 42 | 43 | limit = request.json.get("limit", 100) 44 | page = request.json.get("page", 1) 45 | order_by = request.json.get("order_by") 46 | order_dir = request.json.get("order_dir", "ASC") 47 | include_foreign_data = request.json.get("foreign_data", True) 48 | 49 | query = request.json.get("query", "") 50 | try: 51 | sql, args, models, return_fields = compile_query( 52 | query, 53 | model, 54 | limit=limit, 55 | offset=(limit * (page - 1)), 56 | order_by=order_by, 57 | order_dir=order_dir, 58 | include_foreign_data=include_foreign_data, 59 | returns=True, 60 | ) 61 | except Exception: 62 | return json({"error": format_exc()}) 63 | 64 | _debug = { 65 | "args": args, 66 | "sql": sql, 67 | "request": request.json, 68 | "models": [i.__name__ for i in models], 69 | } 70 | 71 | results = [] 72 | try: 73 | async with get_pool().acquire() as conn: 74 | start = time.time() 75 | results = await conn.fetch(sql, *args) 76 | _debug["ms"] = int((time.time() - start) * 1000) 77 | except Exception: 78 | return json({"error": format_exc(), "_debug": _debug}) 79 | 80 | try: 81 | results, field_names = decode_query_results(models, return_fields, results) 82 | return json({"results": results, "fields": field_names, "_debug": _debug}) 83 | except Exception: 84 | return json({"error": format_exc(), "_debug": _debug}) 85 | -------------------------------------------------------------------------------- /abode/db/messages.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields 2 | from datetime import datetime 3 | from typing import Optional 4 | from . import ( 5 | with_conn, 6 | build_insert_query, 7 | convert_to_type, 8 | JSONB, 9 | Snowflake, 10 | BaseModel, 11 | ) 12 | from .users import User, upsert_user 13 | from .guilds import Guild 14 | from .channels import Channel 15 | 16 | 17 | @dataclass 18 | class Message(BaseModel): 19 | id: Snowflake 20 | channel_id: Snowflake 21 | guild_id: Optional[Snowflake] 22 | author_id: Optional[Snowflake] 23 | webhook_id: Optional[Snowflake] 24 | tts: bool 25 | type: int 26 | content: str 27 | embeds: Optional[JSONB] 28 | mention_everyone: bool 29 | flags: int 30 | activity: Optional[JSONB] 31 | application: Optional[JSONB] 32 | created_at: datetime 33 | edited_at: Optional[datetime] 34 | deleted: bool 35 | 36 | # TODO: eventually these could be types, but I am far too baked for that 37 | # refactor at the moment. 38 | _pk = "id" 39 | _refs = { 40 | "guild": (Guild, ("guild_id", "id"), False), 41 | "author": (User, ("author_id", "id"), True), 42 | "channel": (Channel, ["channel_id", "id"], True), 43 | } 44 | _fts = {"content"} 45 | 46 | @classmethod 47 | def from_discord(cls, message, deleted=False): 48 | return cls( 49 | id=message.id, 50 | guild_id=message.guild.id if message.guild else None, 51 | channel_id=message.channel.id, 52 | author_id=message.author.id, 53 | webhook_id=message.webhook_id if message.webhook_id else None, 54 | tts=bool(message.tts), 55 | type=message.type.value, 56 | content=message.content, 57 | embeds=[i.to_dict() for i in message.embeds], 58 | mention_everyone=bool(message.mention_everyone), 59 | flags=message.flags.value, 60 | activity=message.activity, 61 | application=message.application, 62 | created_at=message.created_at, 63 | edited_at=message.edited_at, 64 | deleted=deleted, 65 | ) 66 | 67 | @classmethod 68 | def from_attrs(cls, instance, deleted=None): 69 | kwargs = {} 70 | for field in fields(cls): 71 | kwargs[field.name] = convert_to_type( 72 | getattr(instance, field.name), field.type 73 | ) 74 | if deleted: 75 | kwargs["deleted"] = deleted 76 | return cls(**kwargs) 77 | 78 | 79 | @with_conn 80 | async def insert_message(conn, message): 81 | new_message = Message.from_discord(message) 82 | 83 | query, args = build_insert_query(new_message, ignore_existing=True) 84 | try: 85 | await conn.execute(query, *args) 86 | except Exception: 87 | print(query) 88 | print(args) 89 | raise 90 | 91 | await upsert_user(message.author, conn=conn) 92 | 93 | 94 | @with_conn 95 | async def update_message(conn, message): 96 | pass 97 | -------------------------------------------------------------------------------- /abode/db/channels.py: -------------------------------------------------------------------------------- 1 | import discord 2 | from dataclasses import dataclass 3 | from typing import Optional, List 4 | from .guilds import Guild 5 | from .users import User 6 | from . import ( 7 | with_conn, 8 | build_insert_query, 9 | build_select_query, 10 | JSONB, 11 | Snowflake, 12 | BaseModel, 13 | ) 14 | 15 | 16 | @dataclass 17 | class Channel(BaseModel): 18 | id: Snowflake 19 | type: int 20 | 21 | name: Optional[str] = None 22 | topic: Optional[str] = None 23 | 24 | # Guild Specific 25 | guild_id: Optional[Snowflake] = None 26 | category_id: Optional[Snowflake] = None 27 | position: Optional[int] = None 28 | slowmode_delay: Optional[int] = None 29 | overwrites: Optional[JSONB] = None 30 | 31 | # Voice Specific 32 | bitrate: Optional[int] = None 33 | user_limit: Optional[int] = None 34 | 35 | # DMs 36 | recipients: Optional[JSONB[List[str]]] = None 37 | owner_id: Optional[Snowflake] = None 38 | icon: Optional[str] = None 39 | 40 | _pk = "id" 41 | _refs = { 42 | "guild": (Guild, ("guild_id", "id"), False), 43 | "owner": (User, ("owner_id", "id"), False), 44 | } 45 | _fts = set() 46 | 47 | @classmethod 48 | def from_discord(cls, channel): 49 | inst = cls(id=channel.id, type=channel.type.value) 50 | 51 | if isinstance( 52 | channel, 53 | (discord.TextChannel, discord.VoiceChannel, discord.CategoryChannel), 54 | ): 55 | inst.guild_id = channel.guild.id 56 | inst.name = channel.name 57 | inst.category_id = channel.category_id 58 | inst.position = channel.position 59 | inst.overwrites = { 60 | str(k.id): [i.value for i in v.pair()] 61 | for k, v in channel.overwrites.items() 62 | } 63 | 64 | if isinstance(channel, discord.TextChannel): 65 | inst.slowmode_delay = channel.slowmode_delay 66 | inst.topic = channel.topic 67 | elif isinstance(channel, discord.VoiceChannel): 68 | inst.bitrate = channel.bitrate 69 | inst.user_limit = channel.user_limit 70 | elif isinstance(channel, discord.DMChannel): 71 | inst.recipients = [channel.recipient.id] 72 | elif isinstance(channel, discord.GroupChannel): 73 | inst.recipients = [i.id for i in channel.recipients] 74 | inst.owner_id = channel.owner.id 75 | inst.icon = channel.icon 76 | inst.name = channel.name 77 | 78 | return inst 79 | 80 | 81 | @with_conn 82 | async def upsert_channel(conn, channel): 83 | new_channel = Channel.from_discord(channel) 84 | 85 | # TODO: calculate diff 86 | existing_channel = await conn.fetchrow( 87 | build_select_query(new_channel, "id = $1"), channel.id 88 | ) 89 | 90 | query, args = build_insert_query(new_channel, upsert=True) 91 | await conn.execute(query, *args) 92 | 93 | if existing_channel is not None: 94 | existing_channel = Channel.from_record(existing_channel) 95 | diff = list(new_channel.diff(existing_channel)) 96 | if diff: 97 | print(f"[channel] diff is {diff}") 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # abode 2 | 3 | abode is a self-hosted home server which aggregates your discord data in a discoverable format. 4 | 5 | ## status 6 | 7 | abode is in early personal testing but I've already found it invaluable for easily querying things. The motivated individual could figure out how to properly run a version locally (tl;dr: postgres database, config with postgres dsn and token, run one client and one api), but I'm not currently planning to extensively document the process (you should understand it to use it). 8 | 9 | ## ok but why 10 | 11 | Discord in general is only relatively friendly to data export as evident from the [Discord Data Package](https://support.discordapp.com/hc/en-us/articles/360004957991-Your-Discord-Data-Package), which lacks messages not sent by you, attachments, and various other relevant data (take it from me, I built the original system while there). The data package also serves as a poor real time solution due to its multi-week delay in between requesting packages, and the long generation time of each individual package. abode fills in for this lack of flexibility and features on Discords side by tracking, storing and indexing your data. abode was designed to be used by individuals for tracking only data that "belongs" (any data discord makes visible to the end user) to them, and thus won't work for large-scale or wide berth data-mining. 12 | 13 | The explanation is all well and good, but it doesn't really explain why you would even _want_ your data archived locally. While different folks can have different reasons, the primary advantages to running abode come from: 14 | 15 | - data backup (for when sinister deletes all the channels again) 16 | - fast and powerful querying (for finding **anything** quickly and easily) 17 | - audit log (for when things happened more than 90 days ago) 18 | - data insights (for understanding how you use and experience discord) 19 | 20 | ## how 21 | 22 | abode includes a client component which connects to discord on behalf of your user (with your token, so this is considered ["self botting"](https://support.discordapp.com/hc/en-us/articles/115002192352-Automated-user-accounts-self-bots-) by discord, and could get your account banned), an api server which can process queries written in a custom-DSL, and a lightweight frontend for querying. abodes primary use is for power-user level querying of its internal data archive, but it can also be used as a generic data API for other applications. 23 | 24 | ## query language 25 | 26 | abode is built around a custom search query language, which is intentionally simple but can be quite powerful when used correctly. The real power of abodes queries comes directly from postgres, where abode leans on full text search and unique indexes to keep query response time fast. 27 | 28 | | Syntax | Meaning | 29 | |--------|---------| 30 | | field:value | fuzzy match of value | 31 | | field:"value" | case insensitive exact match of value | 32 | | field="value" | case sensitive exact match of value | 33 | | field:(x y) | fuzzy match of x and y | 34 | | field:(x OR y) | fuzzy match of x or y | 35 | | field:x AND NOT field:y | fuzzy match of x and not y | 36 | | (field:a AND field:b) OR (field:c AND field:d) | fuzzy match of a and b or c and d | 37 | | -> x y z | select fields x, y, and z | 38 | 39 | ## screenshots 40 | 41 | ![](https://i.imgur.com/LIFBQAR.png) 42 | ![](https://i.imgur.com/eFFvES0.png) 43 | ![](https://i.imgur.com/p0fNmWG.png) -------------------------------------------------------------------------------- /abode/db/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncpg 2 | import functools 3 | import os 4 | import dataclasses 5 | import json 6 | import typing 7 | from datetime import datetime 8 | 9 | pool = None 10 | 11 | 12 | T = typing.TypeVar("T") 13 | 14 | 15 | class JSONB(typing.Generic[T]): 16 | inner: T 17 | 18 | 19 | class FTS: 20 | def __init__(self, inner): 21 | self.inner = inner 22 | 23 | 24 | def Snowflake(i): 25 | return int(i) 26 | 27 | 28 | def to_json_str(obj): 29 | if isinstance(obj, str): 30 | return obj 31 | return json.dumps(obj) 32 | 33 | 34 | async def init_db(config, loop): 35 | global pool 36 | 37 | pool = await asyncpg.create_pool(dsn=config.get("postgres_dsn")) 38 | 39 | async with pool.acquire() as connection: 40 | sql_dir = os.path.abspath( 41 | os.path.join(os.path.dirname(__file__), "..", "schema") 42 | ) 43 | for sql_file in os.listdir(sql_dir): 44 | with open(os.path.join(sql_dir, sql_file), "r") as f: 45 | await connection.execute(f.read()) 46 | 47 | 48 | async def close_db(): 49 | await get_pool().close() 50 | 51 | 52 | def get_pool(): 53 | return pool 54 | 55 | 56 | def with_conn(func): 57 | @functools.wraps(func) 58 | async def wrapped(*args, **kwargs): 59 | if "conn" in kwargs: 60 | return await func(kwargs.pop("conn"), *args, **kwargs) 61 | 62 | async with pool.acquire() as connection: 63 | return await func(connection, *args, **kwargs) 64 | 65 | return wrapped 66 | 67 | 68 | def build_insert_query(instance, upsert=False, ignore_existing=False): 69 | dataclass = instance.__class__ 70 | 71 | column_names = [] 72 | column_values = [] 73 | updates = [] 74 | for field in dataclasses.fields(dataclass): 75 | column_names.append(field.name) 76 | column_values.append( 77 | convert_to_type(getattr(instance, field.name), field.type, to_pg=True) 78 | ) 79 | 80 | if upsert: 81 | updates.append(f"{field.name}=excluded.{field.name}") 82 | values = ", ".join([f"${i}" for i in range(1, len(column_names) + 1)]) 83 | 84 | upsert_contents = "" 85 | if upsert: 86 | updates = ",\n".join(updates) 87 | upsert_contents = f""" 88 | ON CONFLICT (id) DO UPDATE SET 89 | {updates} 90 | """ 91 | 92 | if ignore_existing: 93 | assert not upsert 94 | upsert_contents = """ 95 | ON CONFLICT (id) DO NOTHING 96 | """ 97 | 98 | return ( 99 | f""" 100 | INSERT INTO {table_name(dataclass)} ({', '.join(column_names)}) 101 | VALUES ({values}) 102 | {upsert_contents} 103 | """, 104 | tuple(column_values), 105 | ) 106 | 107 | 108 | def build_select_query(instance, where=None): 109 | dataclass = instance.__class__ 110 | select_fields = ", ".join([field.name for field in dataclasses.fields(dataclass)]) 111 | where = f"WHERE {where}" if where else "" 112 | 113 | return f""" 114 | SELECT {select_fields} FROM {table_name(dataclass)} 115 | {where} 116 | """ 117 | 118 | 119 | def convert_to_type(value, target_type, to_pg=False, from_pg=False, to_js=False): 120 | if typing.get_origin(target_type) is typing.Union: 121 | if type(None) in typing.get_args(target_type): 122 | if value is None: 123 | return None 124 | target_type = next( 125 | i for i in typing.get_args(target_type) if i is not type(None) 126 | ) 127 | else: 128 | assert False 129 | 130 | if ( 131 | to_js 132 | and target_type == Snowflake 133 | or (target_type == typing.Optional[Snowflake] and value is not None) 134 | ): 135 | return str(value) 136 | 137 | if ( 138 | to_js 139 | and target_type == datetime 140 | or (target_type == typing.Optional[datetime] and value is not None) 141 | ): 142 | return value.isoformat() 143 | 144 | if typing.get_origin(target_type) == list: 145 | return list(value) 146 | 147 | if type(value) == target_type: 148 | return value 149 | 150 | if to_pg and typing.get_origin(target_type) == JSONB: 151 | return json.dumps(value) 152 | 153 | if from_pg and typing.get_origin(target_type) == JSONB: 154 | return json.loads(value) 155 | 156 | try: 157 | return target_type(value) 158 | except Exception: 159 | print(type(value)) 160 | print(target_type) 161 | print(typing.get_origin(target_type)) 162 | raise 163 | 164 | 165 | def table_name(model): 166 | return getattr(model, "_table_name", model.__name__.lower() + "s") 167 | 168 | 169 | class BaseModel: 170 | _refs = {} 171 | _fts = set() 172 | _virtual_fields = {} 173 | 174 | def serialize(self, **kwargs): 175 | return { 176 | field.name: convert_to_type( 177 | getattr(self, field.name), field.type, to_js=True 178 | ) 179 | for field in dataclasses.fields(self) 180 | } 181 | 182 | def diff(self, other): 183 | for field in dataclasses.fields(self): 184 | if getattr(other, field.name) != getattr(self, field.name): 185 | yield { 186 | "field": field.name, 187 | "old": getattr(other, field.name), 188 | "new": getattr(self, field.name), 189 | } 190 | 191 | @classmethod 192 | def from_record(cls, record): 193 | return cls( 194 | **{ 195 | field.name: convert_to_type(record[idx], field.type, from_pg=True) 196 | for idx, field in enumerate(dataclasses.fields(cls)) 197 | } 198 | ) 199 | -------------------------------------------------------------------------------- /abode/lib/tests/test_query.py: -------------------------------------------------------------------------------- 1 | from abode.lib.query import QueryParser, compile_query, _compile_selector 2 | from abode.db.guilds import Guild 3 | from abode.db.messages import Message 4 | from abode.db.users import User 5 | from abode.db.channels import Channel 6 | 7 | 8 | def test_parse_basic_queries(): 9 | assert QueryParser.parsed("hello world") == [ 10 | {"type": "symbol", "value": "hello"}, 11 | {"type": "symbol", "value": "AND"}, 12 | {"type": "symbol", "value": "world"}, 13 | ] 14 | 15 | assert QueryParser.parsed('"Hello \\" World"') == [ 16 | {"type": "string", "value": 'Hello " World'} 17 | ] 18 | 19 | assert QueryParser.parsed("(group me daddy)") == [ 20 | { 21 | "type": "group", 22 | "value": [ 23 | {"type": "symbol", "value": "group"}, 24 | {"type": "symbol", "value": "AND"}, 25 | {"type": "symbol", "value": "me"}, 26 | {"type": "symbol", "value": "AND"}, 27 | {"type": "symbol", "value": "daddy"}, 28 | ], 29 | } 30 | ] 31 | 32 | assert QueryParser.parsed("x:y") == [ 33 | { 34 | "type": "label", 35 | "name": "x", 36 | "value": {"type": "symbol", "value": "y"}, 37 | "exact": False, 38 | } 39 | ] 40 | 41 | assert QueryParser.parsed("x=y") == [ 42 | { 43 | "type": "label", 44 | "name": "x", 45 | "value": {"type": "symbol", "value": "y"}, 46 | "exact": True, 47 | } 48 | ] 49 | 50 | assert QueryParser.parsed("x:(y z)") == [ 51 | { 52 | "type": "label", 53 | "name": "x", 54 | "value": { 55 | "type": "group", 56 | "value": [ 57 | {"type": "symbol", "value": "y"}, 58 | {"type": "symbol", "value": "AND"}, 59 | {"type": "symbol", "value": "z"}, 60 | ], 61 | }, 62 | "exact": False, 63 | } 64 | ] 65 | 66 | assert QueryParser.parsed("x:/.* lol \\d me daddy/") == [ 67 | { 68 | "type": "label", 69 | "name": "x", 70 | "value": {"type": "regex", "value": ".* lol \\d me daddy", "flags": []}, 71 | "exact": False, 72 | } 73 | ] 74 | 75 | assert QueryParser.parsed("x:/.* lol \\d me daddy/i") == [ 76 | { 77 | "type": "label", 78 | "name": "x", 79 | "value": {"type": "regex", "value": ".* lol \\d me daddy", "flags": ["i"]}, 80 | "exact": False, 81 | } 82 | ] 83 | 84 | assert QueryParser.parsed("-> a b c") == [ 85 | { 86 | "type": "return", 87 | "value": [ 88 | {"type": "symbol", "value": "a"}, 89 | {"type": "symbol", "value": "b"}, 90 | {"type": "symbol", "value": "c"}, 91 | ], 92 | } 93 | ] 94 | 95 | assert QueryParser.parsed("x:y -> a b c") == [ 96 | { 97 | "type": "label", 98 | "name": "x", 99 | "value": {"type": "symbol", "value": "y"}, 100 | "exact": False, 101 | }, 102 | { 103 | "type": "return", 104 | "value": [ 105 | {"type": "symbol", "value": "a"}, 106 | {"type": "symbol", "value": "b"}, 107 | {"type": "symbol", "value": "c"}, 108 | ], 109 | }, 110 | ] 111 | 112 | 113 | def test_parse_complex_queries(): 114 | assert QueryParser.parsed( 115 | 'type:attachment guild:"discord api" (from:Jake#0001 OR from=danny#0007)' 116 | ) == [ 117 | { 118 | "type": "label", 119 | "name": "type", 120 | "value": {"type": "symbol", "value": "attachment"}, 121 | "exact": False, 122 | }, 123 | {"type": "symbol", "value": "AND"}, 124 | { 125 | "type": "label", 126 | "name": "guild", 127 | "value": {"type": "string", "value": "discord api"}, 128 | "exact": False, 129 | }, 130 | {"type": "symbol", "value": "AND"}, 131 | { 132 | "type": "group", 133 | "value": [ 134 | { 135 | "type": "label", 136 | "name": "from", 137 | "value": {"type": "symbol", "value": "Jake#0001"}, 138 | "exact": False, 139 | }, 140 | {"type": "symbol", "value": "OR"}, 141 | { 142 | "type": "label", 143 | "name": "from", 144 | "value": {"type": "symbol", "value": "danny#0007"}, 145 | "exact": True, 146 | }, 147 | ], 148 | }, 149 | ] 150 | 151 | 152 | def test_compile_basic_queries(): 153 | assert compile_query("name:blob", Guild) == ( 154 | "SELECT guilds.* FROM guilds WHERE guilds.name ILIKE $1", 155 | ("%blob%",), 156 | (Guild,), 157 | ) 158 | 159 | assert compile_query('name:"blob"', Guild) == ( 160 | "SELECT guilds.* FROM guilds WHERE guilds.name ILIKE $1", 161 | ("blob",), 162 | (Guild,), 163 | ) 164 | 165 | assert compile_query("name:(blob emoji)", Guild) == ( 166 | "SELECT guilds.* FROM guilds WHERE (guilds.name ILIKE $1 AND guilds.name ILIKE $2)", 167 | ("%blob%", "%emoji%",), 168 | (Guild,), 169 | ) 170 | 171 | assert compile_query("name:(blob AND emoji)", Guild) == ( 172 | "SELECT guilds.* FROM guilds WHERE (guilds.name ILIKE $1 AND guilds.name ILIKE $2)", 173 | ("%blob%", "%emoji%",), 174 | (Guild,), 175 | ) 176 | 177 | assert compile_query("name:(discord AND NOT api)", Guild) == ( 178 | "SELECT guilds.* FROM guilds WHERE (guilds.name ILIKE $1 AND NOT guilds.name ILIKE $2)", 179 | ("%discord%", "%api%",), 180 | (Guild,), 181 | ) 182 | 183 | assert compile_query("id:1", Guild) == ( 184 | "SELECT guilds.* FROM guilds WHERE guilds.id = $1", 185 | (1,), 186 | (Guild,), 187 | ) 188 | 189 | assert compile_query("", Guild, limit=100, offset=150, order_by="id") == ( 190 | "SELECT guilds.* FROM guilds ORDER BY guilds.id ASC LIMIT 100 OFFSET 150", 191 | (), 192 | (Guild,), 193 | ) 194 | 195 | assert compile_query("", Guild, order_by="id", order_dir="DESC") == ( 196 | "SELECT guilds.* FROM guilds ORDER BY guilds.id DESC", 197 | (), 198 | (Guild,), 199 | ) 200 | 201 | assert compile_query("id=1", Guild) == ( 202 | "SELECT guilds.* FROM guilds WHERE guilds.id = $1", 203 | (1,), 204 | (Guild,), 205 | ) 206 | 207 | 208 | def test_compile_complex_queries(): 209 | assert compile_query("name:blob OR name:api", Guild) == ( 210 | "SELECT guilds.* FROM guilds WHERE guilds.name ILIKE $1 OR guilds.name ILIKE $2", 211 | ("%blob%", "%api%"), 212 | (Guild,), 213 | ) 214 | 215 | assert compile_query("guild.name:blob", Message) == ( 216 | "SELECT messages.* FROM messages JOIN guilds ON messages.guild_id = guilds.id WHERE guilds.name ILIKE $1", 217 | ("%blob%",), 218 | (Message,), 219 | ) 220 | 221 | assert compile_query("content:yeet", Message) == ( 222 | "SELECT messages.* FROM messages WHERE to_tsvector('english', messages.content) @@ phraseto_tsquery($1)", 223 | ("yeet",), 224 | (Message,), 225 | ) 226 | 227 | assert compile_query('guild.name:(a "b")', Message) == ( 228 | "SELECT messages.* FROM messages JOIN guilds ON messages.guild_id = guilds.id WHERE (guilds.name ILIKE $1 AND " 229 | "guilds.name ILIKE $2)", 230 | ("%a%", "b"), 231 | (Message,), 232 | ) 233 | 234 | assert compile_query("guild.owner.name:Danny", Message) == ( 235 | "SELECT messages.* FROM messages JOIN guilds ON messages.guild_id = guilds.id JOIN users ON " 236 | "guilds.owner_id = users.id WHERE users.name ILIKE $1", 237 | ("%Danny%",), 238 | (Message,), 239 | ) 240 | 241 | message_selector = _compile_selector(Message) 242 | guild_selector = _compile_selector(Guild) 243 | author_selector = _compile_selector(User) 244 | channel_selector = _compile_selector(Channel) 245 | 246 | assert compile_query("", Message, include_foreign_data=True) == ( 247 | f"SELECT {message_selector}, {author_selector}, {channel_selector} FROM messages " 248 | "JOIN users ON messages.author_id = users.id JOIN channels ON " 249 | "messages.channel_id = channels.id", 250 | (), 251 | (Message, User, Channel), 252 | ) 253 | 254 | assert compile_query("guild.id:1", Message, include_foreign_data=True) == ( 255 | f"SELECT {message_selector}, {guild_selector}, {author_selector}, {channel_selector} FROM messages JOIN guilds" 256 | " ON messages.guild_id = guilds.id JOIN users ON messages.author_id = users.id JOIN channels ON " 257 | "messages.channel_id = channels.id WHERE guilds.id = $1", 258 | (1,), 259 | (Message, Guild, User, Channel), 260 | ) 261 | 262 | assert compile_query("name: /xxx.*xxx/i", Guild) == ( 263 | f"SELECT guilds.* FROM guilds WHERE guilds.name ~* $1", 264 | ("xxx.*xxx",), 265 | (Guild,), 266 | ) 267 | 268 | guild_selector = _compile_selector(Guild) 269 | user_selector = _compile_selector(User) 270 | assert compile_query( 271 | "name: /xxx.*xxx/i -> name owner.name", Guild, returns=True 272 | ) == ( 273 | f"SELECT {guild_selector}, {user_selector} FROM guilds JOIN users ON guilds.owner_id = users.id WHERE guilds.name ~* $1", 274 | ("xxx.*xxx",), 275 | (Guild, User), 276 | ("name", "owner.name"), 277 | ) 278 | 279 | assert compile_query("name: /xxx.*xxx/i ->", Guild, returns=True) == ( 280 | f"SELECT guilds.* FROM guilds WHERE guilds.name ~* $1", 281 | ("xxx.*xxx",), 282 | (Guild,), 283 | (), 284 | ) 285 | 286 | -------------------------------------------------------------------------------- /abode/lib/query.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements everything required for processing and executing user-written 3 | queries in our custom search DSL. As the language is small and simple enough there 4 | is no real concept of a lexer, and instead we directly parse search queries into 5 | an extremely simple AST. 6 | 7 | From an AST we can then generate a query based on a given model and query options. 8 | This query is returned in a form that can be easily passed to SQL interfaces 9 | (query, (args...)). 10 | """ 11 | import dataclasses 12 | import typing 13 | from abode.db import table_name, JSONB, FTS, Snowflake 14 | 15 | JOINERS = ("AND", "OR") 16 | 17 | 18 | class QueryParser: 19 | def __init__(self, query_string): 20 | self.idx = 0 21 | self.buffer = query_string 22 | 23 | @classmethod 24 | def parsed(cls, query): 25 | inst = cls(query) 26 | return inst._fix(inst._parse()) 27 | 28 | def _next_char(self): 29 | if self.idx >= len(self.buffer): 30 | return None 31 | char = self.buffer[self.idx] 32 | self.idx += 1 33 | return char 34 | 35 | def _peek_char(self, n=0): 36 | if self.idx + n >= len(self.buffer): 37 | return None 38 | return self.buffer[self.idx + n] 39 | 40 | def _parse_string(self, chr='"'): 41 | escaped = False 42 | parts = "" 43 | while True: 44 | char = self._next_char() 45 | assert char is not None 46 | if char == chr: 47 | if escaped: 48 | escaped = False 49 | else: 50 | return parts 51 | elif char == "\\" and not escaped: 52 | escaped = True 53 | continue 54 | elif escaped: 55 | escaped = False 56 | parts += "\\" 57 | parts += char 58 | 59 | def _parse_symbol(self): 60 | parts = "" 61 | while True: 62 | char = self._peek_char() 63 | if char in (" ", ":", "=", '"', "(", ")", "/", None): 64 | return parts 65 | parts += self._next_char() 66 | 67 | def _parse_one(self): 68 | while True: 69 | char = self._next_char() 70 | if char is None or char == ")": 71 | return None 72 | elif char == '"': 73 | string = self._parse_string() 74 | return {"type": "string", "value": string} 75 | elif char == "(": 76 | return {"type": "group", "value": self._parse()} 77 | elif char in ("-", "=") and self._peek_char() == ">": 78 | self._next_char() 79 | value = self._parse() 80 | if any(i["type"] != "symbol" for i in value): 81 | raise Exception("only symbols are allowed in a returns section") 82 | return {"type": "return", "value": value} 83 | elif char == " ": 84 | continue 85 | elif char == "/": 86 | value = self._parse_string("/") 87 | 88 | flags = [] 89 | while self._peek_char() in ("i",): 90 | flags.append(self._next_char()) 91 | 92 | return { 93 | "type": "regex", 94 | "value": value, 95 | "flags": flags, 96 | } 97 | else: 98 | self.idx -= 1 99 | symbol = self._parse_symbol() 100 | if not symbol: 101 | return None 102 | 103 | if self._peek_char() in (":", "="): 104 | exact = self._next_char() == "=" 105 | return { 106 | "type": "label", 107 | "name": symbol, 108 | "value": self._parse_one(), 109 | "exact": exact, 110 | } 111 | 112 | return {"type": "symbol", "value": symbol} 113 | 114 | def _parse(self): 115 | parts = [] 116 | 117 | while True: 118 | one = self._parse_one() 119 | if not one: 120 | return parts 121 | parts.append(one) 122 | 123 | def _fix(self, tree): 124 | result = [] 125 | previous_node = None 126 | for node in tree: 127 | if node["type"] == "group": 128 | node["value"] = self._fix(node["value"]) 129 | elif node["type"] == "symbol": 130 | if node["value"] == "NOT": 131 | if not previous_node or ( 132 | previous_node["type"] == "symbol" 133 | and previous_node["value"] not in JOINERS 134 | ): 135 | raise Exception("NOT requires a joiner prefix (and/or)") 136 | elif node["value"] in JOINERS: 137 | if ( 138 | previous_node 139 | and previous_node["type"] == "symbol" 140 | and previous_node["value"] in JOINERS 141 | ): 142 | raise Exception("One side of joiners cannot be another joiner") 143 | elif node["type"] == "label" and node["value"]["type"] == "group": 144 | # TODO: HMMMMM, need _fix_one(node, previous=xxx) ?? 145 | node["value"]["value"] = self._fix(node["value"]["value"]) 146 | 147 | # Injects 'AND' in between bare non-joiners 148 | if ( 149 | (node["type"] != "symbol" or node["value"] not in JOINERS) 150 | and node["type"] != "return" 151 | and previous_node 152 | and ( 153 | previous_node["type"] != "symbol" 154 | or previous_node["value"] not in JOINERS + ("NOT",) 155 | ) 156 | ): 157 | result.append({"type": "symbol", "value": "AND"}) 158 | previous_node = node 159 | result.append(node) 160 | return result 161 | 162 | 163 | def _resolve_foreign_model_field(field_name, model, joins=None): 164 | rest = None 165 | foreign_field_name, field_name = field_name.split(".", 1) 166 | if "." in field_name: 167 | field_name, rest = field_name.split(".", 1) 168 | 169 | ref_model, on, _ = model._refs[foreign_field_name] 170 | 171 | if not joins: 172 | joins = {} 173 | 174 | joins.update( 175 | {ref_model: f"{table_name(model)}.{on[0]} = {table_name(ref_model)}.{on[1]}"} 176 | ) 177 | 178 | if rest: 179 | return _resolve_foreign_model_field(field_name + "." + rest, ref_model, joins) 180 | 181 | _, ref_type, _ = resolve_model_field(field_name, ref_model) 182 | 183 | return (f"{table_name(ref_model)}.{field_name}", ref_type, joins) 184 | 185 | 186 | def resolve_model_field(field_name, model, allow_virtual=False): 187 | """ 188 | Resolves a field name within a given model. This function will generate joins 189 | for cases where the target field is on a relation or is stored within an 190 | external index. 191 | 192 | The result of this function is a tuple of the target field name, the field 193 | result type, and a dictionary of joins. 194 | 195 | >>> resolve_model_field("guilds.name", Message) 196 | ("guilds.name", str, {"guilds": "messages.guild_id = guilds.id"}) 197 | 198 | >>> resolve_model_field("content", Message) 199 | ("messages_fts.content", FTS(str), {"messages_fts": "messages.id = messages_fts.rowid"}) 200 | """ 201 | if "." in field_name: 202 | return _resolve_foreign_model_field(field_name, model) 203 | 204 | if allow_virtual and field_name in model._virtual_fields: 205 | return (None, None, {}) 206 | 207 | for field in dataclasses.fields(model): 208 | if field.name == field_name: 209 | if field.name in model._fts: 210 | return ( 211 | f"{table_name(model)}.{field.name}", 212 | FTS(field.type), 213 | {}, 214 | ) 215 | 216 | return f"{table_name(model)}.{field.name}", field.type, {} 217 | raise Exception(f"no such field on {model}: `{field_name}``") 218 | 219 | 220 | def _compile_field_query_op(field, field_type, token, varidx): 221 | """ 222 | Compiles a single token against a given field type into a single query filter. 223 | Sadly this function also encodes some more complex logic about querying, such 224 | as wildcard processing and exact matching. 225 | 226 | Returns a tuple of the filter operator and the processed token value as an 227 | argument to the operator. 228 | """ 229 | var = f"${varidx}" 230 | assert token["type"] in ("symbol", "string") 231 | 232 | if typing.get_origin(field_type) is typing.Union: 233 | args = typing.get_args(field_type) 234 | field_type = next(i for i in args if i != type(None)) 235 | 236 | if isinstance(field_type, FTS): 237 | if token.get("exact"): 238 | return (field, "=", token["value"], var) 239 | elif token["type"] == "string": 240 | # TODO: jank replace 241 | return (field, "ILIKE", token["value"].replace("*", "%"), var) 242 | return ( 243 | f"to_tsvector('english', {field})", 244 | "@@", 245 | token["value"], 246 | f"phraseto_tsquery({var})", 247 | ) 248 | elif typing.get_origin(field_type) == JSONB: 249 | (inner,) = typing.get_args(field_type) 250 | 251 | if typing.get_origin(inner) == list: 252 | (type_fn,) = typing.get_args(inner) 253 | return (field, "@>", type_fn(token["value"]), var) 254 | 255 | assert False 256 | elif field_type == Snowflake: 257 | return (field, "=", Snowflake(token["value"]), var) 258 | elif field_type == str or field_type == typing.Optional[str]: 259 | if token.get("exact"): 260 | return (field, "=", token["value"], var) 261 | elif token["type"] == "symbol": 262 | # TODO: regex this so we can handle escapes? 263 | if "*" in token["value"]: 264 | return (field, "ILIKE", token["value"].replace("*", "%"), var) 265 | return (field, "ILIKE", "%" + token["value"] + "%", var) 266 | else: 267 | # Like just gives us case insensitivity here 268 | return (field, "ILIKE", token["value"], var) 269 | elif field_type == int or field_type == typing.Optional[int]: 270 | return (field, "=", int(token["value"]), var) 271 | elif field_type == bool or field_type == typing.Optional[bool]: 272 | return ( 273 | field, 274 | "=", 275 | {"t": True, "f": False, "true": True, "false": False}[ 276 | token["value"].lower() 277 | ], 278 | var, 279 | ) 280 | else: 281 | print(token) 282 | print(field_type) 283 | print(typing.get_origin(field_type)) 284 | print(typing.get_args(field_type)) 285 | raise Exception(f"cannot query against field of type {field_type}") 286 | 287 | 288 | def _compile_model_refs_join(model, value): 289 | joins = {} 290 | ref_model = model 291 | while value.split(".", 1)[0] in ref_model._refs: 292 | ref_model, join_on, _ = model._refs[value] 293 | 294 | joins.update( 295 | { 296 | ref_model: f"{table_name(model)}.{join_on[0]} = {table_name(ref_model)}.{join_on[1]}" 297 | } 298 | ) 299 | 300 | if "." not in value: 301 | return joins 302 | 303 | value = value.split(".", 1)[1] 304 | 305 | raise Exception(f"unlabeled symbol cannot be matched: `{value}`") 306 | 307 | 308 | def _compile_token_for_query(token, model, field=None, field_type=None, varidx=0): 309 | """ 310 | Compile a single token into a single filter against the model. 311 | 312 | Returns a tuple of the where clause, variables, and joins. 313 | """ 314 | 315 | if token["type"] == "label": 316 | field, field_type, field_joins = resolve_model_field(token["name"], model) 317 | token["value"]["exact"] = token["exact"] 318 | where, variables, joins, varidx, returns = _compile_token_for_query( 319 | token["value"], model, field=field, field_type=field_type, varidx=varidx 320 | ) 321 | joins.update(field_joins) 322 | return where, variables, joins, varidx, returns 323 | elif token["type"] == "symbol": 324 | if token["value"] in ("AND", "OR", "NOT"): 325 | return (token["value"], [], {}, varidx, None) 326 | elif field: 327 | varidx += 1 328 | field, op, arg, var = _compile_field_query_op( 329 | field, field_type, token, varidx 330 | ) 331 | return (f"{field} {op} {var}", [arg], {}, varidx, None) 332 | else: 333 | joins = _compile_model_refs_join(model, token["value"]) 334 | return (f"true", [], joins, varidx, None) 335 | elif token["type"] == "string" and field: 336 | varidx += 1 337 | field, op, arg, var = _compile_field_query_op(field, field_type, token, varidx) 338 | return (f"{field} {op} {var}", [arg], {}, varidx, None) 339 | elif token["type"] == "regex" and field: 340 | varidx += 1 341 | op = "~" 342 | if "i" in token["flags"]: 343 | op = "~*" 344 | return ( 345 | f"{field} {op} ${varidx}", 346 | [token["value"]], 347 | {}, 348 | varidx, 349 | None, 350 | ) 351 | elif token["type"] == "group": 352 | where = [] 353 | variables = [] 354 | joins = {} 355 | for child_token in token["value"]: 356 | child_token["exact"] = token.get("exact", False) 357 | ( 358 | where_part, 359 | variables_part, 360 | joins_part, 361 | varidx, 362 | returns, 363 | ) = _compile_token_for_query( 364 | child_token, model, field=field, field_type=field_type, varidx=varidx 365 | ) 366 | joins.update(joins_part) 367 | where.append(where_part) 368 | variables.extend(variables_part) 369 | return ("(" + " ".join(where) + ")", variables, joins, varidx, returns) 370 | elif token["type"] == "return": 371 | return (None, None, None, varidx, [i["value"] for i in token["value"]]) 372 | else: 373 | print(token) 374 | assert False 375 | 376 | 377 | def _compile_selector(model): 378 | return ", ".join( 379 | f"{table_name(model)}.{field.name}" for field in dataclasses.fields(model) 380 | ) 381 | 382 | 383 | def _compile_query_for_model( 384 | tokens, 385 | model, 386 | limit=None, 387 | offset=None, 388 | order_by=None, 389 | order_dir="ASC", 390 | include_foreign_data=False, 391 | returns=False, 392 | ): 393 | return_fields = None 394 | parts = [] 395 | varidx = 0 396 | for token in tokens: 397 | a, b, c, varidx, _returns = _compile_token_for_query( 398 | token, model, varidx=varidx 399 | ) 400 | if _returns is not None: 401 | if return_fields is not None: 402 | raise Exception("multiple returns? bad juju!") 403 | return_fields = tuple(_returns) 404 | continue 405 | if a is not None: 406 | parts.append((a, b, c)) 407 | 408 | where = [] 409 | variables = [] 410 | joins = {} 411 | 412 | for where_part, variables_part, joins_part in parts: 413 | where.append(where_part) 414 | variables.extend(variables_part) 415 | joins.update(joins_part) 416 | 417 | if order_by: 418 | field, field_type, order_joins = resolve_model_field(order_by, model) 419 | joins.update(order_joins) 420 | assert order_dir in ("ASC", "DESC") 421 | order_by = f" ORDER BY {field} {order_dir}" 422 | else: 423 | order_by = "" 424 | 425 | models = {model: None} 426 | if return_fields: 427 | for field in return_fields: 428 | _, _, joins_part = resolve_model_field(field, model, allow_virtual=True) 429 | if joins is not None: 430 | joins.update(joins_part) 431 | models.update({k: None for k in joins.keys()}) 432 | 433 | if include_foreign_data: 434 | for ref_model, join_on, always in model._refs.values(): 435 | if ref_model in joins: 436 | models[ref_model] = None 437 | continue 438 | 439 | if always: 440 | joins.update( 441 | { 442 | ref_model: f"{table_name(model)}.{join_on[0]} = {table_name(ref_model)}.{join_on[1]}" 443 | } 444 | ) 445 | models[ref_model] = None 446 | 447 | if len(models) > 1: 448 | selectors = ", ".join(_compile_selector(model) for model in models.keys()) 449 | else: 450 | selectors = f"{table_name(model)}.*" 451 | 452 | if joins: 453 | joins = "".join( 454 | f" JOIN {table_name(model)} ON {cond}" for model, cond in joins.items() 455 | ) 456 | else: 457 | joins = "" 458 | 459 | if where: 460 | where = " WHERE " + " ".join(where) 461 | else: 462 | where = "" 463 | 464 | suffix = [] 465 | if limit is not None and limit > 0: 466 | suffix.append(f" LIMIT {limit}") 467 | 468 | if offset is not None and offset > 0: 469 | suffix.append(f" OFFSET {offset}") 470 | 471 | suffix = "".join(suffix) 472 | 473 | query = ( 474 | f"SELECT {selectors} FROM {table_name(model)}{joins}{where}{order_by}{suffix}" 475 | ) 476 | variables = tuple(variables) 477 | models = tuple(models.keys()) 478 | 479 | if returns: 480 | return query, variables, models, return_fields 481 | return query, variables, models 482 | 483 | 484 | def compile_query(query, model, **kwargs): 485 | tokens = QueryParser.parsed(query) 486 | return _compile_query_for_model(tokens, model, **kwargs) 487 | 488 | 489 | def decode_query_record(record, models): 490 | idx = 0 491 | for model in models: 492 | num_fields = len(dataclasses.fields(model)) 493 | model_data = record[idx : idx + num_fields] 494 | idx += num_fields 495 | yield model.from_record(model_data) 496 | 497 | 498 | def _resolve_return_field(model, field): 499 | if "." in field: 500 | field, rest = field.split(".", 1) 501 | model = model._refs[field][0] 502 | return _resolve_return_field(model, rest) 503 | return model, field 504 | 505 | 506 | def _get_field_offset(record_offsets, model, field_name): 507 | fields = dataclasses.fields(model) 508 | offset = record_offsets[model] + [i.name for i in fields].index(field_name) 509 | return offset, next(i for i in fields if i.name == field_name) 510 | 511 | 512 | def decode_query_results(models, return_fields, results): 513 | # TODO: leaky af 514 | from abode.db import convert_to_type 515 | 516 | # I guess why not 517 | if return_fields is None: 518 | return_fields = [i.name for i in dataclasses.fields(models[0])] 519 | 520 | record_offsets = {} 521 | idx = 0 522 | for model in models: 523 | record_offsets[model] = idx 524 | idx += len(dataclasses.fields(model)) 525 | 526 | columns = [] 527 | for field in return_fields: 528 | model, field = _resolve_return_field(models[0], field) 529 | if field in model._virtual_fields: 530 | offsets = {} 531 | for field_dep in model._virtual_fields[field]: 532 | offset, field_dep = _get_field_offset(record_offsets, model, field_dep) 533 | offsets[offset] = field_dep 534 | columns.append(("virtual", field, offsets, getattr(model, field))) 535 | else: 536 | offset, field = _get_field_offset(record_offsets, model, field) 537 | columns.append(("offset", field, offset)) 538 | 539 | rows = [] 540 | for result_row in results: 541 | row = [] 542 | 543 | for column in columns: 544 | if column[0] == "offset": 545 | field, offset = column[1:] 546 | row.append( 547 | convert_to_type( 548 | convert_to_type(result_row[offset], field.type, from_pg=True), 549 | field.type, 550 | to_js=True, 551 | ) 552 | ) 553 | elif column[0] == "virtual": 554 | field, offsets, fn = column[1:] 555 | args = [ 556 | convert_to_type( 557 | convert_to_type(result_row[offset], field.type, from_pg=True), 558 | field.type, 559 | to_js=True, 560 | ) 561 | for offset, field in offsets.items() 562 | ] 563 | 564 | # TODO: typecast based on fn 565 | row.append(fn(*args)) 566 | 567 | rows.append(row) 568 | 569 | return rows, return_fields 570 | --------------------------------------------------------------------------------