├── abode
├── __init__.py
├── lib
│ ├── __init__.py
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_query.py
│ └── query.py
├── schema
│ ├── changelog.sql
│ ├── users.sql
│ ├── emoji.sql
│ ├── guilds.sql
│ ├── channels.sql
│ └── messages.sql
├── client.py
├── backfill.py
├── db
│ ├── users.py
│ ├── guilds.py
│ ├── emoji.py
│ ├── messages.py
│ ├── channels.py
│ └── __init__.py
├── cli.py
├── events.py
└── server.py
├── .gitignore
├── requirements.txt
├── frontend
├── templates
│ ├── guild.html
│ ├── user.html
│ ├── message.html
│ ├── emoji.html
│ ├── channel.html
│ └── results.html
├── styles.css
├── index.html
└── script.js
├── .vscode
└── settings.json
└── README.md
/abode/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/abode/lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/abode/lib/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | config.json
3 | test.db*
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | asyncpg
2 | discord.py
3 | sanic
4 |
--------------------------------------------------------------------------------
/frontend/templates/guild.html:
--------------------------------------------------------------------------------
1 |
2 | {{ row.name }}
3 |
--------------------------------------------------------------------------------
/frontend/templates/user.html:
--------------------------------------------------------------------------------
1 |
2 | {{ row.name }}#{{ row.discriminator|discrim }}
3 |
--------------------------------------------------------------------------------
/frontend/templates/message.html:
--------------------------------------------------------------------------------
1 |
2 | {{ row.user.name }}#{{ row.user.discriminator|discrim }} {{ row.content }}
3 |
--------------------------------------------------------------------------------
/frontend/templates/emoji.html:
--------------------------------------------------------------------------------
1 |
2 |

3 |
{{ row.name }}
4 |
--------------------------------------------------------------------------------
/abode/schema/changelog.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE IF NOT EXISTS changelog (
2 | entity_id text,
3 | version integer,
4 | field text,
5 | value text,
6 |
7 | PRIMARY KEY (entity_id, version)
8 | );
--------------------------------------------------------------------------------
/frontend/templates/channel.html:
--------------------------------------------------------------------------------
1 |
2 | {% if row.type == 0 or row.type == 2 or row.type == 4 %}
3 | {{ row.name }} ({{ row.guild.name }})
4 | {% else %}
5 | {{ row.name }}
6 | {% endif %}
7 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "/home/andrei/code/python/abode/.venv/bin/python3.8",
3 | "python.linting.pylintEnabled": false,
4 | "python.linting.flake8Enabled": true,
5 | "python.linting.enabled": true
6 | }
--------------------------------------------------------------------------------
/abode/schema/users.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE IF NOT EXISTS users (
2 | id BIGINT PRIMARY KEY,
3 | name text NOT NULL,
4 | discriminator smallint NOT NULL,
5 | avatar text,
6 | bot boolean,
7 | system boolean
8 | );
9 |
10 | CREATE INDEX IF NOT EXISTS users_name_trgm ON users USING gin (name gin_trgm_ops);
--------------------------------------------------------------------------------
/abode/schema/emoji.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE IF NOT EXISTS emoji (
2 | id BIGINT PRIMARY KEY,
3 | guild_id BIGINT,
4 | author_id BIGINT,
5 | name text,
6 | require_colons boolean,
7 | managed boolean,
8 | animated boolean,
9 | roles jsonb,
10 | created_at timestamp,
11 | deleted boolean
12 | );
--------------------------------------------------------------------------------
/frontend/templates/results.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | {% for field in fields %}
4 | | {{ field }} |
5 | {% endfor %}
6 |
7 |
8 | {% for row in rows %}
9 |
10 | {% for col in row %}
11 | |
12 | {{ col|safe }}
13 | |
14 | {% endfor %}
15 |
16 | {% endfor %}
17 |
--------------------------------------------------------------------------------
/frontend/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: sans-serif;
3 | font-size: 0.8rem;
4 | padding-bottom: 0.9rem;
5 | }
6 |
7 | .search-container {
8 | text-align: center;
9 | margin-bottom: 12px;
10 | }
11 |
12 | .search-container input {
13 | width: 33%;
14 | }
15 |
16 | table, td, th {
17 | border: 1px solid #ddd;
18 | text-align: left;
19 | }
20 |
21 | table {
22 | border-collapse: collapse;
23 | width: 100%;
24 | }
25 |
26 | th, td {
27 | padding: 8px;
28 | }
--------------------------------------------------------------------------------
/abode/schema/guilds.sql:
--------------------------------------------------------------------------------
1 | CREATE EXTENSION IF NOT EXISTS pg_trgm;
2 |
3 | CREATE TABLE IF NOT EXISTS guilds (
4 | id BIGINT PRIMARY KEY,
5 | owner_id BIGINT,
6 | name text,
7 | region text,
8 | icon text,
9 | is_currently_joined boolean,
10 | features jsonb,
11 | banner text,
12 | description text,
13 | splash text,
14 | discovery_splash text,
15 | premium_tier smallint,
16 | premium_subscription_count int
17 | );
18 |
19 | CREATE INDEX IF NOT EXISTS guilds_name_trgm ON guilds USING gin (name gin_trgm_ops);
20 | CREATE INDEX IF NOT EXISTS guilds_owner_id_idx ON guilds (owner_id);
--------------------------------------------------------------------------------
/abode/schema/channels.sql:
--------------------------------------------------------------------------------
1 | CREATE EXTENSION IF NOT EXISTS pg_trgm;
2 |
3 | CREATE TABLE IF NOT EXISTS channels (
4 | id BIGINT PRIMARY KEY,
5 | type SMALLINT NOT NULL,
6 | name text,
7 | topic text,
8 | guild_id BIGINT,
9 | category_id BIGINT,
10 | position BIGINT,
11 | slowmode_delay BIGINT,
12 | overwrites jsonb,
13 | bitrate INT,
14 | user_limit SMALLINT,
15 | recipients jsonb,
16 | owner_id BIGINT,
17 | icon text
18 | );
19 |
20 | CREATE INDEX IF NOT EXISTS channels_name_trgm ON channels USING gin (name gin_trgm_ops);
21 | CREATE INDEX IF NOT EXISTS channels_type_idx ON channels (type);
22 | CREATE INDEX IF NOT EXISTS channels_guild_id_idx ON channels (guild_id);
--------------------------------------------------------------------------------
/frontend/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
13 |
14 |
15 |
17 |
18 |
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/abode/client.py:
--------------------------------------------------------------------------------
1 | import discord
2 | import functools
3 |
4 |
5 | from .events import (
6 | on_ready,
7 | on_guild_join,
8 | on_guild_update,
9 | on_guild_remove,
10 | on_message,
11 | )
12 |
13 |
14 | USER_AGENT = "Mozilla/5.0 (X11; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
15 |
16 |
17 | def setup_client(config, loop):
18 | client = discord.Client(loop=loop)
19 |
20 | def bind(fn):
21 | return functools.partial(fn, client)
22 |
23 | client.on_ready = bind(on_ready)
24 | client.on_guild_join = bind(on_guild_join)
25 | client.on_guild_update = bind(on_guild_update)
26 | client.on_guild_remove = bind(on_guild_remove)
27 | client.on_message = bind(on_message)
28 |
29 | # TODO: cf cookies, IDENTIFY information
30 | client.http.user_agent = USER_AGENT
31 | return client.start(config["token"], bot=False), client.logout()
32 |
--------------------------------------------------------------------------------
/abode/schema/messages.sql:
--------------------------------------------------------------------------------
1 | CREATE EXTENSION IF NOT EXISTS pg_trgm;
2 |
3 | CREATE TABLE IF NOT EXISTS messages (
4 | id BIGINT PRIMARY KEY,
5 | guild_id BIGINT,
6 | channel_id BIGINT NOT NULL,
7 | author_id BIGINT NOT NULL,
8 | webhook_id BIGINT,
9 |
10 | tts boolean NOT NULL,
11 | type integer NOT NULL,
12 | content text NOT NULL,
13 | embeds jsonb,
14 | mention_everyone boolean NOT NULL,
15 | flags integer NOT NULL,
16 | activity jsonb,
17 | application jsonb,
18 |
19 | created_at timestamp NOT NULL,
20 | edited_at timestamp,
21 | deleted boolean NOT NULL
22 | );
23 |
24 | CREATE INDEX IF NOT EXISTS messages_content_trgm ON messages USING gin (content gin_trgm_ops);
25 | CREATE INDEX IF NOT EXISTS messages_content_fts ON messages USING gin (to_tsvector('english', content));
26 | CREATE INDEX IF NOT EXISTS messages_guild_id_idx ON messages (guild_id);
27 | CREATE INDEX IF NOT EXISTS messages_channel_id_idx ON messages (channel_id);
28 | CREATE INDEX IF NOT EXISTS messages_author_id_idx ON messages (author_id);
--------------------------------------------------------------------------------
/abode/backfill.py:
--------------------------------------------------------------------------------
1 | from .db.messages import insert_message
2 | from .db.channels import upsert_channel
3 | from discord import TextChannel
4 |
5 |
6 | async def backfill_channel(channel):
7 | print(f"Backfilling channel {channel.id}")
8 | await upsert_channel(channel)
9 | idx = 0
10 | async for message in channel.history(limit=None, oldest_first=True):
11 | idx += 1
12 | try:
13 | await insert_message(message)
14 | except Exception as e:
15 | print(f" [{channel.id}] failed to backfill message {message.id}: {e}")
16 |
17 | if idx % 5000 == 0:
18 | print(f" [{channel.id}] {idx} messages")
19 | print(f"Done backfilling channel {channel.id}, scanned {idx}")
20 |
21 |
22 | async def backfill_guild(guild):
23 | print(f"Backfilling guild {guild.id}")
24 | for channel in guild.channels:
25 | if isinstance(channel, TextChannel):
26 | try:
27 | await upsert_channel(channel)
28 | await backfill_channel(channel)
29 | except Exception as e:
30 | print(f"failed to backfill channel {channel.id}: {e}")
31 |
--------------------------------------------------------------------------------
/abode/db/users.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | from datetime import datetime
4 | from . import (
5 | with_conn,
6 | build_insert_query,
7 | build_select_query,
8 | JSONB,
9 | Snowflake,
10 | BaseModel,
11 | )
12 |
13 |
14 | @dataclass
15 | class User(BaseModel):
16 | id: Snowflake
17 | name: str
18 | discriminator: int
19 | avatar: Optional[str]
20 | bot: bool
21 | system: bool
22 |
23 | _pk = "id"
24 |
25 | @classmethod
26 | def from_discord(cls, user):
27 | return cls(
28 | id=user.id,
29 | name=user.name,
30 | discriminator=int(user.discriminator),
31 | avatar=user.avatar,
32 | bot=user.bot,
33 | system=user.system,
34 | )
35 |
36 |
37 | @with_conn
38 | async def upsert_user(conn, user):
39 | user = User.from_discord(user)
40 |
41 | # TODO: calculate diff
42 | existing_user = await conn.fetchrow(build_select_query(user, "id = $1"), user.id)
43 |
44 | query, args = build_insert_query(user, upsert=True)
45 | await conn.execute(query, *args)
46 |
47 | if existing_user is not None:
48 | existing_user = User.from_record(existing_user)
49 | diff = list(user.diff(existing_user))
50 | if diff:
51 | print(f"[user] diff is {diff}")
52 |
--------------------------------------------------------------------------------
/abode/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import asyncio
4 | import argparse
5 |
6 | from .db import init_db, close_db
7 | from .client import setup_client
8 | from .server import setup_server
9 |
10 | parser = argparse.ArgumentParser("abode")
11 | parser.add_argument("--run-api", action="store_true")
12 | parser.add_argument("--run-client", action="store_true")
13 |
14 |
15 | def main():
16 | args = parser.parse_args()
17 |
18 | with open(os.getenv("ABODE_CONFIG_PATH", "config.json"), "r") as f:
19 | config = json.load(f)
20 |
21 | start_tasks = []
22 | cleanup_tasks = []
23 |
24 | if args.run_api:
25 | server_start = setup_server(config)
26 | start_tasks.append(server_start)
27 |
28 | loop = asyncio.get_event_loop()
29 | loop.create_task(init_db(config, loop))
30 | cleanup_tasks.append(close_db())
31 |
32 | if args.run_client:
33 | client_start, client_logout = setup_client(config, loop)
34 | start_tasks.append(client_start)
35 | cleanup_tasks.append(client_logout)
36 |
37 | try:
38 | for task in start_tasks:
39 | loop.create_task(task)
40 | loop.run_forever()
41 | except KeyboardInterrupt:
42 | for task in cleanup_tasks:
43 | loop.run_until_complete(task)
44 | finally:
45 | loop.close()
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/frontend/script.js:
--------------------------------------------------------------------------------
1 | const models = ["message", "emoji", "channel", "guild", "user"];
2 | const templates = { "results": null };
3 |
4 | const env = nunjucks.configure({ autoescape: true });
5 | env.addFilter("discrim", (str) => {
6 | return ('0000' + str).slice(-4);
7 | });
8 |
9 | for (const model of Object.keys(templates).concat(models)) {
10 | fetch(`/templates/${model}.html`).then((response) => {
11 | return response.text();
12 | }).then((body) => {
13 | templates[model] = body;
14 | console.log(body);
15 | });
16 | }
17 |
18 | function renderModelRow(name, row) {
19 | return nunjucks.renderString(templates[name], { row });
20 | }
21 |
22 | function renderResult(results, fields) {
23 | $("#error").hide();
24 | $("#results").html(nunjucks.renderString(templates["results"], {
25 | fields: fields,
26 | rows: results,
27 | }))
28 | }
29 |
30 | function renderError(error) {
31 | $("#error").show().text(error);
32 | }
33 |
34 | function handleSearchChange(event) {
35 | var currentModel = $("#model option:selected").text();
36 | var query = {
37 | "query": $(event.target).val(),
38 | "limit": 1000,
39 | "order_by": "id",
40 | "order_dir": "DESC",
41 | };
42 |
43 | fetch(`/search/${currentModel}`, {
44 | method: 'POST',
45 | body: JSON.stringify(query),
46 | }).then((response) => {
47 | return response.json();
48 | }).then((data) => {
49 | console.log("[Debug]", data);
50 | if (data.results && !data.error) {
51 | renderResult(data.results, data.fields);
52 | } else if (data.error) {
53 | renderError(data.error);
54 | }
55 | });
56 | }
57 |
58 | $(document).ready(function () {
59 | for (model of models) {
60 | $("#model").append(``);
61 | }
62 |
63 | var val = $("#search").val();
64 | $("#search").focus().val("").val(val);
65 | $("#search").keyup((e) => {
66 | if (e.keyCode == 13) {
67 | handleSearchChange(e);
68 | }
69 | });
70 | });
--------------------------------------------------------------------------------
/abode/db/guilds.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, fields
2 | from typing import Optional
3 | from . import (
4 | with_conn,
5 | build_insert_query,
6 | build_select_query,
7 | convert_to_type,
8 | Snowflake,
9 | BaseModel,
10 | JSONB,
11 | )
12 | from .users import User
13 |
14 |
15 | @dataclass
16 | class Guild(BaseModel):
17 | id: Snowflake
18 | owner_id: Snowflake
19 | name: str
20 | region: str
21 | icon: Optional[str]
22 | features: JSONB
23 | banner: Optional[str]
24 | description: Optional[str]
25 | splash: Optional[str]
26 | discovery_splash: Optional[str]
27 | premium_tier: int
28 | premium_subscription_count: int
29 | is_currently_joined: bool = None
30 |
31 | _pk = "id"
32 | _refs = {"owner": (User, ("owner_id", "id"), True)}
33 | _fts = set()
34 |
35 | @classmethod
36 | def from_attrs(cls, guild, is_currently_joined=None):
37 | kwargs = {"is_currently_joined": is_currently_joined}
38 | for field in fields(cls):
39 | if not hasattr(guild, field.name):
40 | continue
41 |
42 | kwargs[field.name] = convert_to_type(getattr(guild, field.name), field.type)
43 | return cls(**kwargs)
44 |
45 |
46 | @with_conn
47 | async def upsert_guild(conn, guild, is_currently_joined=None):
48 | from .emoji import upsert_emoji
49 | from .users import upsert_user
50 | from .channels import upsert_channel
51 |
52 | new_guild = Guild.from_attrs(guild, is_currently_joined=is_currently_joined)
53 |
54 | # TODO: calculate diff
55 | existing_guild = await conn.fetchrow(
56 | build_select_query(new_guild, "id = $1"), new_guild.id
57 | )
58 |
59 | query, args = build_insert_query(new_guild, upsert=True)
60 | await conn.execute(query, *args)
61 |
62 | if existing_guild is not None:
63 | existing_guild = Guild.from_record(existing_guild)
64 | diff = list(new_guild.diff(existing_guild))
65 | if diff:
66 | print(f"[guilds] diff is {diff}")
67 |
68 | for channel in guild.channels:
69 | await upsert_channel(channel, conn=conn)
70 |
71 | for emoji in guild.emojis:
72 | await upsert_emoji(emoji, conn=conn)
73 |
74 | for member in guild.members:
75 | await upsert_user(member, conn=conn)
76 |
--------------------------------------------------------------------------------
/abode/db/emoji.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | from datetime import datetime
4 | from . import (
5 | with_conn,
6 | build_insert_query,
7 | build_select_query,
8 | JSONB,
9 | Snowflake,
10 | BaseModel,
11 | )
12 | from .guilds import Guild
13 |
14 |
15 | @dataclass
16 | class Emoji(BaseModel):
17 | id: Snowflake
18 | guild_id: Snowflake
19 | author_id: Optional[Snowflake]
20 | name: str
21 | require_colons: bool
22 | animated: bool
23 | managed: bool
24 | roles: JSONB
25 | created_at: datetime
26 | deleted: bool
27 |
28 | _pk = "id"
29 | _table_name = "emoji"
30 | _refs = {"guild": (Guild, ("guild_id", "id"), True)}
31 | _virtual_fields = {"image": ("id", "animated"), "image_url": ("id", "animated")}
32 |
33 | @staticmethod
34 | def image(id, animated):
35 | # TODO: leaky, idk how to avoid
36 | ext = "gif" if animated else "png"
37 | return f'
'
38 |
39 | @staticmethod
40 | def image_url(id, animated):
41 | ext = "gif" if animated else "png"
42 | return f"https://cdn.discordapp.com/emojis/{id}.{ext}"
43 |
44 | @classmethod
45 | def from_discord(cls, emoji, deleted=False):
46 | return cls(
47 | id=emoji.id,
48 | guild_id=emoji.guild.id,
49 | author_id=emoji.user.id if emoji.user else None,
50 | name=emoji.name,
51 | require_colons=bool(emoji.require_colons),
52 | managed=bool(emoji.managed),
53 | animated=bool(emoji.animated),
54 | roles=[str(i.id) for i in emoji.roles],
55 | created_at=emoji.created_at,
56 | deleted=deleted,
57 | )
58 |
59 |
60 | @with_conn
61 | async def upsert_emoji(conn, emoji):
62 | emoji = Emoji.from_discord(emoji)
63 |
64 | # TODO: calculate diff
65 | existing_emoji = await conn.fetchrow(build_select_query(emoji, "id = $1"), emoji.id)
66 |
67 | query, args = build_insert_query(emoji, upsert=True)
68 | await conn.execute(query, *args)
69 |
70 | if existing_emoji is not None:
71 | existing_emoji = Emoji.from_record(existing_emoji)
72 | diff = list(emoji.diff(existing_emoji))
73 | if diff:
74 | print(f"[emoji] diff is {diff}")
75 |
--------------------------------------------------------------------------------
/abode/events.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from abode.db.guilds import upsert_guild
3 | from abode.db.messages import insert_message
4 | from abode.db.channels import upsert_channel
5 | from abode.backfill import backfill_channel, backfill_guild
6 |
7 |
8 | async def backfill(client, message, args):
9 | snowflake = int(args)
10 | channel = client.get_channel(snowflake)
11 | if not channel:
12 | user = client.get_user(snowflake)
13 | channel = user.dm_channel or await user.create_dm()
14 |
15 | if channel:
16 | await message.add_reaction(client.get_emoji(580596825128697874))
17 | await backfill_channel(channel)
18 | else:
19 | await message.add_reaction(client.get_emoji(494901623731126272))
20 |
21 |
22 | async def backfillg(client, message, args):
23 | guild = client.get_guild(int(args))
24 | if guild:
25 | await message.add_reaction(client.get_emoji(580596825128697874))
26 | await backfill_guild(guild)
27 | else:
28 | await message.add_reaction(client.get_emoji(494901623731126272))
29 |
30 |
31 | async def backfilldms(client, message, args):
32 | await message.add_reaction(client.get_emoji(580596825128697874))
33 | for channel in client.private_channels:
34 | await backfill_channel(channel)
35 |
36 |
37 | commands = {"backfill": backfill, "backfillg": backfillg, "backfilldms": backfilldms}
38 |
39 |
40 | async def on_ready(client):
41 | print("Connected!")
42 |
43 | await asyncio.wait([upsert_channel(channel) for channel in client.private_channels])
44 |
45 | await asyncio.wait(
46 | [upsert_guild(guild, is_currently_joined=True) for guild in client.guilds]
47 | )
48 |
49 |
50 | async def on_guild_join(client, guild):
51 | await upsert_guild(guild, is_currently_joined=True)
52 |
53 |
54 | async def on_guild_update(client, old, new):
55 | await upsert_guild(new, is_currently_joined=True)
56 |
57 |
58 | async def on_guild_remove(client, guild):
59 | await upsert_guild(guild, is_currently_joined=False)
60 |
61 |
62 | async def on_message(client, message):
63 | await insert_message(message)
64 |
65 | if message.author.id == client.user.id:
66 | if message.content.startswith(";"):
67 | command, args = message.content.split(" ")
68 | fn = commands.get(command[1:])
69 | if fn:
70 | await fn(client, message, args)
71 |
--------------------------------------------------------------------------------
/abode/server.py:
--------------------------------------------------------------------------------
1 | import time
2 | from sanic import Sanic
3 | from sanic.response import json
4 | from abode.lib.query import compile_query, decode_query_results
5 | from abode.db.guilds import Guild
6 | from abode.db.messages import Message
7 | from abode.db.emoji import Emoji
8 | from abode.db.users import User
9 | from abode.db.channels import Channel
10 | from abode.db import get_pool
11 | from traceback import format_exc
12 |
13 | app = Sanic()
14 | app.static("/", "./frontend/index.html")
15 | app.static("/styles.css", "./frontend/styles.css")
16 | app.static("/script.js", "./frontend/script.js")
17 | app.static("/templates/", "./frontend/templates")
18 |
19 |
20 | SUPPORTED_MODELS = {
21 | "guild": Guild,
22 | "message": Message,
23 | "emoji": Emoji,
24 | "user": User,
25 | "channel": Channel,
26 | }
27 |
28 |
29 | def setup_server(config):
30 | return app.create_server(
31 | host=config.get("host", "0.0.0.0"),
32 | port=config.get("port", 9999),
33 | return_asyncio_server=True,
34 | )
35 |
36 |
37 | @app.route("/search/", methods=["POST"])
38 | async def route_search(request, model):
39 | model = SUPPORTED_MODELS.get(model)
40 | if not model:
41 | return json({"error": "unsupported model"}, status=404)
42 |
43 | limit = request.json.get("limit", 100)
44 | page = request.json.get("page", 1)
45 | order_by = request.json.get("order_by")
46 | order_dir = request.json.get("order_dir", "ASC")
47 | include_foreign_data = request.json.get("foreign_data", True)
48 |
49 | query = request.json.get("query", "")
50 | try:
51 | sql, args, models, return_fields = compile_query(
52 | query,
53 | model,
54 | limit=limit,
55 | offset=(limit * (page - 1)),
56 | order_by=order_by,
57 | order_dir=order_dir,
58 | include_foreign_data=include_foreign_data,
59 | returns=True,
60 | )
61 | except Exception:
62 | return json({"error": format_exc()})
63 |
64 | _debug = {
65 | "args": args,
66 | "sql": sql,
67 | "request": request.json,
68 | "models": [i.__name__ for i in models],
69 | }
70 |
71 | results = []
72 | try:
73 | async with get_pool().acquire() as conn:
74 | start = time.time()
75 | results = await conn.fetch(sql, *args)
76 | _debug["ms"] = int((time.time() - start) * 1000)
77 | except Exception:
78 | return json({"error": format_exc(), "_debug": _debug})
79 |
80 | try:
81 | results, field_names = decode_query_results(models, return_fields, results)
82 | return json({"results": results, "fields": field_names, "_debug": _debug})
83 | except Exception:
84 | return json({"error": format_exc(), "_debug": _debug})
85 |
--------------------------------------------------------------------------------
/abode/db/messages.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, fields
2 | from datetime import datetime
3 | from typing import Optional
4 | from . import (
5 | with_conn,
6 | build_insert_query,
7 | convert_to_type,
8 | JSONB,
9 | Snowflake,
10 | BaseModel,
11 | )
12 | from .users import User, upsert_user
13 | from .guilds import Guild
14 | from .channels import Channel
15 |
16 |
17 | @dataclass
18 | class Message(BaseModel):
19 | id: Snowflake
20 | channel_id: Snowflake
21 | guild_id: Optional[Snowflake]
22 | author_id: Optional[Snowflake]
23 | webhook_id: Optional[Snowflake]
24 | tts: bool
25 | type: int
26 | content: str
27 | embeds: Optional[JSONB]
28 | mention_everyone: bool
29 | flags: int
30 | activity: Optional[JSONB]
31 | application: Optional[JSONB]
32 | created_at: datetime
33 | edited_at: Optional[datetime]
34 | deleted: bool
35 |
36 | # TODO: eventually these could be types, but I am far too baked for that
37 | # refactor at the moment.
38 | _pk = "id"
39 | _refs = {
40 | "guild": (Guild, ("guild_id", "id"), False),
41 | "author": (User, ("author_id", "id"), True),
42 | "channel": (Channel, ["channel_id", "id"], True),
43 | }
44 | _fts = {"content"}
45 |
46 | @classmethod
47 | def from_discord(cls, message, deleted=False):
48 | return cls(
49 | id=message.id,
50 | guild_id=message.guild.id if message.guild else None,
51 | channel_id=message.channel.id,
52 | author_id=message.author.id,
53 | webhook_id=message.webhook_id if message.webhook_id else None,
54 | tts=bool(message.tts),
55 | type=message.type.value,
56 | content=message.content,
57 | embeds=[i.to_dict() for i in message.embeds],
58 | mention_everyone=bool(message.mention_everyone),
59 | flags=message.flags.value,
60 | activity=message.activity,
61 | application=message.application,
62 | created_at=message.created_at,
63 | edited_at=message.edited_at,
64 | deleted=deleted,
65 | )
66 |
67 | @classmethod
68 | def from_attrs(cls, instance, deleted=None):
69 | kwargs = {}
70 | for field in fields(cls):
71 | kwargs[field.name] = convert_to_type(
72 | getattr(instance, field.name), field.type
73 | )
74 | if deleted:
75 | kwargs["deleted"] = deleted
76 | return cls(**kwargs)
77 |
78 |
79 | @with_conn
80 | async def insert_message(conn, message):
81 | new_message = Message.from_discord(message)
82 |
83 | query, args = build_insert_query(new_message, ignore_existing=True)
84 | try:
85 | await conn.execute(query, *args)
86 | except Exception:
87 | print(query)
88 | print(args)
89 | raise
90 |
91 | await upsert_user(message.author, conn=conn)
92 |
93 |
94 | @with_conn
95 | async def update_message(conn, message):
96 | pass
97 |
--------------------------------------------------------------------------------
/abode/db/channels.py:
--------------------------------------------------------------------------------
1 | import discord
2 | from dataclasses import dataclass
3 | from typing import Optional, List
4 | from .guilds import Guild
5 | from .users import User
6 | from . import (
7 | with_conn,
8 | build_insert_query,
9 | build_select_query,
10 | JSONB,
11 | Snowflake,
12 | BaseModel,
13 | )
14 |
15 |
16 | @dataclass
17 | class Channel(BaseModel):
18 | id: Snowflake
19 | type: int
20 |
21 | name: Optional[str] = None
22 | topic: Optional[str] = None
23 |
24 | # Guild Specific
25 | guild_id: Optional[Snowflake] = None
26 | category_id: Optional[Snowflake] = None
27 | position: Optional[int] = None
28 | slowmode_delay: Optional[int] = None
29 | overwrites: Optional[JSONB] = None
30 |
31 | # Voice Specific
32 | bitrate: Optional[int] = None
33 | user_limit: Optional[int] = None
34 |
35 | # DMs
36 | recipients: Optional[JSONB[List[str]]] = None
37 | owner_id: Optional[Snowflake] = None
38 | icon: Optional[str] = None
39 |
40 | _pk = "id"
41 | _refs = {
42 | "guild": (Guild, ("guild_id", "id"), False),
43 | "owner": (User, ("owner_id", "id"), False),
44 | }
45 | _fts = set()
46 |
47 | @classmethod
48 | def from_discord(cls, channel):
49 | inst = cls(id=channel.id, type=channel.type.value)
50 |
51 | if isinstance(
52 | channel,
53 | (discord.TextChannel, discord.VoiceChannel, discord.CategoryChannel),
54 | ):
55 | inst.guild_id = channel.guild.id
56 | inst.name = channel.name
57 | inst.category_id = channel.category_id
58 | inst.position = channel.position
59 | inst.overwrites = {
60 | str(k.id): [i.value for i in v.pair()]
61 | for k, v in channel.overwrites.items()
62 | }
63 |
64 | if isinstance(channel, discord.TextChannel):
65 | inst.slowmode_delay = channel.slowmode_delay
66 | inst.topic = channel.topic
67 | elif isinstance(channel, discord.VoiceChannel):
68 | inst.bitrate = channel.bitrate
69 | inst.user_limit = channel.user_limit
70 | elif isinstance(channel, discord.DMChannel):
71 | inst.recipients = [channel.recipient.id]
72 | elif isinstance(channel, discord.GroupChannel):
73 | inst.recipients = [i.id for i in channel.recipients]
74 | inst.owner_id = channel.owner.id
75 | inst.icon = channel.icon
76 | inst.name = channel.name
77 |
78 | return inst
79 |
80 |
81 | @with_conn
82 | async def upsert_channel(conn, channel):
83 | new_channel = Channel.from_discord(channel)
84 |
85 | # TODO: calculate diff
86 | existing_channel = await conn.fetchrow(
87 | build_select_query(new_channel, "id = $1"), channel.id
88 | )
89 |
90 | query, args = build_insert_query(new_channel, upsert=True)
91 | await conn.execute(query, *args)
92 |
93 | if existing_channel is not None:
94 | existing_channel = Channel.from_record(existing_channel)
95 | diff = list(new_channel.diff(existing_channel))
96 | if diff:
97 | print(f"[channel] diff is {diff}")
98 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # abode
2 |
3 | abode is a self-hosted home server which aggregates your discord data in a discoverable format.
4 |
5 | ## status
6 |
7 | abode is in early personal testing but I've already found it invaluable for easily querying things. The motivated individual could figure out how to properly run a version locally (tl;dr: postgres database, config with postgres dsn and token, run one client and one api), but I'm not currently planning to extensively document the process (you should understand it to use it).
8 |
9 | ## ok but why
10 |
11 | Discord in general is only relatively friendly to data export as evident from the [Discord Data Package](https://support.discordapp.com/hc/en-us/articles/360004957991-Your-Discord-Data-Package), which lacks messages not sent by you, attachments, and various other relevant data (take it from me, I built the original system while there). The data package also serves as a poor real time solution due to its multi-week delay in between requesting packages, and the long generation time of each individual package. abode fills in for this lack of flexibility and features on Discords side by tracking, storing and indexing your data. abode was designed to be used by individuals for tracking only data that "belongs" (any data discord makes visible to the end user) to them, and thus won't work for large-scale or wide berth data-mining.
12 |
13 | The explanation is all well and good, but it doesn't really explain why you would even _want_ your data archived locally. While different folks can have different reasons, the primary advantages to running abode come from:
14 |
15 | - data backup (for when sinister deletes all the channels again)
16 | - fast and powerful querying (for finding **anything** quickly and easily)
17 | - audit log (for when things happened more than 90 days ago)
18 | - data insights (for understanding how you use and experience discord)
19 |
20 | ## how
21 |
22 | abode includes a client component which connects to discord on behalf of your user (with your token, so this is considered ["self botting"](https://support.discordapp.com/hc/en-us/articles/115002192352-Automated-user-accounts-self-bots-) by discord, and could get your account banned), an api server which can process queries written in a custom-DSL, and a lightweight frontend for querying. abodes primary use is for power-user level querying of its internal data archive, but it can also be used as a generic data API for other applications.
23 |
24 | ## query language
25 |
26 | abode is built around a custom search query language, which is intentionally simple but can be quite powerful when used correctly. The real power of abodes queries comes directly from postgres, where abode leans on full text search and unique indexes to keep query response time fast.
27 |
28 | | Syntax | Meaning |
29 | |--------|---------|
30 | | field:value | fuzzy match of value |
31 | | field:"value" | case insensitive exact match of value |
32 | | field="value" | case sensitive exact match of value |
33 | | field:(x y) | fuzzy match of x and y |
34 | | field:(x OR y) | fuzzy match of x or y |
35 | | field:x AND NOT field:y | fuzzy match of x and not y |
36 | | (field:a AND field:b) OR (field:c AND field:d) | fuzzy match of a and b or c and d |
37 | | -> x y z | select fields x, y, and z |
38 |
39 | ## screenshots
40 |
41 | 
42 | 
43 | 
--------------------------------------------------------------------------------
/abode/db/__init__.py:
--------------------------------------------------------------------------------
1 | import asyncpg
2 | import functools
3 | import os
4 | import dataclasses
5 | import json
6 | import typing
7 | from datetime import datetime
8 |
9 | pool = None
10 |
11 |
12 | T = typing.TypeVar("T")
13 |
14 |
15 | class JSONB(typing.Generic[T]):
16 | inner: T
17 |
18 |
19 | class FTS:
20 | def __init__(self, inner):
21 | self.inner = inner
22 |
23 |
24 | def Snowflake(i):
25 | return int(i)
26 |
27 |
28 | def to_json_str(obj):
29 | if isinstance(obj, str):
30 | return obj
31 | return json.dumps(obj)
32 |
33 |
34 | async def init_db(config, loop):
35 | global pool
36 |
37 | pool = await asyncpg.create_pool(dsn=config.get("postgres_dsn"))
38 |
39 | async with pool.acquire() as connection:
40 | sql_dir = os.path.abspath(
41 | os.path.join(os.path.dirname(__file__), "..", "schema")
42 | )
43 | for sql_file in os.listdir(sql_dir):
44 | with open(os.path.join(sql_dir, sql_file), "r") as f:
45 | await connection.execute(f.read())
46 |
47 |
48 | async def close_db():
49 | await get_pool().close()
50 |
51 |
52 | def get_pool():
53 | return pool
54 |
55 |
56 | def with_conn(func):
57 | @functools.wraps(func)
58 | async def wrapped(*args, **kwargs):
59 | if "conn" in kwargs:
60 | return await func(kwargs.pop("conn"), *args, **kwargs)
61 |
62 | async with pool.acquire() as connection:
63 | return await func(connection, *args, **kwargs)
64 |
65 | return wrapped
66 |
67 |
68 | def build_insert_query(instance, upsert=False, ignore_existing=False):
69 | dataclass = instance.__class__
70 |
71 | column_names = []
72 | column_values = []
73 | updates = []
74 | for field in dataclasses.fields(dataclass):
75 | column_names.append(field.name)
76 | column_values.append(
77 | convert_to_type(getattr(instance, field.name), field.type, to_pg=True)
78 | )
79 |
80 | if upsert:
81 | updates.append(f"{field.name}=excluded.{field.name}")
82 | values = ", ".join([f"${i}" for i in range(1, len(column_names) + 1)])
83 |
84 | upsert_contents = ""
85 | if upsert:
86 | updates = ",\n".join(updates)
87 | upsert_contents = f"""
88 | ON CONFLICT (id) DO UPDATE SET
89 | {updates}
90 | """
91 |
92 | if ignore_existing:
93 | assert not upsert
94 | upsert_contents = """
95 | ON CONFLICT (id) DO NOTHING
96 | """
97 |
98 | return (
99 | f"""
100 | INSERT INTO {table_name(dataclass)} ({', '.join(column_names)})
101 | VALUES ({values})
102 | {upsert_contents}
103 | """,
104 | tuple(column_values),
105 | )
106 |
107 |
108 | def build_select_query(instance, where=None):
109 | dataclass = instance.__class__
110 | select_fields = ", ".join([field.name for field in dataclasses.fields(dataclass)])
111 | where = f"WHERE {where}" if where else ""
112 |
113 | return f"""
114 | SELECT {select_fields} FROM {table_name(dataclass)}
115 | {where}
116 | """
117 |
118 |
119 | def convert_to_type(value, target_type, to_pg=False, from_pg=False, to_js=False):
120 | if typing.get_origin(target_type) is typing.Union:
121 | if type(None) in typing.get_args(target_type):
122 | if value is None:
123 | return None
124 | target_type = next(
125 | i for i in typing.get_args(target_type) if i is not type(None)
126 | )
127 | else:
128 | assert False
129 |
130 | if (
131 | to_js
132 | and target_type == Snowflake
133 | or (target_type == typing.Optional[Snowflake] and value is not None)
134 | ):
135 | return str(value)
136 |
137 | if (
138 | to_js
139 | and target_type == datetime
140 | or (target_type == typing.Optional[datetime] and value is not None)
141 | ):
142 | return value.isoformat()
143 |
144 | if typing.get_origin(target_type) == list:
145 | return list(value)
146 |
147 | if type(value) == target_type:
148 | return value
149 |
150 | if to_pg and typing.get_origin(target_type) == JSONB:
151 | return json.dumps(value)
152 |
153 | if from_pg and typing.get_origin(target_type) == JSONB:
154 | return json.loads(value)
155 |
156 | try:
157 | return target_type(value)
158 | except Exception:
159 | print(type(value))
160 | print(target_type)
161 | print(typing.get_origin(target_type))
162 | raise
163 |
164 |
165 | def table_name(model):
166 | return getattr(model, "_table_name", model.__name__.lower() + "s")
167 |
168 |
169 | class BaseModel:
170 | _refs = {}
171 | _fts = set()
172 | _virtual_fields = {}
173 |
174 | def serialize(self, **kwargs):
175 | return {
176 | field.name: convert_to_type(
177 | getattr(self, field.name), field.type, to_js=True
178 | )
179 | for field in dataclasses.fields(self)
180 | }
181 |
182 | def diff(self, other):
183 | for field in dataclasses.fields(self):
184 | if getattr(other, field.name) != getattr(self, field.name):
185 | yield {
186 | "field": field.name,
187 | "old": getattr(other, field.name),
188 | "new": getattr(self, field.name),
189 | }
190 |
191 | @classmethod
192 | def from_record(cls, record):
193 | return cls(
194 | **{
195 | field.name: convert_to_type(record[idx], field.type, from_pg=True)
196 | for idx, field in enumerate(dataclasses.fields(cls))
197 | }
198 | )
199 |
--------------------------------------------------------------------------------
/abode/lib/tests/test_query.py:
--------------------------------------------------------------------------------
1 | from abode.lib.query import QueryParser, compile_query, _compile_selector
2 | from abode.db.guilds import Guild
3 | from abode.db.messages import Message
4 | from abode.db.users import User
5 | from abode.db.channels import Channel
6 |
7 |
8 | def test_parse_basic_queries():
9 | assert QueryParser.parsed("hello world") == [
10 | {"type": "symbol", "value": "hello"},
11 | {"type": "symbol", "value": "AND"},
12 | {"type": "symbol", "value": "world"},
13 | ]
14 |
15 | assert QueryParser.parsed('"Hello \\" World"') == [
16 | {"type": "string", "value": 'Hello " World'}
17 | ]
18 |
19 | assert QueryParser.parsed("(group me daddy)") == [
20 | {
21 | "type": "group",
22 | "value": [
23 | {"type": "symbol", "value": "group"},
24 | {"type": "symbol", "value": "AND"},
25 | {"type": "symbol", "value": "me"},
26 | {"type": "symbol", "value": "AND"},
27 | {"type": "symbol", "value": "daddy"},
28 | ],
29 | }
30 | ]
31 |
32 | assert QueryParser.parsed("x:y") == [
33 | {
34 | "type": "label",
35 | "name": "x",
36 | "value": {"type": "symbol", "value": "y"},
37 | "exact": False,
38 | }
39 | ]
40 |
41 | assert QueryParser.parsed("x=y") == [
42 | {
43 | "type": "label",
44 | "name": "x",
45 | "value": {"type": "symbol", "value": "y"},
46 | "exact": True,
47 | }
48 | ]
49 |
50 | assert QueryParser.parsed("x:(y z)") == [
51 | {
52 | "type": "label",
53 | "name": "x",
54 | "value": {
55 | "type": "group",
56 | "value": [
57 | {"type": "symbol", "value": "y"},
58 | {"type": "symbol", "value": "AND"},
59 | {"type": "symbol", "value": "z"},
60 | ],
61 | },
62 | "exact": False,
63 | }
64 | ]
65 |
66 | assert QueryParser.parsed("x:/.* lol \\d me daddy/") == [
67 | {
68 | "type": "label",
69 | "name": "x",
70 | "value": {"type": "regex", "value": ".* lol \\d me daddy", "flags": []},
71 | "exact": False,
72 | }
73 | ]
74 |
75 | assert QueryParser.parsed("x:/.* lol \\d me daddy/i") == [
76 | {
77 | "type": "label",
78 | "name": "x",
79 | "value": {"type": "regex", "value": ".* lol \\d me daddy", "flags": ["i"]},
80 | "exact": False,
81 | }
82 | ]
83 |
84 | assert QueryParser.parsed("-> a b c") == [
85 | {
86 | "type": "return",
87 | "value": [
88 | {"type": "symbol", "value": "a"},
89 | {"type": "symbol", "value": "b"},
90 | {"type": "symbol", "value": "c"},
91 | ],
92 | }
93 | ]
94 |
95 | assert QueryParser.parsed("x:y -> a b c") == [
96 | {
97 | "type": "label",
98 | "name": "x",
99 | "value": {"type": "symbol", "value": "y"},
100 | "exact": False,
101 | },
102 | {
103 | "type": "return",
104 | "value": [
105 | {"type": "symbol", "value": "a"},
106 | {"type": "symbol", "value": "b"},
107 | {"type": "symbol", "value": "c"},
108 | ],
109 | },
110 | ]
111 |
112 |
113 | def test_parse_complex_queries():
114 | assert QueryParser.parsed(
115 | 'type:attachment guild:"discord api" (from:Jake#0001 OR from=danny#0007)'
116 | ) == [
117 | {
118 | "type": "label",
119 | "name": "type",
120 | "value": {"type": "symbol", "value": "attachment"},
121 | "exact": False,
122 | },
123 | {"type": "symbol", "value": "AND"},
124 | {
125 | "type": "label",
126 | "name": "guild",
127 | "value": {"type": "string", "value": "discord api"},
128 | "exact": False,
129 | },
130 | {"type": "symbol", "value": "AND"},
131 | {
132 | "type": "group",
133 | "value": [
134 | {
135 | "type": "label",
136 | "name": "from",
137 | "value": {"type": "symbol", "value": "Jake#0001"},
138 | "exact": False,
139 | },
140 | {"type": "symbol", "value": "OR"},
141 | {
142 | "type": "label",
143 | "name": "from",
144 | "value": {"type": "symbol", "value": "danny#0007"},
145 | "exact": True,
146 | },
147 | ],
148 | },
149 | ]
150 |
151 |
152 | def test_compile_basic_queries():
153 | assert compile_query("name:blob", Guild) == (
154 | "SELECT guilds.* FROM guilds WHERE guilds.name ILIKE $1",
155 | ("%blob%",),
156 | (Guild,),
157 | )
158 |
159 | assert compile_query('name:"blob"', Guild) == (
160 | "SELECT guilds.* FROM guilds WHERE guilds.name ILIKE $1",
161 | ("blob",),
162 | (Guild,),
163 | )
164 |
165 | assert compile_query("name:(blob emoji)", Guild) == (
166 | "SELECT guilds.* FROM guilds WHERE (guilds.name ILIKE $1 AND guilds.name ILIKE $2)",
167 | ("%blob%", "%emoji%",),
168 | (Guild,),
169 | )
170 |
171 | assert compile_query("name:(blob AND emoji)", Guild) == (
172 | "SELECT guilds.* FROM guilds WHERE (guilds.name ILIKE $1 AND guilds.name ILIKE $2)",
173 | ("%blob%", "%emoji%",),
174 | (Guild,),
175 | )
176 |
177 | assert compile_query("name:(discord AND NOT api)", Guild) == (
178 | "SELECT guilds.* FROM guilds WHERE (guilds.name ILIKE $1 AND NOT guilds.name ILIKE $2)",
179 | ("%discord%", "%api%",),
180 | (Guild,),
181 | )
182 |
183 | assert compile_query("id:1", Guild) == (
184 | "SELECT guilds.* FROM guilds WHERE guilds.id = $1",
185 | (1,),
186 | (Guild,),
187 | )
188 |
189 | assert compile_query("", Guild, limit=100, offset=150, order_by="id") == (
190 | "SELECT guilds.* FROM guilds ORDER BY guilds.id ASC LIMIT 100 OFFSET 150",
191 | (),
192 | (Guild,),
193 | )
194 |
195 | assert compile_query("", Guild, order_by="id", order_dir="DESC") == (
196 | "SELECT guilds.* FROM guilds ORDER BY guilds.id DESC",
197 | (),
198 | (Guild,),
199 | )
200 |
201 | assert compile_query("id=1", Guild) == (
202 | "SELECT guilds.* FROM guilds WHERE guilds.id = $1",
203 | (1,),
204 | (Guild,),
205 | )
206 |
207 |
208 | def test_compile_complex_queries():
209 | assert compile_query("name:blob OR name:api", Guild) == (
210 | "SELECT guilds.* FROM guilds WHERE guilds.name ILIKE $1 OR guilds.name ILIKE $2",
211 | ("%blob%", "%api%"),
212 | (Guild,),
213 | )
214 |
215 | assert compile_query("guild.name:blob", Message) == (
216 | "SELECT messages.* FROM messages JOIN guilds ON messages.guild_id = guilds.id WHERE guilds.name ILIKE $1",
217 | ("%blob%",),
218 | (Message,),
219 | )
220 |
221 | assert compile_query("content:yeet", Message) == (
222 | "SELECT messages.* FROM messages WHERE to_tsvector('english', messages.content) @@ phraseto_tsquery($1)",
223 | ("yeet",),
224 | (Message,),
225 | )
226 |
227 | assert compile_query('guild.name:(a "b")', Message) == (
228 | "SELECT messages.* FROM messages JOIN guilds ON messages.guild_id = guilds.id WHERE (guilds.name ILIKE $1 AND "
229 | "guilds.name ILIKE $2)",
230 | ("%a%", "b"),
231 | (Message,),
232 | )
233 |
234 | assert compile_query("guild.owner.name:Danny", Message) == (
235 | "SELECT messages.* FROM messages JOIN guilds ON messages.guild_id = guilds.id JOIN users ON "
236 | "guilds.owner_id = users.id WHERE users.name ILIKE $1",
237 | ("%Danny%",),
238 | (Message,),
239 | )
240 |
241 | message_selector = _compile_selector(Message)
242 | guild_selector = _compile_selector(Guild)
243 | author_selector = _compile_selector(User)
244 | channel_selector = _compile_selector(Channel)
245 |
246 | assert compile_query("", Message, include_foreign_data=True) == (
247 | f"SELECT {message_selector}, {author_selector}, {channel_selector} FROM messages "
248 | "JOIN users ON messages.author_id = users.id JOIN channels ON "
249 | "messages.channel_id = channels.id",
250 | (),
251 | (Message, User, Channel),
252 | )
253 |
254 | assert compile_query("guild.id:1", Message, include_foreign_data=True) == (
255 | f"SELECT {message_selector}, {guild_selector}, {author_selector}, {channel_selector} FROM messages JOIN guilds"
256 | " ON messages.guild_id = guilds.id JOIN users ON messages.author_id = users.id JOIN channels ON "
257 | "messages.channel_id = channels.id WHERE guilds.id = $1",
258 | (1,),
259 | (Message, Guild, User, Channel),
260 | )
261 |
262 | assert compile_query("name: /xxx.*xxx/i", Guild) == (
263 | f"SELECT guilds.* FROM guilds WHERE guilds.name ~* $1",
264 | ("xxx.*xxx",),
265 | (Guild,),
266 | )
267 |
268 | guild_selector = _compile_selector(Guild)
269 | user_selector = _compile_selector(User)
270 | assert compile_query(
271 | "name: /xxx.*xxx/i -> name owner.name", Guild, returns=True
272 | ) == (
273 | f"SELECT {guild_selector}, {user_selector} FROM guilds JOIN users ON guilds.owner_id = users.id WHERE guilds.name ~* $1",
274 | ("xxx.*xxx",),
275 | (Guild, User),
276 | ("name", "owner.name"),
277 | )
278 |
279 | assert compile_query("name: /xxx.*xxx/i ->", Guild, returns=True) == (
280 | f"SELECT guilds.* FROM guilds WHERE guilds.name ~* $1",
281 | ("xxx.*xxx",),
282 | (Guild,),
283 | (),
284 | )
285 |
286 |
--------------------------------------------------------------------------------
/abode/lib/query.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements everything required for processing and executing user-written
3 | queries in our custom search DSL. As the language is small and simple enough there
4 | is no real concept of a lexer, and instead we directly parse search queries into
5 | an extremely simple AST.
6 |
7 | From an AST we can then generate a query based on a given model and query options.
8 | This query is returned in a form that can be easily passed to SQL interfaces
9 | (query, (args...)).
10 | """
11 | import dataclasses
12 | import typing
13 | from abode.db import table_name, JSONB, FTS, Snowflake
14 |
15 | JOINERS = ("AND", "OR")
16 |
17 |
18 | class QueryParser:
19 | def __init__(self, query_string):
20 | self.idx = 0
21 | self.buffer = query_string
22 |
23 | @classmethod
24 | def parsed(cls, query):
25 | inst = cls(query)
26 | return inst._fix(inst._parse())
27 |
28 | def _next_char(self):
29 | if self.idx >= len(self.buffer):
30 | return None
31 | char = self.buffer[self.idx]
32 | self.idx += 1
33 | return char
34 |
35 | def _peek_char(self, n=0):
36 | if self.idx + n >= len(self.buffer):
37 | return None
38 | return self.buffer[self.idx + n]
39 |
40 | def _parse_string(self, chr='"'):
41 | escaped = False
42 | parts = ""
43 | while True:
44 | char = self._next_char()
45 | assert char is not None
46 | if char == chr:
47 | if escaped:
48 | escaped = False
49 | else:
50 | return parts
51 | elif char == "\\" and not escaped:
52 | escaped = True
53 | continue
54 | elif escaped:
55 | escaped = False
56 | parts += "\\"
57 | parts += char
58 |
59 | def _parse_symbol(self):
60 | parts = ""
61 | while True:
62 | char = self._peek_char()
63 | if char in (" ", ":", "=", '"', "(", ")", "/", None):
64 | return parts
65 | parts += self._next_char()
66 |
67 | def _parse_one(self):
68 | while True:
69 | char = self._next_char()
70 | if char is None or char == ")":
71 | return None
72 | elif char == '"':
73 | string = self._parse_string()
74 | return {"type": "string", "value": string}
75 | elif char == "(":
76 | return {"type": "group", "value": self._parse()}
77 | elif char in ("-", "=") and self._peek_char() == ">":
78 | self._next_char()
79 | value = self._parse()
80 | if any(i["type"] != "symbol" for i in value):
81 | raise Exception("only symbols are allowed in a returns section")
82 | return {"type": "return", "value": value}
83 | elif char == " ":
84 | continue
85 | elif char == "/":
86 | value = self._parse_string("/")
87 |
88 | flags = []
89 | while self._peek_char() in ("i",):
90 | flags.append(self._next_char())
91 |
92 | return {
93 | "type": "regex",
94 | "value": value,
95 | "flags": flags,
96 | }
97 | else:
98 | self.idx -= 1
99 | symbol = self._parse_symbol()
100 | if not symbol:
101 | return None
102 |
103 | if self._peek_char() in (":", "="):
104 | exact = self._next_char() == "="
105 | return {
106 | "type": "label",
107 | "name": symbol,
108 | "value": self._parse_one(),
109 | "exact": exact,
110 | }
111 |
112 | return {"type": "symbol", "value": symbol}
113 |
114 | def _parse(self):
115 | parts = []
116 |
117 | while True:
118 | one = self._parse_one()
119 | if not one:
120 | return parts
121 | parts.append(one)
122 |
123 | def _fix(self, tree):
124 | result = []
125 | previous_node = None
126 | for node in tree:
127 | if node["type"] == "group":
128 | node["value"] = self._fix(node["value"])
129 | elif node["type"] == "symbol":
130 | if node["value"] == "NOT":
131 | if not previous_node or (
132 | previous_node["type"] == "symbol"
133 | and previous_node["value"] not in JOINERS
134 | ):
135 | raise Exception("NOT requires a joiner prefix (and/or)")
136 | elif node["value"] in JOINERS:
137 | if (
138 | previous_node
139 | and previous_node["type"] == "symbol"
140 | and previous_node["value"] in JOINERS
141 | ):
142 | raise Exception("One side of joiners cannot be another joiner")
143 | elif node["type"] == "label" and node["value"]["type"] == "group":
144 | # TODO: HMMMMM, need _fix_one(node, previous=xxx) ??
145 | node["value"]["value"] = self._fix(node["value"]["value"])
146 |
147 | # Injects 'AND' in between bare non-joiners
148 | if (
149 | (node["type"] != "symbol" or node["value"] not in JOINERS)
150 | and node["type"] != "return"
151 | and previous_node
152 | and (
153 | previous_node["type"] != "symbol"
154 | or previous_node["value"] not in JOINERS + ("NOT",)
155 | )
156 | ):
157 | result.append({"type": "symbol", "value": "AND"})
158 | previous_node = node
159 | result.append(node)
160 | return result
161 |
162 |
163 | def _resolve_foreign_model_field(field_name, model, joins=None):
164 | rest = None
165 | foreign_field_name, field_name = field_name.split(".", 1)
166 | if "." in field_name:
167 | field_name, rest = field_name.split(".", 1)
168 |
169 | ref_model, on, _ = model._refs[foreign_field_name]
170 |
171 | if not joins:
172 | joins = {}
173 |
174 | joins.update(
175 | {ref_model: f"{table_name(model)}.{on[0]} = {table_name(ref_model)}.{on[1]}"}
176 | )
177 |
178 | if rest:
179 | return _resolve_foreign_model_field(field_name + "." + rest, ref_model, joins)
180 |
181 | _, ref_type, _ = resolve_model_field(field_name, ref_model)
182 |
183 | return (f"{table_name(ref_model)}.{field_name}", ref_type, joins)
184 |
185 |
186 | def resolve_model_field(field_name, model, allow_virtual=False):
187 | """
188 | Resolves a field name within a given model. This function will generate joins
189 | for cases where the target field is on a relation or is stored within an
190 | external index.
191 |
192 | The result of this function is a tuple of the target field name, the field
193 | result type, and a dictionary of joins.
194 |
195 | >>> resolve_model_field("guilds.name", Message)
196 | ("guilds.name", str, {"guilds": "messages.guild_id = guilds.id"})
197 |
198 | >>> resolve_model_field("content", Message)
199 | ("messages_fts.content", FTS(str), {"messages_fts": "messages.id = messages_fts.rowid"})
200 | """
201 | if "." in field_name:
202 | return _resolve_foreign_model_field(field_name, model)
203 |
204 | if allow_virtual and field_name in model._virtual_fields:
205 | return (None, None, {})
206 |
207 | for field in dataclasses.fields(model):
208 | if field.name == field_name:
209 | if field.name in model._fts:
210 | return (
211 | f"{table_name(model)}.{field.name}",
212 | FTS(field.type),
213 | {},
214 | )
215 |
216 | return f"{table_name(model)}.{field.name}", field.type, {}
217 | raise Exception(f"no such field on {model}: `{field_name}``")
218 |
219 |
220 | def _compile_field_query_op(field, field_type, token, varidx):
221 | """
222 | Compiles a single token against a given field type into a single query filter.
223 | Sadly this function also encodes some more complex logic about querying, such
224 | as wildcard processing and exact matching.
225 |
226 | Returns a tuple of the filter operator and the processed token value as an
227 | argument to the operator.
228 | """
229 | var = f"${varidx}"
230 | assert token["type"] in ("symbol", "string")
231 |
232 | if typing.get_origin(field_type) is typing.Union:
233 | args = typing.get_args(field_type)
234 | field_type = next(i for i in args if i != type(None))
235 |
236 | if isinstance(field_type, FTS):
237 | if token.get("exact"):
238 | return (field, "=", token["value"], var)
239 | elif token["type"] == "string":
240 | # TODO: jank replace
241 | return (field, "ILIKE", token["value"].replace("*", "%"), var)
242 | return (
243 | f"to_tsvector('english', {field})",
244 | "@@",
245 | token["value"],
246 | f"phraseto_tsquery({var})",
247 | )
248 | elif typing.get_origin(field_type) == JSONB:
249 | (inner,) = typing.get_args(field_type)
250 |
251 | if typing.get_origin(inner) == list:
252 | (type_fn,) = typing.get_args(inner)
253 | return (field, "@>", type_fn(token["value"]), var)
254 |
255 | assert False
256 | elif field_type == Snowflake:
257 | return (field, "=", Snowflake(token["value"]), var)
258 | elif field_type == str or field_type == typing.Optional[str]:
259 | if token.get("exact"):
260 | return (field, "=", token["value"], var)
261 | elif token["type"] == "symbol":
262 | # TODO: regex this so we can handle escapes?
263 | if "*" in token["value"]:
264 | return (field, "ILIKE", token["value"].replace("*", "%"), var)
265 | return (field, "ILIKE", "%" + token["value"] + "%", var)
266 | else:
267 | # Like just gives us case insensitivity here
268 | return (field, "ILIKE", token["value"], var)
269 | elif field_type == int or field_type == typing.Optional[int]:
270 | return (field, "=", int(token["value"]), var)
271 | elif field_type == bool or field_type == typing.Optional[bool]:
272 | return (
273 | field,
274 | "=",
275 | {"t": True, "f": False, "true": True, "false": False}[
276 | token["value"].lower()
277 | ],
278 | var,
279 | )
280 | else:
281 | print(token)
282 | print(field_type)
283 | print(typing.get_origin(field_type))
284 | print(typing.get_args(field_type))
285 | raise Exception(f"cannot query against field of type {field_type}")
286 |
287 |
288 | def _compile_model_refs_join(model, value):
289 | joins = {}
290 | ref_model = model
291 | while value.split(".", 1)[0] in ref_model._refs:
292 | ref_model, join_on, _ = model._refs[value]
293 |
294 | joins.update(
295 | {
296 | ref_model: f"{table_name(model)}.{join_on[0]} = {table_name(ref_model)}.{join_on[1]}"
297 | }
298 | )
299 |
300 | if "." not in value:
301 | return joins
302 |
303 | value = value.split(".", 1)[1]
304 |
305 | raise Exception(f"unlabeled symbol cannot be matched: `{value}`")
306 |
307 |
308 | def _compile_token_for_query(token, model, field=None, field_type=None, varidx=0):
309 | """
310 | Compile a single token into a single filter against the model.
311 |
312 | Returns a tuple of the where clause, variables, and joins.
313 | """
314 |
315 | if token["type"] == "label":
316 | field, field_type, field_joins = resolve_model_field(token["name"], model)
317 | token["value"]["exact"] = token["exact"]
318 | where, variables, joins, varidx, returns = _compile_token_for_query(
319 | token["value"], model, field=field, field_type=field_type, varidx=varidx
320 | )
321 | joins.update(field_joins)
322 | return where, variables, joins, varidx, returns
323 | elif token["type"] == "symbol":
324 | if token["value"] in ("AND", "OR", "NOT"):
325 | return (token["value"], [], {}, varidx, None)
326 | elif field:
327 | varidx += 1
328 | field, op, arg, var = _compile_field_query_op(
329 | field, field_type, token, varidx
330 | )
331 | return (f"{field} {op} {var}", [arg], {}, varidx, None)
332 | else:
333 | joins = _compile_model_refs_join(model, token["value"])
334 | return (f"true", [], joins, varidx, None)
335 | elif token["type"] == "string" and field:
336 | varidx += 1
337 | field, op, arg, var = _compile_field_query_op(field, field_type, token, varidx)
338 | return (f"{field} {op} {var}", [arg], {}, varidx, None)
339 | elif token["type"] == "regex" and field:
340 | varidx += 1
341 | op = "~"
342 | if "i" in token["flags"]:
343 | op = "~*"
344 | return (
345 | f"{field} {op} ${varidx}",
346 | [token["value"]],
347 | {},
348 | varidx,
349 | None,
350 | )
351 | elif token["type"] == "group":
352 | where = []
353 | variables = []
354 | joins = {}
355 | for child_token in token["value"]:
356 | child_token["exact"] = token.get("exact", False)
357 | (
358 | where_part,
359 | variables_part,
360 | joins_part,
361 | varidx,
362 | returns,
363 | ) = _compile_token_for_query(
364 | child_token, model, field=field, field_type=field_type, varidx=varidx
365 | )
366 | joins.update(joins_part)
367 | where.append(where_part)
368 | variables.extend(variables_part)
369 | return ("(" + " ".join(where) + ")", variables, joins, varidx, returns)
370 | elif token["type"] == "return":
371 | return (None, None, None, varidx, [i["value"] for i in token["value"]])
372 | else:
373 | print(token)
374 | assert False
375 |
376 |
377 | def _compile_selector(model):
378 | return ", ".join(
379 | f"{table_name(model)}.{field.name}" for field in dataclasses.fields(model)
380 | )
381 |
382 |
383 | def _compile_query_for_model(
384 | tokens,
385 | model,
386 | limit=None,
387 | offset=None,
388 | order_by=None,
389 | order_dir="ASC",
390 | include_foreign_data=False,
391 | returns=False,
392 | ):
393 | return_fields = None
394 | parts = []
395 | varidx = 0
396 | for token in tokens:
397 | a, b, c, varidx, _returns = _compile_token_for_query(
398 | token, model, varidx=varidx
399 | )
400 | if _returns is not None:
401 | if return_fields is not None:
402 | raise Exception("multiple returns? bad juju!")
403 | return_fields = tuple(_returns)
404 | continue
405 | if a is not None:
406 | parts.append((a, b, c))
407 |
408 | where = []
409 | variables = []
410 | joins = {}
411 |
412 | for where_part, variables_part, joins_part in parts:
413 | where.append(where_part)
414 | variables.extend(variables_part)
415 | joins.update(joins_part)
416 |
417 | if order_by:
418 | field, field_type, order_joins = resolve_model_field(order_by, model)
419 | joins.update(order_joins)
420 | assert order_dir in ("ASC", "DESC")
421 | order_by = f" ORDER BY {field} {order_dir}"
422 | else:
423 | order_by = ""
424 |
425 | models = {model: None}
426 | if return_fields:
427 | for field in return_fields:
428 | _, _, joins_part = resolve_model_field(field, model, allow_virtual=True)
429 | if joins is not None:
430 | joins.update(joins_part)
431 | models.update({k: None for k in joins.keys()})
432 |
433 | if include_foreign_data:
434 | for ref_model, join_on, always in model._refs.values():
435 | if ref_model in joins:
436 | models[ref_model] = None
437 | continue
438 |
439 | if always:
440 | joins.update(
441 | {
442 | ref_model: f"{table_name(model)}.{join_on[0]} = {table_name(ref_model)}.{join_on[1]}"
443 | }
444 | )
445 | models[ref_model] = None
446 |
447 | if len(models) > 1:
448 | selectors = ", ".join(_compile_selector(model) for model in models.keys())
449 | else:
450 | selectors = f"{table_name(model)}.*"
451 |
452 | if joins:
453 | joins = "".join(
454 | f" JOIN {table_name(model)} ON {cond}" for model, cond in joins.items()
455 | )
456 | else:
457 | joins = ""
458 |
459 | if where:
460 | where = " WHERE " + " ".join(where)
461 | else:
462 | where = ""
463 |
464 | suffix = []
465 | if limit is not None and limit > 0:
466 | suffix.append(f" LIMIT {limit}")
467 |
468 | if offset is not None and offset > 0:
469 | suffix.append(f" OFFSET {offset}")
470 |
471 | suffix = "".join(suffix)
472 |
473 | query = (
474 | f"SELECT {selectors} FROM {table_name(model)}{joins}{where}{order_by}{suffix}"
475 | )
476 | variables = tuple(variables)
477 | models = tuple(models.keys())
478 |
479 | if returns:
480 | return query, variables, models, return_fields
481 | return query, variables, models
482 |
483 |
484 | def compile_query(query, model, **kwargs):
485 | tokens = QueryParser.parsed(query)
486 | return _compile_query_for_model(tokens, model, **kwargs)
487 |
488 |
489 | def decode_query_record(record, models):
490 | idx = 0
491 | for model in models:
492 | num_fields = len(dataclasses.fields(model))
493 | model_data = record[idx : idx + num_fields]
494 | idx += num_fields
495 | yield model.from_record(model_data)
496 |
497 |
498 | def _resolve_return_field(model, field):
499 | if "." in field:
500 | field, rest = field.split(".", 1)
501 | model = model._refs[field][0]
502 | return _resolve_return_field(model, rest)
503 | return model, field
504 |
505 |
506 | def _get_field_offset(record_offsets, model, field_name):
507 | fields = dataclasses.fields(model)
508 | offset = record_offsets[model] + [i.name for i in fields].index(field_name)
509 | return offset, next(i for i in fields if i.name == field_name)
510 |
511 |
512 | def decode_query_results(models, return_fields, results):
513 | # TODO: leaky af
514 | from abode.db import convert_to_type
515 |
516 | # I guess why not
517 | if return_fields is None:
518 | return_fields = [i.name for i in dataclasses.fields(models[0])]
519 |
520 | record_offsets = {}
521 | idx = 0
522 | for model in models:
523 | record_offsets[model] = idx
524 | idx += len(dataclasses.fields(model))
525 |
526 | columns = []
527 | for field in return_fields:
528 | model, field = _resolve_return_field(models[0], field)
529 | if field in model._virtual_fields:
530 | offsets = {}
531 | for field_dep in model._virtual_fields[field]:
532 | offset, field_dep = _get_field_offset(record_offsets, model, field_dep)
533 | offsets[offset] = field_dep
534 | columns.append(("virtual", field, offsets, getattr(model, field)))
535 | else:
536 | offset, field = _get_field_offset(record_offsets, model, field)
537 | columns.append(("offset", field, offset))
538 |
539 | rows = []
540 | for result_row in results:
541 | row = []
542 |
543 | for column in columns:
544 | if column[0] == "offset":
545 | field, offset = column[1:]
546 | row.append(
547 | convert_to_type(
548 | convert_to_type(result_row[offset], field.type, from_pg=True),
549 | field.type,
550 | to_js=True,
551 | )
552 | )
553 | elif column[0] == "virtual":
554 | field, offsets, fn = column[1:]
555 | args = [
556 | convert_to_type(
557 | convert_to_type(result_row[offset], field.type, from_pg=True),
558 | field.type,
559 | to_js=True,
560 | )
561 | for offset, field in offsets.items()
562 | ]
563 |
564 | # TODO: typecast based on fn
565 | row.append(fn(*args))
566 |
567 | rows.append(row)
568 |
569 | return rows, return_fields
570 |
--------------------------------------------------------------------------------