├── .github └── workflows │ └── lint.yml ├── .gitignore ├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bot.conf.example ├── bot ├── __init__.py ├── acl.py ├── autoload.py ├── client.py ├── cogs.py ├── commands.py ├── config.py ├── interactions.py ├── main_tasks.py ├── message_tracker.py ├── reactions.py └── tasks.py ├── discord ├── docker-compose.yml ├── docker ├── Dockerfile └── requirements_core.txt ├── log_setup.py ├── main.py ├── migrations ├── plugins.consensus-9043f7f4cbcd154392e59579f64e32418f269840-404a3478aeb168e45611267af447900b1212f9f2.sql ├── plugins.factoids-3809d321fe36da4f650357bf78d02670b8b1a4f2-87a9a7795dabb98da513f93556237eca230119fc.sql ├── plugins.factoids-87a9a7795dabb98da513f93556237eca230119fc-aa369b1260ecbabe80e9a33f62099e7afea2bace.sql ├── plugins.factoids-aa369b1260ecbabe80e9a33f62099e7afea2bace-409508100f9b343fcf5bf512b57cfd12b166f5b8.sql ├── plugins.log-30418475489ba40f09ca4ee0051f09bb8bcf150b-2c5e3570c48030fa81308961343a62f3379c5b6c.sql ├── plugins.log-4809e558c405bbde63a7dc65cbbdf505ebd14cd7-65ed62be1b679154b675097517dd25c7fc27ad8a.sql ├── plugins.log-65ed62be1b679154b675097517dd25c7fc27ad8a-9f4dee1a807ac59d97ed694526009f900c93d22e.sql ├── plugins.log-9f4dee1a807ac59d97ed694526009f900c93d22e-30418475489ba40f09ca4ee0051f09bb8bcf150b.sql ├── plugins.modmail-4f5dea7d15ddaa90bec050ba4d8a6460aefe44b8-6e01739d17b263c9005268bcfe60fea051fbc673.sql ├── plugins.modmail-6e01739d17b263c9005268bcfe60fea051fbc673-8515db1e187aade4392fa2755dc62355808f2e6a.sql ├── plugins.modmail-bf5948e3b8d67d57dfda8e9ccb34a8dd44cf5d39-4f5dea7d15ddaa90bec050ba4d8a6460aefe44b8.sql ├── plugins.persistence-cff14827533eafa51a65684a41e83d93ce27b44c-79d5069c83aabeebf0be39f73e2d8828f9ce2080.sql ├── plugins.roles_review-5305e3042b889ffbd0a5fa96f5bf33e85ad930f9-6705b50d5a48f1ce96f62352069198b94655eba8.sql ├── plugins.roles_review-87439a30c949b57f6a003ae2e2cc626a7cb967a4-5305e3042b889ffbd0a5fa96f5bf33e85ad930f9.sql ├── plugins.tickets-6712376268cc661947b8ed360d1e1c78c33bcc7f-9da100e1067dfc8cab4bdc43a559a1a96907b628.sql ├── plugins.tickets-9da100e1067dfc8cab4bdc43a559a1a96907b628-f69269c5ed1edd9bc84989ab9a218b220dbb4810.sql ├── plugins.tickets-b108465ff07d357fe70fb67e68dbdb592e14a68a-6712376268cc661947b8ed360d1e1c78c33bcc7f.sql ├── plugins.tickets-cecbe3f6862e737833c0252aecaba0fbe79c0463-b108465ff07d357fe70fb67e68dbdb592e14a68a.sql ├── plugins.tickets-db25e78666828bd5fbbbf387796c88a8859b70b8-cecbe3f6862e737833c0252aecaba0fbe79c0463.sql ├── plugins.tickets-f69269c5ed1edd9bc84989ab9a218b220dbb4810-424a6e89e1a95a089f089e92a7f2751416dd575b.sql └── util.db.kv-234193f845d63b034c0d83e99a0282f94278c0fe-9e12e532a2abd0c7ea4b4aff6357b3affd7d8352.sql ├── plugins ├── __init__.py ├── appeals.py ├── automod.py ├── bot_manager.py ├── bulk_perms.py ├── clopen.py ├── consensus.py ├── db_manager.py ├── discord_log.py ├── eval.py ├── factoids.py ├── help.py ├── keepvanity.py ├── log.py ├── modmail.py ├── persistence.py ├── phish.py ├── pins.py ├── reminders.py ├── roleoverride.py ├── rolereactions.py ├── roles_dialog.py ├── roles_review.py ├── tickets.py ├── update.py ├── version.py └── whois.py ├── pyproject.toml ├── requirements.txt ├── requirements_core.txt ├── requirements_linting.txt ├── static_config.py └── util ├── __init__.py ├── asyncio.py ├── db ├── __init__.py ├── dsn.py ├── initialization.py ├── kv │ └── __init__.py └── log.py ├── digraph.py ├── discord.py ├── frozen_dict.py ├── frozen_list.py ├── restart.py └── setup └── __main__.py /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | pyright: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v1 10 | - name: checkout submodules 11 | run: git submodule update --init --recursive 12 | - uses: actions/setup-python@v1 13 | with: 14 | python-version: 3.9 15 | - name: pip 16 | run: pip install -r requirements_core.txt 17 | - run: pyright 18 | format: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v1 22 | - name: checkout submodules 23 | run: git submodule update --init --recursive 24 | - uses: actions/setup-python@v1 25 | with: 26 | python-version: 3.9 27 | - name: pip 28 | run: pip install -r requirements_linting.txt 29 | - name: black 30 | run: black --check . 31 | - name: isort 32 | run: isort --check . 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | */__pycache__ 3 | logs/ 4 | bot.conf 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "discord.py"] 2 | path = discord.py 3 | url = https://github.com/Rapptz/discord.py 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # What Needs Doing 2 | 3 | Feel free to work on minor improvements or quality of life features. Also feel free to take any of the [open issues](https://github.com/discord-math/bot/issues), unless they are marked as "spec needed". 4 | 5 | For broader scope changes such as introducing brand new functionality, you should seek approval of the Meta Council. 6 | 7 | # Tools 8 | 9 | The code is type-checked with [pyright](https://microsoft.github.io/pyright/), and formatted with [black](https://black.readthedocs.io/en/stable/) and [isort](https://isort.readthedocs.io/en/latest/). These are run automatically on every pull request, and you should make sure all of the checks pass. 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, mniip et al. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | -------------------------------------------------------------------------------- /bot.conf.example: -------------------------------------------------------------------------------- 1 | # A config with values that are very unlikely to change during execution 2 | [DB] 3 | # Database connection string, should work with the provided docker image 4 | dsn = host=db user=bot password=bot dbname=discord 5 | # Directory with migration files, should work with the provided docker image 6 | migrations = migrations/ 7 | [Log] 8 | # Log directory 9 | directory = logs/ 10 | [Discord] 11 | # Bot token 12 | token = AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA 13 | -------------------------------------------------------------------------------- /bot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord-math/bot/1231f964ea2d701e18dda7b63d565de43846de9b/bot/__init__.py -------------------------------------------------------------------------------- /bot/autoload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Automatically load certain plugins after bot initialization. 3 | """ 4 | 5 | import logging 6 | from typing import TYPE_CHECKING 7 | 8 | from sqlalchemy import TEXT, BigInteger, select 9 | from sqlalchemy.ext.asyncio import AsyncSession 10 | import sqlalchemy.orm 11 | from sqlalchemy.orm import Mapped, mapped_column 12 | 13 | import bot.main_tasks 14 | import plugins 15 | import util.asyncio 16 | import util.db.kv 17 | 18 | 19 | registry = sqlalchemy.orm.registry() 20 | 21 | 22 | @registry.mapped 23 | class AutoloadedPlugin: 24 | __tablename__ = "autoload" 25 | 26 | name: Mapped[str] = mapped_column(TEXT, primary_key=True) 27 | order: Mapped[int] = mapped_column(BigInteger, nullable=False) 28 | 29 | if TYPE_CHECKING: 30 | 31 | def __init__(self, *, name: str, order: int) -> None: ... 32 | 33 | 34 | logger: logging.Logger = logging.getLogger(__name__) 35 | 36 | 37 | @plugins.init 38 | async def init() -> None: 39 | if (manager := plugins.PluginManager.of(__name__)) is None: 40 | logger.error("No plugin manager") 41 | return 42 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 43 | 44 | async def autoload() -> None: 45 | async with AsyncSession(util.db.engine) as session: 46 | conf = await util.db.kv.load(__name__) 47 | for key in [key for key, in conf]: 48 | session.add(AutoloadedPlugin(name=key, order=0)) 49 | conf[key] = None 50 | await session.commit() 51 | await conf 52 | 53 | stmt = select(AutoloadedPlugin).order_by(AutoloadedPlugin.order) 54 | for plugin in (await session.execute(stmt)).scalars(): 55 | try: 56 | # Sidestep plugin dependency tracking 57 | await manager.load(plugin.name) 58 | except: 59 | logger.critical("Exception during autoload of {}".format(plugin.name), exc_info=True) 60 | 61 | bot.main_tasks.create_task(autoload(), name="Plugin autoload") 62 | -------------------------------------------------------------------------------- /bot/client.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the "client" singleton. Reloading the module restarts the connection to discord. 3 | """ 4 | 5 | import logging 6 | 7 | from discord import AllowedMentions, Intents 8 | from discord.ext.commands import Bot 9 | 10 | import bot.main_tasks 11 | import plugins 12 | import static_config 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | intents = Intents.all() 18 | intents.presences = False 19 | client = Bot( 20 | command_prefix=(), max_messages=None, intents=intents, allowed_mentions=AllowedMentions(everyone=False, roles=False) 21 | ) 22 | 23 | 24 | # Disable command functionality until reenabled again in bot.commands 25 | @client.event 26 | async def on_message(*args: object, **kwargs: object) -> None: 27 | pass 28 | 29 | 30 | @client.event 31 | async def on_error(event: str, *args: object, **kwargs: object) -> None: 32 | logger.error("Uncaught exception in {}".format(event), exc_info=True) 33 | 34 | 35 | async def main_task() -> None: 36 | try: 37 | async with client: 38 | await client.start(static_config.Discord["token"], reconnect=True) 39 | except: 40 | logger.critical("Exception in main Discord task", exc_info=True) 41 | finally: 42 | await client.close() 43 | 44 | 45 | @plugins.init 46 | def init() -> None: 47 | task = bot.main_tasks.create_task(main_task(), name="Discord client") 48 | plugins.finalizer(task.cancel) 49 | -------------------------------------------------------------------------------- /bot/cogs.py: -------------------------------------------------------------------------------- 1 | from typing import Type, TypeVar 2 | 3 | from discord.ext.commands import Cog as Cog, command as command, group as group 4 | 5 | from bot.client import client 6 | import plugins 7 | 8 | 9 | T = TypeVar("T", bound=Cog) 10 | 11 | 12 | def cog(cls: Type[T]) -> T: 13 | """Decorator for cog classes that are loaded/unloaded from the bot together with the plugin.""" 14 | cog = cls() 15 | cog_name = "{}:{}:{}".format(cog.__module__, cog.__cog_name__, hex(id(cog))) 16 | cog.__cog_name__ = cog_name 17 | 18 | async def initialize_cog() -> None: 19 | await client.add_cog(cog) 20 | 21 | async def finalize_cog() -> None: 22 | await client.remove_cog(cog_name) 23 | 24 | plugins.finalizer(finalize_cog) 25 | 26 | plugins.init(initialize_cog) 27 | 28 | return cog 29 | -------------------------------------------------------------------------------- /bot/commands.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for registering basic commands. Commands are triggered by a configurable prefix. 3 | """ 4 | 5 | import asyncio 6 | import logging 7 | from typing import TYPE_CHECKING, Any, Optional, Set, TypeVar 8 | 9 | import discord 10 | from discord import AllowedMentions, Message, PartialMessage 11 | import discord.ext.commands 12 | from discord.ext.commands import ( 13 | BadUnionArgument, 14 | Bot, 15 | CheckFailure, 16 | Cog, 17 | Command, 18 | CommandError, 19 | CommandInvokeError, 20 | CommandNotFound, 21 | NoPrivateMessage, 22 | PrivateMessageOnly, 23 | UserInputError, 24 | ) 25 | from sqlalchemy import TEXT, BigInteger, Computed 26 | from sqlalchemy.ext.asyncio import AsyncSession 27 | import sqlalchemy.orm 28 | from sqlalchemy.orm import Mapped, mapped_column 29 | 30 | from bot.client import client 31 | from bot.cogs import cog 32 | import plugins 33 | import util.db.kv 34 | from util.discord import format 35 | 36 | 37 | registry = sqlalchemy.orm.registry() 38 | 39 | 40 | @registry.mapped 41 | class GlobalConfig: 42 | __tablename__ = "commands_config" 43 | id: Mapped[int] = mapped_column(BigInteger, Computed("0"), primary_key=True) 44 | prefix: Mapped[str] = mapped_column(TEXT, nullable=False) 45 | 46 | if TYPE_CHECKING: 47 | 48 | def __init__(self, *, prefix: str, id: int = ...) -> None: ... 49 | 50 | 51 | logger: logging.Logger = logging.getLogger(__name__) 52 | prefix: str 53 | 54 | 55 | @plugins.init 56 | async def init() -> None: 57 | global prefix 58 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 59 | 60 | async with AsyncSession(util.db.engine, expire_on_commit=False) as session: 61 | conf = await session.get(GlobalConfig, 0) 62 | if not conf: 63 | conf = GlobalConfig(prefix=str((await util.db.kv.load(__name__)).prefix)) 64 | session.add(conf) 65 | await session.commit() 66 | 67 | prefix = client.command_prefix = conf.prefix 68 | 69 | 70 | @plugins.finalizer 71 | def cleanup_prefix() -> None: 72 | client.command_prefix = () 73 | 74 | 75 | Context = discord.ext.commands.Context[Bot] 76 | 77 | 78 | @cog 79 | class Commands(Cog): 80 | @Cog.listener() 81 | async def on_command(self, ctx: Context) -> None: 82 | logger.info( 83 | format( 84 | "Command {!r} from {!m} in {!c}", 85 | ctx.command and ctx.command.qualified_name, 86 | ctx.author.id, 87 | ctx.channel.id, 88 | ) 89 | ) 90 | 91 | @Cog.listener() 92 | async def on_command_error(self, ctx: Context, exc: Exception) -> None: 93 | try: 94 | if isinstance(exc, CommandNotFound): 95 | return 96 | elif isinstance(exc, CheckFailure) and not isinstance(exc, (NoPrivateMessage, PrivateMessageOnly)): 97 | return 98 | elif isinstance(exc, UserInputError): 99 | if isinstance(exc, BadUnionArgument): 100 | 101 | def conv_name(conv: type) -> str: 102 | try: 103 | return conv.__name__ 104 | except AttributeError: 105 | if hasattr(conv, "__origin__"): 106 | return repr(conv) 107 | return conv.__class__.__name__ 108 | 109 | exc_str = 'Could not interpret "{}" as:\n{}'.format( 110 | exc.param.name, 111 | "\n".join( 112 | "- {}: {}".format(conv_name(conv), sub_exc) 113 | for conv, sub_exc in zip(exc.converters, exc.errors) 114 | ), 115 | ) 116 | else: 117 | exc_str = str(exc) 118 | message = "Error: {}".format(exc_str) 119 | if ctx.command is not None: 120 | if getattr(ctx.command, "suppress_usage", False): 121 | return 122 | if ctx.invoked_with is not None: 123 | usage = " ".join( 124 | s for s in ctx.invoked_parents + [ctx.invoked_with, ctx.command.signature] if s 125 | ) 126 | else: 127 | usage = " ".join(s for s in [ctx.command.qualified_name, ctx.command.signature] if s) 128 | message += format("\nUsage: {!i}", usage) 129 | await ctx.send(message, allowed_mentions=AllowedMentions.none()) 130 | return 131 | elif isinstance(exc, CommandInvokeError): 132 | logger.error( 133 | format( 134 | "Error in command {} {!r} {!r} from {!m} in {!c}", 135 | ctx.command and ctx.command.qualified_name, 136 | tuple(ctx.args), 137 | ctx.kwargs, 138 | ctx.author.id, 139 | ctx.channel.id, 140 | ), 141 | exc_info=exc.__cause__, 142 | ) 143 | return 144 | elif isinstance(exc, CommandError): 145 | await ctx.send("Error: {}".format(str(exc)), allowed_mentions=AllowedMentions.none()) 146 | return 147 | else: 148 | logger.error( 149 | format( 150 | "Unknown exception in command {} {!r} {!r} from {!m} in {!c}", 151 | ctx.command and ctx.command.qualified_name, 152 | tuple(ctx.args), 153 | ctx.kwargs, 154 | ), 155 | exc_info=exc, 156 | ) 157 | return 158 | finally: 159 | await finalize_cleanup(ctx) 160 | 161 | @Cog.listener() 162 | async def on_message(self, msg: Message) -> None: 163 | await client.process_commands(msg) 164 | 165 | 166 | CommandT = TypeVar("CommandT", bound=Command[Any, Any, Any]) 167 | 168 | 169 | def plugin_command(cmd: CommandT) -> CommandT: 170 | """ 171 | Register a command to be added/removed together with the plugin. The command must be already wrapped in 172 | discord.ext.commands.command or discord.ext.commands.group. 173 | """ 174 | client.add_command(cmd) 175 | plugins.finalizer(lambda: client.remove_command(cmd.name)) 176 | return cmd 177 | 178 | 179 | def suppress_usage(cmd: CommandT) -> CommandT: 180 | """This decorator on a command suppresses the usage instructions if the command is invoked incorrectly.""" 181 | cmd.suppress_usage = True # type: ignore 182 | return cmd 183 | 184 | 185 | BotT = TypeVar("BotT", bound=Bot, covariant=True) 186 | 187 | 188 | class CleanupContext(discord.ext.commands.Context[BotT]): 189 | cleanup: "CleanupReference" 190 | 191 | 192 | class CleanupReference: 193 | __slots__ = "messages", "task" 194 | messages: Set[PartialMessage] 195 | task: Optional[asyncio.Task[None]] 196 | 197 | def __init__(self, ctx: CleanupContext[BotT]): 198 | self.messages = set() 199 | chan_id = ctx.channel.id 200 | msg_id = ctx.message.id 201 | 202 | async def cleanup_task() -> None: 203 | await ctx.bot.wait_for( 204 | "raw_message_delete", check=lambda m: m.channel_id == chan_id and m.message_id == msg_id 205 | ) 206 | 207 | self.task = asyncio.create_task(cleanup_task(), name="Cleanup task for {}-{}".format(chan_id, msg_id)) 208 | 209 | def __del__(self) -> None: 210 | if self.task is not None: 211 | self.task.cancel() 212 | self.task = None 213 | 214 | def add(self, msg: Message) -> None: 215 | self.messages.add(PartialMessage(channel=msg.channel, id=msg.id)) 216 | 217 | async def finalize(self) -> None: 218 | if self.task is None: 219 | return 220 | try: 221 | if len(self.messages) != 0: 222 | await asyncio.wait_for(self.task, 300) 223 | except (asyncio.TimeoutError, asyncio.CancelledError): 224 | pass 225 | else: 226 | for msg in self.messages: 227 | try: 228 | await msg.delete() 229 | except (discord.NotFound, discord.Forbidden): 230 | pass 231 | finally: 232 | self.task.cancel() 233 | self.task = None 234 | 235 | 236 | def init_cleanup(ctx: CleanupContext[BotT]) -> None: 237 | if not hasattr(ctx, "cleanup"): 238 | ref = CleanupReference(ctx) 239 | ctx.cleanup = ref 240 | 241 | old_send = ctx.send 242 | 243 | async def send(*args: Any, **kwargs: Any) -> Message: 244 | msg = await old_send(*args, **kwargs) 245 | ref.add(msg) 246 | return msg 247 | 248 | ctx.send = send 249 | 250 | 251 | async def finalize_cleanup(ctx: object) -> None: 252 | if (ref := getattr(ctx, "cleanup", None)) is not None: 253 | await ref.finalize() 254 | 255 | 256 | def add_cleanup(ctx: object, msg: Message) -> None: 257 | """Mark a message as "output" of a cleanup command.""" 258 | if (ref := getattr(ctx, "cleanup", None)) is not None: 259 | ref.add(msg) 260 | 261 | 262 | def cleanup(cmd: CommandT) -> CommandT: 263 | """Make the command watch out for the deletion of the invoking message, and in that case, delete all output.""" 264 | old_invoke = cmd.invoke 265 | 266 | async def invoke(ctx: CleanupContext[BotT]) -> None: 267 | init_cleanup(ctx) 268 | await old_invoke(ctx) 269 | await finalize_cleanup(ctx) 270 | 271 | cmd.invoke = invoke # type: ignore 272 | 273 | old_on_error = getattr(cmd, "on_error", None) 274 | 275 | async def on_error(*args: Any) -> None: 276 | if len(args) == 3: 277 | _, ctx, _ = args 278 | else: 279 | ctx, _ = args 280 | init_cleanup(ctx) 281 | if old_on_error is not None: 282 | await old_on_error(*args) 283 | 284 | cmd.on_error = on_error 285 | 286 | old_ensure_assignment_on_copy = cmd._ensure_assignment_on_copy 287 | 288 | def ensure_assignment_on_copy(other: CommandT) -> CommandT: 289 | return cleanup(old_ensure_assignment_on_copy(other)) 290 | 291 | cmd._ensure_assignment_on_copy = ensure_assignment_on_copy 292 | 293 | return cmd 294 | -------------------------------------------------------------------------------- /bot/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Iterator, Optional, Sequence, TypeVar, Union 3 | 4 | from discord.ext.commands import Command, group 5 | 6 | from bot.acl import privileged 7 | from bot.commands import Context, cleanup, plugin_command 8 | import plugins 9 | import util.db 10 | import util.db.kv 11 | from util.discord import CodeBlock, CodeItem, Inline, PlainItem, Quoted, chunk_messages, format 12 | 13 | 14 | @plugin_command 15 | @cleanup 16 | @group("config", invoke_without_command=True) 17 | @privileged 18 | async def config_command( 19 | ctx: Context, namespace: Optional[str], key: Optional[str], value: Optional[Union[CodeBlock, Inline, Quoted]] 20 | ) -> None: 21 | """Edit the key-value configs.""" 22 | if namespace is None: 23 | 24 | def namespace_items(nsps: Sequence[str]) -> Iterator[PlainItem]: 25 | first = True 26 | for nsp in nsps: 27 | if first: 28 | first = False 29 | else: 30 | yield PlainItem(", ") 31 | yield PlainItem(format("{!i}", nsp)) 32 | 33 | for content, _ in chunk_messages(namespace_items(await util.db.kv.get_namespaces())): 34 | await ctx.send(content) 35 | return 36 | 37 | conf = await util.db.kv.load(namespace) 38 | 39 | if key is None: 40 | 41 | def keys_items() -> Iterator[PlainItem]: 42 | first = True 43 | for keys in conf: 44 | if first: 45 | first = False 46 | else: 47 | yield PlainItem("; ") 48 | yield PlainItem(",".join(format("{!i}", key) for key in keys)) 49 | 50 | for content, _ in chunk_messages(keys_items()): 51 | await ctx.send(content) 52 | return 53 | 54 | keys = key.split(",") 55 | 56 | if value is None: 57 | for content, files in chunk_messages( 58 | (CodeItem(util.db.kv.json_encode(conf[keys]) or "", language="json", filename="{}.json".format(key)),) 59 | ): 60 | await ctx.send(content, files=files) 61 | return 62 | 63 | conf[keys] = json.loads(value.text) 64 | await conf 65 | await ctx.send("\u2705") 66 | 67 | 68 | @config_command.command("--delete") 69 | @privileged 70 | async def config_delete(ctx: Context, namespace: str, key: str) -> None: 71 | """Delete the provided key from the config.""" 72 | conf = await util.db.kv.load(namespace) 73 | keys = key.split(",") 74 | conf[keys] = None 75 | await conf 76 | await ctx.send("\u2705") 77 | 78 | 79 | CommandT = TypeVar("CommandT", bound=Command[Any, Any, Any]) 80 | 81 | 82 | def plugin_config_command(cmd: CommandT) -> CommandT: 83 | """Register a subcommand of the config command to be added/removed together with the plugin.""" 84 | config_command.add_command(cmd) 85 | plugins.finalizer(lambda: config_command.remove_command(cmd.name)) 86 | return cmd 87 | -------------------------------------------------------------------------------- /bot/interactions.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Callable, Coroutine, Optional, TypeVar, Union 3 | from typing_extensions import Concatenate, ParamSpec 4 | 5 | from discord import Interaction, Member, Message, User 6 | import discord.app_commands 7 | from discord.app_commands import AppCommandError, CheckFailure, Command, ContextMenu, Group 8 | from discord.ui import View 9 | 10 | from bot.client import client 11 | from bot.tasks import task 12 | import plugins 13 | from util.discord import format 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | old_on_error = client.tree.on_error 19 | 20 | 21 | @client.tree.error 22 | async def on_error(interaction: Interaction, exc: AppCommandError) -> None: 23 | if isinstance(exc, CheckFailure): 24 | message = "Error: {}".format(str(exc)) 25 | if interaction.response.is_done(): 26 | await interaction.followup.send(message, ephemeral=True) 27 | else: 28 | await interaction.response.send_message(message, ephemeral=True) 29 | return 30 | else: 31 | logger.error( 32 | format( 33 | "Error in command {!r} {!r} from {!m} in {!c}: {}", 34 | interaction.command, 35 | interaction.data, 36 | interaction.user, 37 | interaction.channel_id, 38 | str(exc), 39 | ), 40 | exc_info=exc.__cause__, 41 | ) 42 | return 43 | 44 | 45 | @plugins.finalizer 46 | def restore_on_error() -> None: 47 | client.tree.error(old_on_error) # type: ignore 48 | 49 | 50 | @task(name="Command tree sync task", exc_backoff_base=1) 51 | async def sync_task() -> None: 52 | await client.wait_until_ready() 53 | logger.debug("Syncing command tree") 54 | await client.tree.sync() 55 | 56 | 57 | P = ParamSpec("P") 58 | T = TypeVar("T") 59 | 60 | 61 | def command( 62 | name: str, description: Optional[str] = None 63 | ) -> Callable[[Callable[Concatenate[Interaction, P], Coroutine[Any, Any, T]]], Command[Any, P, T]]: 64 | """Decorator for a slash command that is added/removed together with the plugin.""" 65 | 66 | def decorator(fun: Callable[Concatenate[Interaction, P], Coroutine[Any, Any, T]]) -> Command[Any, P, T]: 67 | if description is None: 68 | cmd = discord.app_commands.command(name=name)(fun) 69 | else: 70 | cmd = discord.app_commands.command(name=name, description=description)(fun) 71 | 72 | client.tree.add_command(cmd) 73 | sync_task.run_coalesced(5) 74 | 75 | def finalizer(): 76 | client.tree.remove_command(cmd.name) 77 | sync_task.run_coalesced(5) 78 | 79 | plugins.finalizer(finalizer) 80 | 81 | return cmd 82 | 83 | return decorator 84 | 85 | 86 | def group(name: str, *, description: str, **kwargs: Any) -> Group: 87 | """Decorator for a slash command group that is added/removed together with the plugin.""" 88 | cmd = Group(name=name, description=description, **kwargs) 89 | 90 | client.tree.add_command(cmd) 91 | sync_task.run_coalesced(5) 92 | 93 | def finalizer(): 94 | client.tree.remove_command(cmd.name) 95 | sync_task.run_coalesced(5) 96 | 97 | plugins.finalizer(finalizer) 98 | 99 | return cmd 100 | 101 | 102 | def context_menu( 103 | name: str, 104 | ) -> Callable[ 105 | [ 106 | Union[ 107 | Callable[[Interaction, Member], Coroutine[Any, Any, object]], 108 | Callable[[Interaction, User], Coroutine[Any, Any, object]], 109 | Callable[[Interaction, Message], Coroutine[Any, Any, object]], 110 | Callable[[Interaction, Union[Member, User]], Coroutine[Any, Any, object]], 111 | ] 112 | ], 113 | ContextMenu, 114 | ]: 115 | """Decorator for a context menu command that is added/removed together with the plugin.""" 116 | 117 | def decorator( 118 | fun: Union[ 119 | Callable[[Interaction, Member], Coroutine[Any, Any, object]], 120 | Callable[[Interaction, User], Coroutine[Any, Any, object]], 121 | Callable[[Interaction, Message], Coroutine[Any, Any, object]], 122 | Callable[[Interaction, Union[Member, User]], Coroutine[Any, Any, object]], 123 | ] 124 | ) -> ContextMenu: 125 | cmd = discord.app_commands.context_menu(name=name)(fun) 126 | 127 | client.tree.add_command(cmd) 128 | sync_task.run_coalesced(5) 129 | 130 | def finalizer(): 131 | client.tree.remove_command(cmd.name) 132 | sync_task.run_coalesced(5) 133 | 134 | plugins.finalizer(finalizer) 135 | 136 | return cmd 137 | 138 | return decorator 139 | 140 | 141 | V = TypeVar("V", bound=View) 142 | 143 | 144 | def persistent_view(view: V) -> V: 145 | """Declare a given view as persistent (for as long as the plugin is loaded).""" 146 | assert view.is_persistent() 147 | client.add_view(view) 148 | 149 | def finalizer(): 150 | view.stop() 151 | 152 | plugins.finalizer(finalizer) 153 | 154 | return view 155 | -------------------------------------------------------------------------------- /bot/main_tasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module manages a registry of "main" tasks that are extending the runtime of the entire asyncio program. Once all 3 | main tasks complete (by returning or raising an exception), the program terminates. 4 | """ 5 | 6 | import asyncio 7 | import logging 8 | from typing import Any, Coroutine, List, Optional, TypeVar 9 | 10 | 11 | tasks: List[asyncio.Task[object]] 12 | try: 13 | # Keep the list of tasks if we're being reloaded 14 | tasks # type: ignore 15 | except NameError: 16 | tasks = [] 17 | else: 18 | tasks = tasks # type: ignore 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | T = TypeVar("T") 23 | 24 | 25 | def create_task(coro: Coroutine[Any, Any, T], *, name: Optional[str] = None) -> asyncio.Task[T]: 26 | """Register a task as a "main" task. If the task finishes or raises an exception, it is removed from the list""" 27 | task = asyncio.create_task(coro, name=name) 28 | task.add_done_callback(tasks.remove) 29 | tasks.append(task) 30 | return task 31 | 32 | 33 | def cancel() -> None: 34 | """Cancel all currently registered tasks""" 35 | for t in tasks: 36 | t.cancel() 37 | 38 | 39 | async def wait() -> None: 40 | """Return when all registered tasks are done, or if any tasks raises an exception, raise that exception""" 41 | while tasks: 42 | logger.debug("Waiting for tasks: {}".format(tasks)) 43 | await asyncio.gather(*tasks) 44 | 45 | 46 | async def wait_all() -> None: 47 | """Return when all registered tasks are done, accumulating exceptions""" 48 | try: 49 | while tasks: 50 | logger.debug("Waiting for tasks: {}".format(tasks)) 51 | await asyncio.gather(*tasks) 52 | except: 53 | logger.debug("Exception when waiting for main tasks", exc_info=True) 54 | await wait_all() 55 | raise 56 | -------------------------------------------------------------------------------- /bot/reactions.py: -------------------------------------------------------------------------------- 1 | """Utilities for waiting for Discord reactions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | from typing import ( 7 | Any, 8 | AsyncIterator, 9 | Callable, 10 | ContextManager, 11 | Dict, 12 | Generic, 13 | Literal, 14 | Optional, 15 | Tuple, 16 | TypeVar, 17 | Union, 18 | cast, 19 | overload, 20 | ) 21 | from weakref import WeakSet 22 | 23 | import discord 24 | from discord import ( 25 | Emoji, 26 | Message, 27 | PartialEmoji, 28 | RawReactionActionEvent, 29 | RawReactionClearEmojiEvent, 30 | RawReactionClearEvent, 31 | ) 32 | from discord.abc import Snowflake 33 | 34 | from bot.client import client 35 | from bot.cogs import Cog, cog 36 | import util.asyncio 37 | 38 | 39 | T = TypeVar("T") 40 | 41 | 42 | class FilteredQueue(asyncio.Queue[T], Generic[T]): 43 | """An async queue that only accepts values that match the given filter""" 44 | 45 | __slots__ = "filter" 46 | 47 | def __init__(self, maxsize: int = 0, *, filter: Optional[Callable[[T], bool]] = None): 48 | self.filter: Callable[[T], bool] 49 | self.filter = filter if filter is not None else lambda _: True 50 | return super().__init__(maxsize) 51 | 52 | async def put(self, item: T) -> None: 53 | if self.filter(item): 54 | return await super().put(item) 55 | 56 | def put_nowait(self, item: T) -> None: 57 | if self.filter(item): 58 | return super().put_nowait(item) 59 | 60 | 61 | ReactionEvent = Union[RawReactionActionEvent, RawReactionClearEvent, RawReactionClearEmojiEvent] 62 | 63 | reaction_queues: WeakSet[FilteredQueue[Union[BaseException, Tuple[str, ReactionEvent]]]] 64 | reaction_queues = WeakSet() 65 | 66 | 67 | class ReactionMonitor(ContextManager["ReactionMonitor[T]"], Generic[T]): 68 | """ 69 | A reaction monitor waits for reaction events matching particular rules, or until specified timeouts expire. Example 70 | use-case: 71 | 72 | try: 73 | with ReactionMonitor(event="add", message_id=..., timeout_each=120, timeout_total=300) as mon: 74 | ev, payload = await mon 75 | # ev = "add" 76 | # payload: RawReactionActionEvent 77 | except asyncio.TimeoutError: 78 | # either 2 minutes have passed since most recent reaction, or 5 minutes have passed in total 79 | """ 80 | 81 | __slots__ = ("loop", "queue", "end_time", "timeout_each") 82 | loop: asyncio.AbstractEventLoop 83 | queue: FilteredQueue[Union[BaseException, Tuple[str, ReactionEvent]]] 84 | end_time: Optional[float] 85 | timeout_each: Optional[float] 86 | 87 | @overload 88 | def __init__( 89 | self: ReactionMonitor[RawReactionActionEvent], 90 | *, 91 | event: Literal["add", "remove"], 92 | filter: Optional[Callable[[str, RawReactionActionEvent], bool]] = None, 93 | guild_id: Optional[int] = None, 94 | channel_id: Optional[int] = None, 95 | message_id: Optional[int] = None, 96 | author_id: Optional[int] = None, 97 | emoji: Optional[Union[PartialEmoji, Emoji, str, int]] = None, 98 | loop: Optional[asyncio.AbstractEventLoop] = None, 99 | timeout_each: Optional[float] = None, 100 | timeout_total: Optional[float] = None, 101 | ) -> None: ... 102 | 103 | @overload 104 | def __init__( 105 | self: ReactionMonitor[RawReactionClearEvent], 106 | *, 107 | event: Literal["clear"], 108 | filter: Optional[Callable[[str, RawReactionClearEvent], bool]] = None, 109 | guild_id: Optional[int] = None, 110 | channel_id: Optional[int] = None, 111 | message_id: Optional[int] = None, 112 | author_id: Optional[int] = None, 113 | emoji: Optional[Union[PartialEmoji, Emoji, str, int]] = None, 114 | loop: Optional[asyncio.AbstractEventLoop] = None, 115 | timeout_each: Optional[float] = None, 116 | timeout_total: Optional[float] = None, 117 | ) -> None: ... 118 | 119 | @overload 120 | def __init__( 121 | self: ReactionMonitor[RawReactionClearEmojiEvent], 122 | *, 123 | event: Literal["clear_emoji"], 124 | filter: Optional[Callable[[str, RawReactionClearEmojiEvent], bool]] = None, 125 | guild_id: Optional[int] = None, 126 | channel_id: Optional[int] = None, 127 | message_id: Optional[int] = None, 128 | author_id: Optional[int] = None, 129 | emoji: Optional[Union[PartialEmoji, Emoji, str, int]] = None, 130 | loop: Optional[asyncio.AbstractEventLoop] = None, 131 | timeout_each: Optional[float] = None, 132 | timeout_total: Optional[float] = None, 133 | ) -> None: ... 134 | 135 | @overload 136 | def __init__( 137 | self: ReactionMonitor[ReactionEvent], 138 | *, 139 | event: None = None, 140 | filter: Optional[Callable[[str, ReactionEvent], bool]] = None, 141 | guild_id: Optional[int] = None, 142 | channel_id: Optional[int] = None, 143 | message_id: Optional[int] = None, 144 | author_id: Optional[int] = None, 145 | emoji: Optional[Union[PartialEmoji, Emoji, str, int]] = None, 146 | loop: Optional[asyncio.AbstractEventLoop] = None, 147 | timeout_each: Optional[float] = None, 148 | timeout_total: Optional[float] = None, 149 | ) -> None: ... 150 | 151 | def __init__( 152 | self: ReactionMonitor[object], 153 | *, 154 | event: Optional[str] = None, 155 | filter: Optional[Callable[[str, Any], bool]] = None, 156 | guild_id: Optional[int] = None, 157 | channel_id: Optional[int] = None, 158 | message_id: Optional[int] = None, 159 | author_id: Optional[int] = None, 160 | emoji: Optional[Union[PartialEmoji, Emoji, str, int]] = None, 161 | loop: Optional[asyncio.AbstractEventLoop] = None, 162 | timeout_each: Optional[float] = None, 163 | timeout_total: Optional[float] = None, 164 | ): 165 | self.loop = loop if loop is not None else asyncio.get_running_loop() 166 | 167 | # for "add" and "remove", RawReactionActionEvent has the fields 168 | # guild_id, channel_id, message_id, author_id, emoji 169 | # for "clear", RawReactionClearEvent has the fields 170 | # guild_id, channel_id, message_id 171 | # for "clear_emoji", RawReactionClearEmojiEvent has the fields 172 | # guild_id, channel_id, message_id, emoji 173 | def event_filter(ev: str, payload: ReactionEvent) -> bool: 174 | return ( 175 | (guild_id is None or payload.guild_id == guild_id) 176 | and (channel_id is None or payload.channel_id == channel_id) 177 | and (message_id is None or payload.message_id == message_id) 178 | and ( 179 | author_id is None or not hasattr(payload, "user_id") or payload.user_id == author_id # type: ignore 180 | ) 181 | and (event is None or ev == event) 182 | and ( 183 | emoji is None 184 | or not hasattr(payload, "emoji") 185 | or payload.emoji == emoji # type: ignore 186 | or payload.emoji.name == emoji # type: ignore 187 | or payload.emoji.id == emoji # type: ignore 188 | ) 189 | and (filter is None or filter(ev, payload)) 190 | ) 191 | 192 | self.timeout_each = timeout_each 193 | if timeout_total is None: 194 | self.end_time = None 195 | else: 196 | self.end_time = self.loop.time() + timeout_total 197 | 198 | def queue_filter(value: Union[BaseException, Tuple[str, ReactionEvent]]) -> bool: 199 | return isinstance(value, BaseException) or event_filter(*value) 200 | 201 | self.queue = FilteredQueue(maxsize=0, filter=queue_filter) 202 | 203 | def __enter__(self) -> ReactionMonitor[T]: 204 | reaction_queues.add(self.queue) 205 | return self 206 | 207 | def __exit__(self, exc_type, exc_val, tb) -> None: # type: ignore 208 | reaction_queues.discard(self.queue) 209 | 210 | @util.asyncio.__await__ 211 | async def __await__(self) -> Tuple[str, T]: 212 | timeout = self.timeout_each 213 | if self.end_time is not None: 214 | remaining = self.end_time - self.loop.time() 215 | if timeout is None or timeout > remaining: 216 | timeout = remaining 217 | value = await asyncio.wait_for(self.queue.get(), timeout) 218 | if isinstance(value, BaseException): 219 | raise value 220 | return cast(Tuple[str, T], value) 221 | 222 | async def __aiter__(self) -> AsyncIterator[Tuple[str, T]]: 223 | while True: 224 | try: 225 | yield await self 226 | except asyncio.TimeoutError: 227 | return 228 | 229 | def cancel(self, exc: Optional[BaseException] = None) -> None: 230 | if exc is None: 231 | exc = asyncio.CancelledError() 232 | try: 233 | raise exc 234 | except BaseException as exc: 235 | self.queue.put_nowait(exc) 236 | 237 | 238 | def deliver_event(ev: str, payload: ReactionEvent) -> None: 239 | gen = reaction_queues.__iter__() 240 | 241 | def cont_deliver() -> None: 242 | try: 243 | for queue in gen: 244 | queue.put_nowait((ev, payload)) 245 | except: 246 | cont_deliver() 247 | raise 248 | 249 | cont_deliver() 250 | 251 | 252 | @cog 253 | class Reactions(Cog): 254 | @Cog.listener() 255 | async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: 256 | deliver_event("add", payload) 257 | 258 | @Cog.listener() 259 | async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: 260 | deliver_event("remove", payload) 261 | 262 | @Cog.listener() 263 | async def on_raw_reaction_clear(self, payload: RawReactionClearEvent) -> None: 264 | deliver_event("clear", payload) 265 | 266 | @Cog.listener() 267 | async def on_raw_reaction_clear_emoji(self, payload: RawReactionClearEmojiEvent) -> None: 268 | deliver_event("clear_emoji", payload) 269 | 270 | 271 | def emoji_key(emoji: Union[Emoji, PartialEmoji, str]) -> Union[str, int]: 272 | if isinstance(emoji, str): 273 | return emoji 274 | elif emoji.id is None: 275 | return emoji.name 276 | else: 277 | return emoji.id 278 | 279 | 280 | async def get_reaction( 281 | msg: Message, 282 | user: Snowflake, 283 | reactions: Dict[Union[Emoji, PartialEmoji, str], T], 284 | *, 285 | timeout: Optional[float] = None, 286 | unreact: bool = True, 287 | ) -> Optional[T]: 288 | """ 289 | Offer a set of reactions on a given message, each corresponding to a given return value, and wait for the given user 290 | to react on any one of them, or until timeout is reached, in which case None is returned. If unreact=True, then all 291 | reactions that were not selected by the user are removed afterwards. 292 | """ 293 | assert client.user is not None 294 | reacts = {emoji_key(key): value for key, value in reactions.items()} 295 | with ReactionMonitor( 296 | channel_id=msg.channel.id, 297 | message_id=msg.id, 298 | author_id=user.id, 299 | event="add", 300 | filter=lambda _, p: emoji_key(p.emoji) in reacts, 301 | timeout_each=timeout, 302 | ) as mon: 303 | try: 304 | await asyncio.gather(*(msg.add_reaction(key) for key in reactions)) 305 | except (discord.NotFound, discord.Forbidden): 306 | pass 307 | try: 308 | _, payload = await mon 309 | except asyncio.TimeoutError: 310 | return None 311 | if unreact: 312 | try: 313 | await asyncio.gather( 314 | *( 315 | msg.remove_reaction(key, client.user) 316 | for key in reactions 317 | if emoji_key(key) != emoji_key(payload.emoji) 318 | ) 319 | ) 320 | except (discord.NotFound, discord.Forbidden): 321 | pass 322 | return reacts.get(emoji_key(payload.emoji)) 323 | 324 | 325 | async def get_input( 326 | msg: Message, 327 | user: Snowflake, 328 | reactions: Dict[Union[Emoji, PartialEmoji, str], T], 329 | *, 330 | timeout: Optional[float] = None, 331 | unreact: bool = True, 332 | ) -> Optional[Union[T, Message]]: 333 | """ 334 | Offer a set of reactions on a given message a-la get_reaction, and wait until the user either reacts or responds 335 | with a message. 336 | """ 337 | assert client.user is not None 338 | reacts = {emoji_key(key): value for key, value in reactions.items()} 339 | with ReactionMonitor( 340 | channel_id=msg.channel.id, 341 | message_id=msg.id, 342 | author_id=user.id, 343 | event="add", 344 | filter=lambda _, p: emoji_key(p.emoji) in reacts, 345 | timeout_each=timeout, 346 | ) as mon: 347 | try: 348 | await asyncio.gather(*(msg.add_reaction(key) for key in reactions)) 349 | except (discord.NotFound, discord.Forbidden): 350 | pass 351 | msg_task = asyncio.create_task( 352 | client.wait_for("message", check=lambda m: m.channel == msg.channel and m.author.id == user.id) 353 | ) 354 | reaction_task = asyncio.ensure_future(mon) 355 | try: 356 | done, _ = await asyncio.wait( 357 | (msg_task, reaction_task), timeout=timeout, return_when=asyncio.FIRST_COMPLETED 358 | ) 359 | except asyncio.TimeoutError: 360 | return None 361 | if msg_task in done: 362 | reaction_task.cancel() 363 | if unreact: 364 | try: 365 | await asyncio.gather(*(msg.remove_reaction(key, client.user) for key in reactions)) 366 | except (discord.NotFound, discord.Forbidden): 367 | pass 368 | return msg_task.result() 369 | elif reaction_task in done: 370 | msg_task.cancel() 371 | _, payload = reaction_task.result() 372 | if unreact: 373 | try: 374 | await asyncio.gather( 375 | *( 376 | msg.remove_reaction(key, client.user) 377 | for key in reactions 378 | if emoji_key(key) != emoji_key(payload.emoji) 379 | ) 380 | ) 381 | except (discord.NotFound, discord.Forbidden): 382 | pass 383 | return reacts.get(emoji_key(payload.emoji)) 384 | else: 385 | return None 386 | -------------------------------------------------------------------------------- /bot/tasks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Awaitable, Callable, Optional 4 | 5 | import plugins 6 | 7 | 8 | logger: logging.Logger = logging.getLogger(__name__) 9 | 10 | 11 | class Task(asyncio.Task[None]): 12 | __slots__ = "cb", "timeout", "exc_backoff_base", "exc_backoff_multiplier", "queue" 13 | cb: Callable[[], Awaitable[object]] 14 | timeout: Optional[float] 15 | exc_backoff_base: Optional[float] 16 | exc_backoff_multiplier: int 17 | queue: asyncio.Queue[Optional[float]] 18 | 19 | def __init__( 20 | self, 21 | cb: Callable[[], Awaitable[object]], 22 | *, 23 | every: Optional[float] = None, 24 | exc_backoff_base: Optional[float] = None, 25 | name: Optional[str] = None, 26 | ) -> None: 27 | super().__init__(self.task_loop(), loop=asyncio.get_event_loop(), name=name) 28 | self.cb = cb 29 | self.timeout = every 30 | self.exc_backoff_base = exc_backoff_base 31 | self.exc_backoff_multiplier = 1 32 | self.queue = asyncio.Queue() 33 | 34 | def run_once(self) -> None: 35 | """ 36 | Trigger the task to run once, and reset the "every" timer. The task will be run as many times as this function 37 | is called. 38 | """ 39 | self.queue.put_nowait(None) 40 | 41 | def run_coalesced(self, timeout: float) -> None: 42 | """ 43 | Trigger the task to run once, unless another run_coalesced or run_once happens within "timeout" seconds, in 44 | which case the two requests are coalesced and the task runs once. Multiple run_coalesced invocations will join 45 | into one, but a run_once always ends the chain. 46 | """ 47 | self.queue.put_nowait(timeout) 48 | 49 | async def task_loop(self) -> None: 50 | while True: 51 | try: 52 | try: 53 | timeout = self.timeout 54 | while True: 55 | timeout = await asyncio.wait_for(self.queue.get(), timeout=timeout) 56 | if timeout is None: 57 | break 58 | elif self.timeout is not None and timeout > self.timeout: 59 | timeout = self.timeout 60 | except asyncio.TimeoutError: 61 | pass 62 | await self.cb() 63 | self.exc_backoff_multiplier = 1 64 | except asyncio.CancelledError: 65 | raise 66 | except: 67 | logger.error("Exception in {}".format(self.get_name()), exc_info=True) 68 | if self.exc_backoff_base is not None: 69 | await asyncio.sleep(self.exc_backoff_base * self.exc_backoff_multiplier) 70 | self.exc_backoff_multiplier *= 2 71 | 72 | 73 | def task( 74 | *, every: Optional[float] = None, exc_backoff_base: Optional[float] = None, name: Optional[str] = None 75 | ) -> Callable[[Callable[[], Awaitable[object]]], Task]: 76 | """ 77 | A decorator that registers the function as a task that is called periodically or upon request. The task is cancelled 78 | on plugin unload. The "every" parameter causes the task to wait at most that many seconds between executions. A task 79 | can be started sooner by invoking .run_once or .run_coalesced on it. The "exc_backoff_base" parameter causes an 80 | additional delay in case the task throws an exception, and the delay increases exponentially if the task keeps 81 | throwing exceptions. 82 | """ 83 | 84 | def register_task(cb: Callable[[], Awaitable[object]]) -> Task: 85 | task = Task(cb, every=every, exc_backoff_base=exc_backoff_base, name=name) 86 | plugins.finalizer(task.cancel) 87 | return task 88 | 89 | return register_task 90 | -------------------------------------------------------------------------------- /discord: -------------------------------------------------------------------------------- 1 | discord.py/discord -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | bot: 4 | build: docker 5 | volumes: 6 | - type: bind 7 | source: $PWD 8 | target: /opt/bot/ 9 | depends_on: 10 | db: 11 | condition: service_healthy 12 | 13 | db: 14 | image: postgres:12.20-alpine 15 | environment: 16 | POSTGRES_USER: bot 17 | POSTGRES_PASSWORD: bot 18 | POSTGRES_DB: discord 19 | volumes: 20 | - db:/var/lib/postgresql/data 21 | healthcheck: 22 | test: ["CMD-SHELL", "pg_isready -U bot -d discord"] 23 | start_interval: 1s 24 | start_period: 1m 25 | interval: 10m 26 | 27 | volumes: 28 | db: 29 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-alpine 2 | RUN apk add --no-cache gcc g++ musl-dev git 3 | COPY requirements_core.txt / 4 | RUN pip install --use-pep517 --requirement /requirements_core.txt \ 5 | && rm requirements_core.txt 6 | 7 | WORKDIR /opt/bot/ 8 | CMD ["python", "main.py"] 9 | -------------------------------------------------------------------------------- /docker/requirements_core.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | aiohttp-session 3 | asyncpg 4 | datrie 5 | pyright==1.1.378 6 | PyYAML 7 | sqlalchemy >= 2 8 | types-requests 9 | types-PyYAML 10 | typing_extensions 11 | -------------------------------------------------------------------------------- /log_setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | import time 4 | from typing import Callable, List, Optional, TextIO, Tuple, Type, Union 5 | import warnings 6 | 7 | import static_config 8 | 9 | 10 | logging.basicConfig(handlers=[], force=True) 11 | 12 | 13 | def closure() -> None: 14 | old_showwarning = warnings.showwarning 15 | 16 | def showwarning( 17 | message: Union[Warning, str], 18 | category: Type[Warning], 19 | filename: str, 20 | lineno: int, 21 | file: Optional[TextIO] = None, 22 | line: Optional[str] = None, 23 | ) -> None: 24 | if file is not None: 25 | old_showwarning(message, category, filename, lineno, file, line) 26 | else: 27 | text = warnings.formatwarning(message, category, filename, lineno, line) 28 | logging.getLogger("__builtins__").error(text) 29 | 30 | warnings.showwarning = showwarning 31 | 32 | 33 | closure() 34 | 35 | logger: logging.Logger = logging.getLogger() 36 | logger.setLevel(logging.NOTSET) 37 | 38 | 39 | class Formatter(logging.Formatter): 40 | """A formatter that formats multi-line messages in a greppable fashion""" 41 | 42 | __slots__ = () 43 | 44 | converter = time.gmtime 45 | default_time_format = "%Y-%m-%dT%H:%M:%S" 46 | default_msec_format = "%s.%03d" 47 | 48 | def format(self, record: logging.LogRecord) -> str: 49 | record.asctime = self.formatTime(record, self.datefmt) 50 | if record.exc_info: 51 | if not record.exc_text: 52 | record.exc_text = self.formatException(record.exc_info) 53 | 54 | lines = record.getMessage().split("\n") 55 | if record.exc_text: 56 | lines.extend(record.exc_text.split("\n")) 57 | if record.stack_info: 58 | lines.extend(self.formatStack(record.stack_info).split("\n")) 59 | 60 | lines = list(filter(bool, lines)) 61 | 62 | output: List[str] = [] 63 | for i in range(len(lines)): 64 | record.message = lines[i] 65 | if len(lines) == 1: 66 | record.symbol = ":" 67 | elif i == 0: 68 | record.symbol = "{" 69 | elif i == len(lines) - 1: 70 | record.symbol = "}" 71 | else: 72 | record.symbol = "|" 73 | output.append(self.formatMessage(record)) 74 | return "\n".join(output) 75 | 76 | 77 | formatter: logging.Formatter = Formatter("%(asctime)s %(name)s %(levelname)s%(symbol)s %(message)s") 78 | 79 | targets: List[Tuple[int, str, Optional[Callable[[logging.LogRecord], bool]]]] = [ 80 | (logging.DEBUG, "debug.discord", lambda r: r.name.startswith("discord.")), 81 | (logging.DEBUG, "debug", lambda r: not r.name.startswith("discord.")), 82 | (logging.INFO, "info", None), 83 | (logging.WARNING, "warning", None), 84 | (logging.ERROR, "error", None), 85 | (logging.CRITICAL, "critical", None), 86 | ] 87 | 88 | for level, name, cond in targets: 89 | handler = logging.handlers.TimedRotatingFileHandler( 90 | filename="{}/{}.log".format(static_config.Log["directory"], name), 91 | when="midnight", 92 | utc=True, 93 | encoding="utf", 94 | errors="replace", 95 | ) 96 | handler.setLevel(level) 97 | handler.setFormatter(formatter) 98 | if cond: 99 | handler.addFilter(cond) 100 | logger.addHandler(handler) 101 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import log_setup # type: ignore 4 | 5 | 6 | logger: logging.Logger = logging.getLogger(__name__) 7 | 8 | try: 9 | import asyncio 10 | 11 | import plugins 12 | 13 | manager = plugins.PluginManager(["bot", "plugins", "util"]) 14 | manager.register() 15 | 16 | async def async_main() -> None: 17 | main_tasks = None 18 | try: 19 | main_tasks = await manager.load("bot.main_tasks") 20 | await manager.load("bot.autoload") 21 | await main_tasks.wait() 22 | except: 23 | logger.critical("Exception during main event loop", exc_info=True) 24 | finally: 25 | logger.info("Unloading all plugins") 26 | await manager.unload_all() 27 | logger.info("Cancelling main tasks") 28 | if main_tasks: 29 | main_tasks.cancel() 30 | await main_tasks.wait_all() 31 | logger.info("Exiting main loop") 32 | 33 | asyncio.run(async_main()) 34 | except: 35 | logger.critical("Exception in main", exc_info=True) 36 | raise 37 | -------------------------------------------------------------------------------- /migrations/plugins.consensus-9043f7f4cbcd154392e59579f64e32418f269840-404a3478aeb168e45611267af447900b1212f9f2.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE consensus.polltype AS ENUM ('COUNTED', 'CHOICE', 'WITH_COMMENTS', 'WITH_CONCERNS'); 2 | 3 | ALTER TABLE consensus.polls ADD COLUMN poll consensus.polltype; 4 | UPDATE consensus.polls SET poll = 'WITH_CONCERNS'; 5 | ALTER TABLE consensus.polls ALTER COLUMN poll SET NOT NULL; 6 | 7 | ALTER TABLE consensus.polls ADD COLUMN options TEXT[]; 8 | UPDATE consensus.polls SET options = ARRAY[E'\u2705', E'\U0001F518', E'\u274C']; 9 | ALTER TABLE consensus.polls ALTER COLUMN options SET NOT NULL; 10 | 11 | ALTER TABLE consensus.votes ADD COLUMN choice_index BIGINT; 12 | UPDATE consensus.votes SET choice_index = CASE 13 | WHEN vote = 'UPVOTE' THEN 0 14 | WHEN vote = 'NEUTRAL' THEN 1 15 | WHEN vote = 'DOWNVOTE' THEN 2 16 | END; 17 | ALTER TABLE consensus.votes ALTER COLUMN choice_index SET NOT NULL; 18 | 19 | ALTER TABLE consensus.votes DROP COLUMN vote; 20 | DROP TYPE consensus.votetype; 21 | -------------------------------------------------------------------------------- /migrations/plugins.factoids-3809d321fe36da4f650357bf78d02670b8b1a4f2-87a9a7795dabb98da513f93556237eca230119fc.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE factoids.factoids DROP CONSTRAINT factoids_pkey; 2 | ALTER TABLE factoids.factoids ADD COLUMN id SERIAL PRIMARY KEY; 3 | 4 | CREATE TABLE factoids.aliases 5 | ( name TEXT NOT NULL PRIMARY KEY 6 | , id INTEGER NOT NULL REFERENCES factoids.factoids(id) 7 | , author_id BIGINT NOT NULL 8 | , created_at TIMESTAMP NOT NULL 9 | , uses BIGINT NOT NULL 10 | , used_at TIMESTAMP 11 | ); 12 | 13 | INSERT INTO factoids.aliases (name, id, author_id, created_at, uses, used_at) 14 | SELECT name, id, author_id, created_at, uses, used_at FROM factoids.factoids; 15 | 16 | ALTER TABLE factoids.factoids DROP COLUMN name; 17 | -------------------------------------------------------------------------------- /migrations/plugins.factoids-87a9a7795dabb98da513f93556237eca230119fc-aa369b1260ecbabe80e9a33f62099e7afea2bace.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE factoids.factoids ADD COLUMN flags JSONB; 2 | -------------------------------------------------------------------------------- /migrations/plugins.factoids-aa369b1260ecbabe80e9a33f62099e7afea2bace-409508100f9b343fcf5bf512b57cfd12b166f5b8.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE factoids.config ( 2 | id BIGINT GENERATED ALWAYS AS (0) STORED NOT NULL, 3 | prefix TEXT, 4 | PRIMARY KEY (id) 5 | ) 6 | -------------------------------------------------------------------------------- /migrations/plugins.log-30418475489ba40f09ca4ee0051f09bb8bcf150b-2c5e3570c48030fa81308961343a62f3379c5b6c.sql: -------------------------------------------------------------------------------- 1 | CREATE INDEX messages_author_id ON log.messages USING BTREE (author_id); 2 | -------------------------------------------------------------------------------- /migrations/plugins.log-4809e558c405bbde63a7dc65cbbdf505ebd14cd7-65ed62be1b679154b675097517dd25c7fc27ad8a.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE saved_messages ALTER COLUMN content TYPE BYTEA USING CONVERT_TO(content, 'utf8'); 2 | -------------------------------------------------------------------------------- /migrations/plugins.log-65ed62be1b679154b675097517dd25c7fc27ad8a-9f4dee1a807ac59d97ed694526009f900c93d22e.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE saved_files DROP COLUMN content; 2 | ALTER TABLE saved_files ADD COLUMN local_filename TEXT; 3 | -------------------------------------------------------------------------------- /migrations/plugins.log-9f4dee1a807ac59d97ed694526009f900c93d22e-30418475489ba40f09ca4ee0051f09bb8bcf150b.sql: -------------------------------------------------------------------------------- 1 | CREATE SCHEMA log; 2 | ALTER TABLE saved_messages SET SCHEMA log; 3 | ALTER TABLE log.saved_messages RENAME TO messages; 4 | ALTER TABLE saved_files SET SCHEMA log; 5 | ALTER TABLE log.saved_files RENAME TO files; 6 | 7 | CREATE TABLE log.users 8 | ( id BIGINT NOT NULL 9 | , set_at TIMESTAMP NOT NULL 10 | , username TEXT NOT NULL 11 | , discrim CHAR(4) NOT NULL 12 | , unset_at TIMESTAMP 13 | , PRIMARY KEY (id, set_at) 14 | ); 15 | CREATE TABLE log.nicks 16 | ( id BIGINT NOT NULL 17 | , set_at TIMESTAMP NOT NULL 18 | , nick TEXT 19 | , unset_at TIMESTAMP 20 | , PRIMARY KEY (id, set_at) 21 | ); 22 | -------------------------------------------------------------------------------- /migrations/plugins.modmail-4f5dea7d15ddaa90bec050ba4d8a6460aefe44b8-6e01739d17b263c9005268bcfe60fea051fbc673.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE modmail.threads ADD PRIMARY KEY (thread_first_message_id); 2 | -------------------------------------------------------------------------------- /migrations/plugins.modmail-6e01739d17b263c9005268bcfe60fea051fbc673-8515db1e187aade4392fa2755dc62355808f2e6a.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE modmail.guilds ( 2 | guild_id BIGSERIAL NOT NULL PRIMARY KEY, 3 | token TEXT NOT NULL, 4 | channel_id BIGINT NOT NULL, 5 | role_id BIGINT NOT NULL, 6 | thread_expiry INTERVAL NOT NULL 7 | ); 8 | 9 | WITH conf AS (SELECT key, value FROM kv WHERE namespace = 'plugins.modmail') 10 | INSERT INTO modmail.guilds 11 | SELECT 12 | (guild.value::JSON #>> '{}')::BIGINT AS guild_id, 13 | token.value::JSON #>> '{}' AS token, 14 | (channel.value::JSON #>> '{}')::BIGINT AS channel_id, 15 | (role.value::JSON #>> '{}')::BIGINT AS role_id, 16 | (thread_expiry.value::JSON #>> '{}')::BIGINT 17 | * INTERVAL '1 second' AS thread_expiry 18 | FROM 19 | conf AS guild, 20 | conf AS token, 21 | conf AS channel, 22 | conf AS role, 23 | conf AS thread_expiry 24 | WHERE guild.key = ARRAY['guild'] 25 | AND token.key = ARRAY['token'] 26 | AND channel.key = ARRAY['channel'] 27 | AND role.key = ARRAY['role'] 28 | AND thread_expiry.key = ARRAY['thread_expiry']; 29 | 30 | DELETE FROM kv WHERE namespace = 'plugins.modmail'; 31 | -------------------------------------------------------------------------------- /migrations/plugins.modmail-bf5948e3b8d67d57dfda8e9ccb34a8dd44cf5d39-4f5dea7d15ddaa90bec050ba4d8a6460aefe44b8.sql: -------------------------------------------------------------------------------- 1 | CREATE SCHEMA modmail; 2 | ALTER TABLE modmails SET SCHEMA modmail; 3 | ALTER TABLE modmail.modmails RENAME TO messages; 4 | CREATE TABLE modmail.threads 5 | ( user_id BIGINT NOT NULL 6 | , thread_first_message_id BIGINT NOT NULL 7 | , last_used TIMESTAMP NOT NULL 8 | ); 9 | -------------------------------------------------------------------------------- /migrations/plugins.persistence-cff14827533eafa51a65684a41e83d93ce27b44c-79d5069c83aabeebf0be39f73e2d8828f9ce2080.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE persistence.roles ( 2 | id BIGSERIAL NOT NULL, 3 | PRIMARY KEY (id) 4 | ); 5 | -------------------------------------------------------------------------------- /migrations/plugins.roles_review-5305e3042b889ffbd0a5fa96f5bf33e85ad930f9-6705b50d5a48f1ce96f62352069198b94655eba8.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE roles_review.roles ( 2 | id BIGSERIAL NOT NULL, 3 | review_channel_id BIGINT NOT NULL, 4 | upvote_limit INTEGER NOT NULL, 5 | downvote_limit INTEGER NOT NULL, 6 | pending_role_id BIGINT, 7 | denied_role_id BIGINT, 8 | prompt TEXT[] NOT NULL, 9 | invitation TEXT NOT NULL, 10 | PRIMARY KEY (id) 11 | ); 12 | -------------------------------------------------------------------------------- /migrations/plugins.roles_review-87439a30c949b57f6a003ae2e2cc626a7cb967a4-5305e3042b889ffbd0a5fa96f5bf33e85ad930f9.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE roles_review.applications ADD COLUMN voting_id BIGINT; 2 | ALTER TABLE roles_review.applications ADD COLUMN decision BOOLEAN; 3 | UPDATE roles_review.applications SET decision = 'false' WHERE resolved; 4 | ALTER TABLE roles_review.applications DROP COLUMN resolved; 5 | -------------------------------------------------------------------------------- /migrations/plugins.tickets-6712376268cc661947b8ed360d1e1c78c33bcc7f-9da100e1067dfc8cab4bdc43a559a1a96907b628.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE tickets.tickets ALTER COLUMN created_at SET NOT NULL; 2 | -------------------------------------------------------------------------------- /migrations/plugins.tickets-9da100e1067dfc8cab4bdc43a559a1a96907b628-f69269c5ed1edd9bc84989ab9a218b220dbb4810.sql: -------------------------------------------------------------------------------- 1 | ALTER TYPE tickets.TicketType ADD VALUE 'TIMEOUT'; 2 | -------------------------------------------------------------------------------- /migrations/plugins.tickets-b108465ff07d357fe70fb67e68dbdb592e14a68a-6712376268cc661947b8ed360d1e1c78c33bcc7f.sql: -------------------------------------------------------------------------------- 1 | CREATE VIEW tickets.mod_queues AS 2 | SELECT tkt.id AS id 3 | FROM tickets.mods mod 4 | INNER JOIN tickets.tickets tkt ON mod.modid = tkt.modid AND tkt.id = 5 | (SELECT t.id 6 | FROM tickets.tickets t 7 | WHERE mod.modid = t.modid AND stage <> 'COMMENTED' 8 | ORDER BY t.id LIMIT 1 9 | ); 10 | -------------------------------------------------------------------------------- /migrations/plugins.tickets-cecbe3f6862e737833c0252aecaba0fbe79c0463-b108465ff07d357fe70fb67e68dbdb592e14a68a.sql: -------------------------------------------------------------------------------- 1 | BEGIN; 2 | 3 | ALTER TYPE tickets.TicketStatus RENAME TO OldTicketStatus; 4 | 5 | CREATE TYPE tickets.TicketStatus AS ENUM 6 | ( 'IN_EFFECT' 7 | , 'EXPIRED' 8 | , 'EXPIRE_FAILED' 9 | , 'REVERTED' 10 | , 'HIDDEN' 11 | ); 12 | 13 | ALTER TABLE tickets.tickets ALTER COLUMN status TYPE tickets.TicketStatus USING 14 | CASE 15 | WHEN status = 'NEW' THEN 'IN_EFFECT' 16 | ELSE status::TEXT::tickets.TicketStatus 17 | END; 18 | ALTER TABLE tickets.history ALTER COLUMN status TYPE tickets.TicketStatus USING 19 | CASE 20 | WHEN status = 'NEW' THEN 'IN_EFFECT' 21 | ELSE status::TEXT::tickets.TicketStatus 22 | END; 23 | 24 | DROP TYPE tickets.OldTicketStatus; 25 | 26 | ALTER TABLE tickets.tickets ADD FOREIGN KEY (modid) REFERENCES tickets.mods (modid); 27 | 28 | ALTER TABLE tickets.mods ADD COLUMN scheduled_delivery TIMESTAMP; 29 | 30 | ALTER TABLE tickets.mods DROP COLUMN last_prompt_msgid; 31 | 32 | CREATE INDEX tickets_mod_queue ON tickets.tickets USING BTREE (modid, id) WHERE stage <> 'COMMENTED'; 33 | 34 | CREATE OR REPLACE FUNCTION tickets.log_ticket_update() 35 | RETURNS TRIGGER AS $log_ticket_update$ 36 | DECLARE 37 | last_version INT; 38 | BEGIN 39 | SELECT version INTO last_version 40 | FROM tickets.history 41 | WHERE id = OLD.id 42 | ORDER BY version DESC LIMIT 1; 43 | IF NOT FOUND THEN 44 | INSERT INTO tickets.history 45 | VALUES 46 | ( 0 47 | , OLD.created_at 48 | , OLD.id 49 | , OLD.type 50 | , OLD.stage 51 | , OLD.status 52 | , OLD.modid 53 | , OLD.targetid 54 | , OLD.roleid 55 | , OLD.auditid 56 | , OLD.duration 57 | , OLD.comment 58 | , OLD.list_msgid 59 | , OLD.delivered_id 60 | , OLD.created_at 61 | , OLD.modified_by 62 | ); 63 | last_version = 0; 64 | END IF; 65 | INSERT INTO tickets.history 66 | VALUES 67 | ( last_version + 1 68 | , CURRENT_TIMESTAMP AT TIME ZONE 'UTC' 69 | , NEW.id 70 | , NULLIF(NEW.type, OLD.type) 71 | , NULLIF(NEW.stage, OLD.stage) 72 | , NULLIF(NEW.status, OLD.status) 73 | , NULLIF(NEW.modid, OLD.modid) 74 | , NULLIF(NEW.targetid, OLD.targetid) 75 | , NULLIF(NEW.roleid, OLD.roleid) 76 | , NULLIF(NEW.auditid, OLD.auditid) 77 | , NULLIF(NEW.duration, OLD.duration) 78 | , NULLIF(NEW.comment, OLD.comment) 79 | , NULLIF(NEW.list_msgid, OLD.list_msgid) 80 | , NULLIF(NEW.delivered_id, OLD.delivered_id) 81 | , NULLIF(NEW.created_at, OLD.created_at) 82 | , NEW.modified_by 83 | ); 84 | RETURN NULL; 85 | END 86 | $log_ticket_update$ LANGUAGE plpgsql; 87 | 88 | COMMIT; 89 | -------------------------------------------------------------------------------- /migrations/plugins.tickets-db25e78666828bd5fbbbf387796c88a8859b70b8-cecbe3f6862e737833c0252aecaba0fbe79c0463.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE tickets.tickets ALTER COLUMN created_at SET NOT NULL; 2 | -------------------------------------------------------------------------------- /migrations/plugins.tickets-f69269c5ed1edd9bc84989ab9a218b220dbb4810-424a6e89e1a95a089f089e92a7f2751416dd575b.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE tickets.tickets ADD COLUMN approved BOOLEAN NOT NULL DEFAULT 'true'; 2 | ALTER TABLE tickets.tickets ALTER COLUMN approved DROP DEFAULT; 3 | ALTER TABLE tickets.history ADD COLUMN approved BOOLEAN; 4 | 5 | DROP TRIGGER log_update ON tickets.tickets; 6 | DROP FUNCTION tickets.log_ticket_update; 7 | 8 | CREATE FUNCTION tickets.log_ticket_update() 9 | RETURNS TRIGGER AS $log_ticket_update$ 10 | DECLARE 11 | last_version INT; 12 | BEGIN 13 | SELECT version INTO last_version 14 | FROM tickets.history 15 | WHERE id = OLD.id 16 | ORDER BY version DESC LIMIT 1; 17 | IF NOT FOUND THEN 18 | INSERT INTO tickets.history 19 | ( version, last_modified_at, id, type, stage, status, modid, targetid, roleid, auditid 20 | , duration, comment, approved, list_msgid, delivered_id, created_at, modified_by ) 21 | VALUES 22 | ( 0 23 | , OLD.created_at 24 | , OLD.id 25 | , OLD.type 26 | , OLD.stage 27 | , OLD.status 28 | , OLD.modid 29 | , OLD.targetid 30 | , OLD.roleid 31 | , OLD.auditid 32 | , OLD.duration 33 | , OLD.comment 34 | , OLD.approved 35 | , OLD.list_msgid 36 | , OLD.delivered_id 37 | , OLD.created_at 38 | , OLD.modified_by 39 | ); 40 | last_version = 0; 41 | END IF; 42 | INSERT INTO tickets.history 43 | ( version, last_modified_at, id, type, stage, status, modid, targetid, roleid, auditid 44 | , duration, comment, approved, list_msgid, delivered_id, created_at, modified_by ) 45 | VALUES 46 | ( last_version + 1 47 | , CURRENT_TIMESTAMP AT TIME ZONE 'UTC' 48 | , NEW.id 49 | , NULLIF(NEW.type, OLD.type) 50 | , NULLIF(NEW.stage, OLD.stage) 51 | , NULLIF(NEW.status, OLD.status) 52 | , NULLIF(NEW.modid, OLD.modid) 53 | , NULLIF(NEW.targetid, OLD.targetid) 54 | , NULLIF(NEW.roleid, OLD.roleid) 55 | , NULLIF(NEW.auditid, OLD.auditid) 56 | , NULLIF(NEW.duration, OLD.duration) 57 | , NULLIF(NEW.comment, OLD.comment) 58 | , NULLIF(NEW.approved, OLD.approved) 59 | , NULLIF(NEW.list_msgid, OLD.list_msgid) 60 | , NULLIF(NEW.delivered_id, OLD.delivered_id) 61 | , NULLIF(NEW.created_at, OLD.created_at) 62 | , NEW.modified_by 63 | ); 64 | RETURN NULL; 65 | END 66 | $log_ticket_update$ LANGUAGE plpgsql; 67 | 68 | CREATE TRIGGER log_update 69 | AFTER UPDATE ON 70 | tickets.tickets 71 | FOR EACH ROW 72 | WHEN 73 | (OLD.* IS DISTINCT FROM NEW.*) 74 | EXECUTE PROCEDURE 75 | tickets.log_ticket_update(); 76 | -------------------------------------------------------------------------------- /migrations/util.db.kv-234193f845d63b034c0d83e99a0282f94278c0fe-9e12e532a2abd0c7ea4b4aff6357b3affd7d8352.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE kv ALTER COLUMN key TYPE TEXT ARRAY USING ARRAY[key]; 2 | -------------------------------------------------------------------------------- /plugins/appeals.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import TYPE_CHECKING, NoReturn, Optional, Protocol, cast 3 | 4 | from aiohttp import ( 5 | ClientSession, 6 | DummyCookieJar, 7 | FormData, 8 | TraceConfig, 9 | TraceRequestEndParams, 10 | TraceRequestExceptionParams, 11 | TraceRequestStartParams, 12 | ) 13 | from aiohttp.web import ( 14 | Application, 15 | AppRunner, 16 | HTTPBadRequest, 17 | HTTPForbidden, 18 | HTTPInternalServerError, 19 | HTTPSeeOther, 20 | HTTPTemporaryRedirect, 21 | Request, 22 | RouteTableDef, 23 | TCPSite, 24 | ) 25 | import aiohttp_session 26 | import discord 27 | from discord import AllowedMentions, ButtonStyle, Interaction, InteractionType, PartialMessage, TextChannel 28 | from discord.abc import Messageable 29 | from discord.ext.commands import Cog 30 | from discord.ui import Button, View 31 | from sqlalchemy import BigInteger, Integer, func, select 32 | from sqlalchemy.ext.asyncio import async_sessionmaker 33 | import sqlalchemy.orm 34 | from sqlalchemy.orm import Mapped, mapped_column 35 | from sqlalchemy.schema import CreateSchema 36 | from yarl import URL 37 | 38 | from bot.client import client 39 | from bot.cogs import cog 40 | import plugins 41 | import plugins.tickets 42 | import util.db 43 | import util.db.kv 44 | from util.discord import PlainItem, chunk_messages, format, retry 45 | 46 | 47 | if TYPE_CHECKING: 48 | import discord.types.interactions 49 | 50 | 51 | logger = logging.getLogger(__name__) 52 | 53 | registry: sqlalchemy.orm.registry = sqlalchemy.orm.registry() 54 | 55 | sessionmaker = async_sessionmaker(util.db.engine, future=True, expire_on_commit=False) 56 | 57 | 58 | @registry.mapped 59 | class Appeal: 60 | __tablename__ = "appeals" 61 | __table_args__ = {"schema": "appeals"} 62 | 63 | id: Mapped[int] = mapped_column(Integer, primary_key=True) 64 | user_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 65 | channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 66 | thread_id: Mapped[Optional[int]] = mapped_column(BigInteger) 67 | message_id: Mapped[Optional[int]] = mapped_column(BigInteger) 68 | 69 | if TYPE_CHECKING: 70 | 71 | def __init__( 72 | self, *, user_id: int, channel_id: int, thread_id: Optional[int], message_id: Optional[int] = ... 73 | ) -> None: ... 74 | 75 | async def get_message(self) -> Optional[PartialMessage]: 76 | if self.message_id is None: 77 | return None 78 | channel_id = self.channel_id if self.thread_id is None else self.thread_id 79 | try: 80 | if not isinstance(channel := await client.fetch_channel(channel_id), Messageable): 81 | return None 82 | except (discord.NotFound, discord.Forbidden): 83 | return None 84 | return channel.get_partial_message(self.message_id) 85 | 86 | 87 | class AppealsConf(Protocol): 88 | client_id: int 89 | client_secret: str 90 | guild: int 91 | channel: int 92 | max_appeals: int 93 | 94 | 95 | conf: AppealsConf 96 | http: ClientSession 97 | runner: AppRunner 98 | 99 | 100 | class AppealView(View): 101 | def __init__(self, appeal_id: int) -> None: 102 | super().__init__(timeout=None) 103 | self.add_item( 104 | Button(style=ButtonStyle.danger, label="Close", custom_id="{}:{}:Close".format(__name__, appeal_id)) 105 | ) 106 | 107 | 108 | AUTHORIZE_URL = URL("https://discord.com/oauth2/authorize") 109 | TOKEN_URL = URL("https://discord.com/api/oauth2/token") 110 | ME_URL = URL("https://discord.com/api/users/@me") 111 | 112 | # https://github.com/discord-math/discord-math.github.io/blob/main/_appeals/form.md 113 | APPEAL_FORM_URL = URL("https://mathematics.gg/appeals/form") 114 | # https://github.com/discord-math/discord-math.github.io/blob/main/_appeals/success.md 115 | APPEAL_SUCCESS_URL = URL("https://mathematics.gg/appeals/success") 116 | 117 | CALLBACK_URL = URL("https://api.mathematics.gg/appeals/callback") 118 | 119 | routes = RouteTableDef() 120 | 121 | 122 | @routes.get("/auth") 123 | async def get_auth(request: Request) -> NoReturn: 124 | query = { 125 | "client_id": conf.client_id, 126 | "response_type": "code", 127 | "scope": "identify", 128 | "redirect_uri": str(CALLBACK_URL), 129 | } 130 | raise HTTPSeeOther(location=AUTHORIZE_URL.with_query(query)) 131 | 132 | 133 | async def get_user_id(token: str) -> int: 134 | async with http.get(ME_URL, headers={"Authorization": "Bearer {}".format(token)}) as r: 135 | user = await r.json() 136 | 137 | try: 138 | return int(user["id"]) 139 | except (KeyError, ValueError): 140 | raise HTTPInternalServerError(text="No user id") 141 | 142 | 143 | @routes.get("/callback") 144 | async def get_callback(request: Request) -> NoReturn: 145 | if "error" in request.query: 146 | raise HTTPForbidden(text=request.query["error"]) 147 | if "code" not in request.query: 148 | raise HTTPBadRequest(text="No access code provided") 149 | 150 | body = { 151 | "client_id": conf.client_id, 152 | "client_secret": conf.client_secret, 153 | "grant_type": "authorization_code", 154 | "code": request.query["code"], 155 | "redirect_uri": str(CALLBACK_URL), 156 | } 157 | async with http.post(TOKEN_URL, headers={"Accept": "application/json"}, data=FormData(body)) as r: 158 | result = await r.json() 159 | 160 | if "error_description" in result: 161 | raise HTTPForbidden(text=str(result["error_description"])) 162 | if "error" in result: 163 | raise HTTPForbidden(text=str(result["error"])) 164 | 165 | if "access_token" not in result: 166 | raise HTTPInternalServerError(text="No access token") 167 | token = result["access_token"] 168 | 169 | user_id = await get_user_id(token) 170 | 171 | async with sessionmaker() as session: 172 | # TODO: TOCTOU 173 | stmt = select(func.count(Appeal.id)).where(Appeal.user_id == user_id) 174 | num_appeals = cast(int, (await session.execute(stmt)).scalar()) 175 | if num_appeals >= conf.max_appeals: 176 | raise HTTPForbidden(text="You already have {} active appeals".format(conf.max_appeals)) 177 | 178 | in_guild = False 179 | in_banlist = False 180 | if guild := client.get_guild(conf.guild): 181 | async for entry in guild.bans(after=discord.Object(user_id - 1), limit=1): 182 | if entry.user.id == user_id: 183 | in_banlist = True 184 | else: 185 | if guild.get_member(user_id): 186 | in_guild = True 187 | 188 | cookie = await aiohttp_session.new_session(request) 189 | cookie["token"] = token 190 | 191 | if in_guild: 192 | location = APPEAL_FORM_URL.with_fragment("in_guild") 193 | elif not in_banlist: 194 | location = APPEAL_FORM_URL.with_fragment("not_in_banlist") 195 | else: 196 | location = APPEAL_FORM_URL 197 | raise HTTPTemporaryRedirect(location=location) 198 | 199 | 200 | @routes.post("/submit") 201 | async def post_submit(request: Request) -> NoReturn: 202 | cookie = await aiohttp_session.get_session(request) 203 | if "token" not in cookie: 204 | await get_auth(request) 205 | token = cookie["token"] 206 | 207 | post_data = await request.post() 208 | if "type" not in post_data or "appeal" not in post_data: 209 | raise HTTPBadRequest() 210 | kind = post_data["type"] 211 | reason = post_data.get("reason", "") 212 | appeal = post_data["appeal"] 213 | if not isinstance(kind, str) or not isinstance(reason, str) or not isinstance(appeal, str): 214 | raise HTTPBadRequest() 215 | kind = kind[:11] 216 | reason = reason[:4000] 217 | appeal = appeal[:4000] 218 | 219 | user_id = await get_user_id(token) 220 | 221 | async with sessionmaker() as session: 222 | # TODO: TOCTOU 223 | stmt = select(func.count(Appeal.id)).where(Appeal.user_id == user_id) 224 | num_appeals = cast(int, (await session.execute(stmt)).scalar()) 225 | if num_appeals >= conf.max_appeals: 226 | raise HTTPForbidden(text="You already have {} active appeals".format(conf.max_appeals)) 227 | 228 | if not (guild := client.get_guild(conf.guild)): 229 | raise HTTPInternalServerError() 230 | if not isinstance(channel := guild.get_channel(conf.channel), TextChannel): 231 | raise HTTPInternalServerError() 232 | 233 | last_message = None 234 | for content, _ in chunk_messages( 235 | [ 236 | PlainItem(format("**Ban Appeal from** {!m}:\n\n", user_id)), 237 | PlainItem("**Type:** {}\n".format(kind)), 238 | PlainItem("**Reason:** "), 239 | PlainItem(reason), 240 | PlainItem("\n**Appeal:** "), 241 | PlainItem(appeal), 242 | ] 243 | ): 244 | last_message = await retry( 245 | lambda: channel.send(content, allowed_mentions=AllowedMentions.none()), attempts=10 246 | ) 247 | assert last_message 248 | 249 | thread = await retry(lambda: last_message.create_thread(name=str(user_id)), attempts=10) 250 | 251 | appeal = Appeal(user_id=user_id, channel_id=channel.id, thread_id=thread.id) 252 | session.add(appeal) 253 | await session.commit() 254 | 255 | msg = await retry(lambda: thread.send(view=AppealView(appeal.id)), attempts=10) 256 | appeal.message_id = msg.id 257 | await session.commit() 258 | 259 | embeds = plugins.tickets.summarise_tickets( 260 | await plugins.tickets.visible_tickets(session, user_id), "Tickets for {}".format(user_id), dm=False 261 | ) 262 | if embeds: 263 | embeds = list(embeds) 264 | for i in range(0, len(embeds), 10): 265 | await thread.send(embeds=embeds[i : i + 10]) 266 | 267 | raise HTTPSeeOther(location=APPEAL_SUCCESS_URL) 268 | 269 | 270 | app = Application() 271 | app.add_routes(routes) 272 | aiohttp_session.setup(app, aiohttp_session.SimpleCookieStorage()) 273 | 274 | 275 | async def on_request_start(session: ClientSession, context: object, params: TraceRequestStartParams) -> None: 276 | logger.debug("Sending request to {}".format(params.url)) 277 | 278 | 279 | async def on_request_end(session: ClientSession, context: object, params: TraceRequestEndParams) -> None: 280 | logger.debug("Request to {} received {}".format(params.url, params.response.status)) 281 | 282 | 283 | async def on_request_exception(session: ClientSession, context: object, params: TraceRequestExceptionParams) -> None: 284 | logger.debug("Request to {} received exception".format(params.url), exc_info=params.exception) 285 | 286 | 287 | @plugins.init 288 | async def init() -> None: 289 | global conf, http, runner 290 | 291 | conf = cast(AppealsConf, await util.db.kv.load(__name__)) 292 | 293 | await util.db.init(util.db.get_ddl(CreateSchema("appeals"), registry.metadata.create_all)) 294 | 295 | trace_config = TraceConfig() 296 | trace_config.on_request_start.append(on_request_start) 297 | trace_config.on_request_end.append(on_request_end) 298 | trace_config.on_request_exception.append(on_request_exception) 299 | http = ClientSession(cookie_jar=DummyCookieJar(), trace_configs=[trace_config]) 300 | plugins.finalizer(http.close) 301 | 302 | runner = AppRunner(app) 303 | await runner.setup() 304 | plugins.finalizer(runner.cleanup) 305 | site = TCPSite(runner, "127.0.0.1", 16720) 306 | await site.start() 307 | 308 | 309 | async def close_appeal(interaction: Interaction, appeal_id: int) -> None: 310 | async with sessionmaker() as session: 311 | if not (appeal := await session.get(Appeal, appeal_id)): 312 | return 313 | 314 | await session.delete(appeal) 315 | if msg := await appeal.get_message(): 316 | try: 317 | await retry( 318 | lambda: msg.edit(content="Appeal handled (user may create new ones).", view=None), attempts=10 319 | ) 320 | except discord.HTTPException: 321 | pass 322 | await session.commit() 323 | 324 | 325 | @cog 326 | class AppealsCog(Cog): 327 | @Cog.listener() 328 | async def on_interaction(self, interaction: Interaction) -> None: 329 | if interaction.type != InteractionType.component or interaction.data is None: 330 | return 331 | data = cast("discord.types.interactions.MessageComponentInteractionData", interaction.data) 332 | if data["component_type"] != 2: 333 | return 334 | if ":" not in data["custom_id"]: 335 | return 336 | mod, rest = data["custom_id"].split(":", 1) 337 | if mod != __name__ or ":" not in rest: 338 | return 339 | appeal_id, action = rest.split(":", 1) 340 | try: 341 | appeal_id = int(appeal_id) 342 | except ValueError: 343 | return 344 | if action == "Close": 345 | await close_appeal(interaction, appeal_id) 346 | -------------------------------------------------------------------------------- /plugins/bot_manager.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import importlib 3 | import sys 4 | import traceback 5 | 6 | from discord.ext.commands import command 7 | 8 | from bot.acl import privileged 9 | from bot.commands import Context, cleanup, plugin_command 10 | import plugins 11 | from util.discord import CodeItem, Typing, chunk_messages, format 12 | import util.restart 13 | 14 | 15 | def get_current_manager(): 16 | manager = plugins.PluginManager.of(__name__) 17 | assert manager 18 | return manager 19 | 20 | 21 | manager = get_current_manager() 22 | 23 | 24 | @plugin_command 25 | @command("restart") 26 | @privileged 27 | async def restart_command(ctx: Context) -> None: 28 | """Restart the bot process.""" 29 | await ctx.send("Restarting...") 30 | util.restart.restart() 31 | 32 | 33 | class PluginConverter(str): 34 | @classmethod 35 | async def convert(cls, ctx: Context, arg: str) -> str: 36 | if "." not in arg: 37 | arg = "plugins." + arg 38 | return arg 39 | 40 | 41 | async def reply_exception(ctx: Context) -> None: 42 | _, exc, tb = sys.exc_info() 43 | for content, files in chunk_messages( 44 | (CodeItem("".join(traceback.format_exception(None, exc, tb)), language="py", filename="error.txt"),) 45 | ): 46 | await ctx.send(content, files=files) 47 | del tb 48 | 49 | 50 | @plugin_command 51 | @cleanup 52 | @command("load") 53 | @privileged 54 | async def load_command(ctx: Context, plugin: PluginConverter) -> None: 55 | """Load a plugin.""" 56 | try: 57 | async with Typing(ctx): 58 | await manager.load(plugin) 59 | except: 60 | await reply_exception(ctx) 61 | else: 62 | await ctx.send("\u2705") 63 | 64 | 65 | @plugin_command 66 | @cleanup 67 | @command("reload") 68 | @privileged 69 | async def reload_command(ctx: Context, plugin: PluginConverter) -> None: 70 | """Reload a plugin.""" 71 | try: 72 | async with Typing(ctx): 73 | await manager.reload(plugin) 74 | except: 75 | await reply_exception(ctx) 76 | else: 77 | await ctx.send("\u2705") 78 | 79 | 80 | @plugin_command 81 | @cleanup 82 | @command("unsafereload") 83 | @privileged 84 | async def unsafe_reload_command(ctx: Context, plugin: PluginConverter) -> None: 85 | """Reload a plugin without its dependents.""" 86 | try: 87 | async with Typing(ctx): 88 | await manager.unsafe_reload(plugin) 89 | except: 90 | await reply_exception(ctx) 91 | else: 92 | await ctx.send("\u2705") 93 | 94 | 95 | @plugin_command 96 | @cleanup 97 | @command("unload") 98 | @privileged 99 | async def unload_command(ctx: Context, plugin: PluginConverter) -> None: 100 | """Unload a plugin.""" 101 | try: 102 | async with Typing(ctx): 103 | await manager.unload(plugin) 104 | except: 105 | await reply_exception(ctx) 106 | else: 107 | await ctx.send("\u2705") 108 | 109 | 110 | @plugin_command 111 | @cleanup 112 | @command("unsafeunload") 113 | @privileged 114 | async def unsafe_unload_command(ctx: Context, plugin: PluginConverter) -> None: 115 | """Unload a plugin without its dependents.""" 116 | try: 117 | async with Typing(ctx): 118 | await manager.unsafe_unload(plugin) 119 | except: 120 | await reply_exception(ctx) 121 | else: 122 | await ctx.send("\u2705") 123 | 124 | 125 | @plugin_command 126 | @cleanup 127 | @command("reloadmod") 128 | @privileged 129 | async def reloadmod_command(ctx: Context, module: str) -> None: 130 | """Reload a module.""" 131 | try: 132 | importlib.reload(sys.modules[module]) 133 | except: 134 | await reply_exception(ctx) 135 | else: 136 | await ctx.send("\u2705") 137 | 138 | 139 | @plugin_command 140 | @cleanup 141 | @command("plugins") 142 | @privileged 143 | async def plugins_command(ctx: Context) -> None: 144 | """List loaded plugins.""" 145 | output = defaultdict(list) 146 | for name in sys.modules: 147 | if manager.is_plugin(name): 148 | try: 149 | key = manager.plugins[name].state.name 150 | except KeyError: 151 | key = "???" 152 | output[key].append(name) 153 | await ctx.send( 154 | "\n".join( 155 | format("- {!i}: {}", key, ", ".join(format("{!i}", name) for name in sorted(plugins))) 156 | for key, plugins in output.items() 157 | ) 158 | ) 159 | -------------------------------------------------------------------------------- /plugins/discord_log.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import sys 4 | from threading import Lock 5 | from types import FrameType 6 | from typing import TYPE_CHECKING, Iterator, List, Literal, Optional, Union, cast 7 | 8 | from discord import Client 9 | from discord.ext.commands import command 10 | from sqlalchemy import BigInteger, Computed 11 | from sqlalchemy.ext.asyncio import async_sessionmaker 12 | import sqlalchemy.orm 13 | from sqlalchemy.orm import Mapped, mapped_column 14 | 15 | from bot.client import client 16 | from bot.commands import Context 17 | from bot.config import plugin_config_command 18 | import plugins 19 | import util.db.kv 20 | from util.discord import CodeItem, PartialTextChannelConverter, chunk_messages, format 21 | 22 | 23 | registry = sqlalchemy.orm.registry() 24 | sessionmaker = async_sessionmaker(util.db.engine, expire_on_commit=False) 25 | 26 | 27 | @registry.mapped 28 | class GlobalConfig: 29 | __tablename__ = "syslog_config" 30 | id: Mapped[int] = mapped_column(BigInteger, Computed("0"), primary_key=True) 31 | channel_id: Mapped[Optional[int]] = mapped_column(BigInteger) 32 | 33 | if TYPE_CHECKING: 34 | 35 | def __init__(self, *, id: int = ..., channel_id: Optional[int] = ...) -> None: ... 36 | 37 | 38 | conf: GlobalConfig 39 | logger: logging.Logger = logging.getLogger(__name__) 40 | 41 | 42 | class DiscordHandler(logging.Handler): 43 | __slots__ = "queue", "lock" 44 | queue: List[str] 45 | thread_lock: Lock 46 | 47 | def __init__(self, level: int = logging.NOTSET): 48 | self.queue = [] 49 | self.thread_lock = Lock() # just in case 50 | return super().__init__(level) 51 | 52 | def queue_pop(self) -> Optional[str]: 53 | with self.thread_lock: 54 | if len(self.queue) == 0: 55 | return None 56 | return self.queue.pop(0) 57 | 58 | async def log_discord(self, chan_id: int, client: Client) -> None: 59 | try: 60 | 61 | def fill_items() -> Iterator[CodeItem]: 62 | while (text := self.queue_pop()) is not None: 63 | yield CodeItem(text, language="py", filename="log.txt") 64 | 65 | for content, files in chunk_messages(fill_items()): 66 | await client.get_partial_messageable(chan_id).send(content, files=files) 67 | except: 68 | logger.critical("Could not report exception to Discord", exc_info=True, extra={"no_discord": True}) 69 | 70 | def emit(self, record: logging.LogRecord) -> None: 71 | if hasattr(record, "no_discord"): 72 | return 73 | try: 74 | if asyncio.get_event_loop().is_closed(): 75 | return 76 | except: 77 | return 78 | 79 | if client.is_closed(): 80 | return 81 | 82 | if conf.channel_id is None: 83 | return 84 | 85 | text = self.format(record) 86 | 87 | # Check the traceback for whether we are nested inside log_discord, 88 | # as a last resort measure 89 | frame: Optional[FrameType] = sys._getframe() 90 | while frame: 91 | if frame.f_code == self.log_discord.__code__: 92 | del frame 93 | return 94 | frame = frame.f_back 95 | del frame 96 | 97 | with self.thread_lock: 98 | if self.queue: 99 | self.queue.append(text) 100 | else: 101 | self.queue.append(text) 102 | asyncio.create_task(self.log_discord(conf.channel_id, client), name="Logging to Discord") 103 | 104 | 105 | @plugins.init 106 | async def init() -> None: 107 | global conf 108 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 109 | async with sessionmaker() as session: 110 | c = await session.get(GlobalConfig, 0) 111 | if not c: 112 | c = GlobalConfig(channel_id=cast(Optional[int], (await util.db.kv.load(__name__)).channel)) 113 | session.add(c) 114 | await session.commit() 115 | conf = c 116 | 117 | handler: logging.Handler = DiscordHandler(logging.ERROR) 118 | handler.setFormatter(logging.Formatter("%(name)s %(levelname)s: %(message)s")) 119 | logging.getLogger().addHandler(handler) 120 | 121 | def finalizer() -> None: 122 | logging.getLogger().removeHandler(handler) 123 | 124 | plugins.finalizer(finalizer) 125 | 126 | 127 | @plugin_config_command 128 | @command("syslog") 129 | async def config(ctx: Context, channel: Optional[Union[Literal["None"], PartialTextChannelConverter]]) -> None: 130 | global conf 131 | async with sessionmaker() as session: 132 | c = await session.get(GlobalConfig, 0) 133 | assert c 134 | if channel is None: 135 | await ctx.send("None" if c.channel_id is None else format("{!c}", conf.channel_id)) 136 | else: 137 | c.channel_id = None if channel == "None" else channel.id 138 | await session.commit() 139 | conf = c 140 | await ctx.send("\u2705") 141 | -------------------------------------------------------------------------------- /plugins/eval.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import builtins 3 | import inspect 4 | from io import StringIO 5 | import sys 6 | import traceback 7 | from types import FunctionType 8 | from typing import Any, Callable, Dict, TypeVar, Union 9 | 10 | from discord.ext.commands import Greedy, command 11 | 12 | from bot.acl import privileged 13 | from bot.client import client 14 | from bot.commands import Context, cleanup, plugin_command 15 | from util.discord import CodeBlock, CodeItem, Inline, PlainItem, Typing, chunk_messages 16 | 17 | 18 | T = TypeVar("T") 19 | 20 | 21 | @plugin_command 22 | @cleanup 23 | @command("exec", aliases=["eval"]) 24 | @privileged 25 | async def exec_command(ctx: Context, args: Greedy[Union[CodeBlock, Inline, str]]) -> None: 26 | """ 27 | Execute all code blocks in the command line as python code. 28 | The code can be an expression on a series of statements. The code has all loaded modules in scope, as well as "ctx" 29 | and "client". The print function is redirected. The code can also use top-level "await". 30 | """ 31 | outputs = [] 32 | code_scope: Dict[str, object] = dict(sys.modules) 33 | # Using real builtins to avoid dependency tracking 34 | code_scope["__builtins__"] = builtins 35 | code_scope.update(builtins.__dict__) 36 | code_scope["ctx"] = ctx 37 | code_scope["client"] = client 38 | 39 | def mk_code_print(fp: StringIO) -> Callable[..., None]: 40 | def code_print(*args: object, sep: str = " ", end: str = "\n", file: Any = fp, flush: bool = False): 41 | return print(*args, sep=sep, end=end, file=file, flush=flush) 42 | 43 | return code_print 44 | 45 | fp = StringIO() 46 | try: 47 | async with Typing(ctx): 48 | for arg in args: 49 | if isinstance(arg, (CodeBlock, Inline)): 50 | fp = StringIO() 51 | outputs.append(fp) 52 | code_scope["print"] = mk_code_print(fp) 53 | try: 54 | code = compile( 55 | arg.text, "".format(ctx.message.id), "eval", ast.PyCF_ALLOW_TOP_LEVEL_AWAIT 56 | ) 57 | except: 58 | code = compile( 59 | arg.text, "".format(ctx.message.id), "exec", ast.PyCF_ALLOW_TOP_LEVEL_AWAIT 60 | ) 61 | fun = FunctionType(code, code_scope) 62 | ret = fun() 63 | if inspect.iscoroutine(ret): 64 | ret = await ret 65 | if ret != None: 66 | mk_code_print(fp)(repr(ret)) 67 | except: 68 | _, exc, tb = sys.exc_info() 69 | mk_code_print(fp)("".join(traceback.format_tb(tb))) 70 | mk_code_print(fp)(repr(exc)) 71 | del tb 72 | 73 | for content, files in chunk_messages( 74 | ( 75 | CodeItem(fp.getvalue(), language="py", filename="output{}.txt".format(i)) 76 | if fp.getvalue() 77 | else PlainItem("\u2705") 78 | ) 79 | for i, fp in enumerate(outputs, start=1) 80 | ): 81 | await ctx.send(content, files=files) 82 | -------------------------------------------------------------------------------- /plugins/help.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, List, Mapping, Optional, Set 3 | 4 | import discord.ext.commands 5 | from discord.ext.commands import Cog, Command, Group 6 | 7 | import bot.acl 8 | from bot.acl import ACLCheck, EvalResult, evaluate_acl, evaluate_ctx 9 | from bot.client import client 10 | import plugins 11 | from util.discord import format 12 | 13 | 14 | class HelpCommand(discord.ext.commands.HelpCommand): 15 | async def send_bot_help(self, mapping: Mapping[Optional[Cog], List[Command[Any, Any, Any]]]) -> None: 16 | if self.context is None: 17 | return 18 | 19 | commands: Mapping[str, Set[Command[Any, Any, Any]]] = defaultdict(set) 20 | for cmds in mapping.values(): 21 | for cmd in cmds: 22 | allowed = True 23 | for check in cmd.checks: 24 | if isinstance(check, ACLCheck): 25 | acl = bot.acl.commands.get(cmd.qualified_name) 26 | if evaluate_acl(acl, self.context.author, None) == EvalResult.FALSE: 27 | allowed = False 28 | break 29 | if allowed: 30 | commands[cmd.module].add(cmd) 31 | prefix = self.context.prefix or "" 32 | 33 | listing = "\n".join( 34 | "- {}: {}".format( 35 | module.rsplit(".", 1)[-1], 36 | ", ".join(format("{!i}", prefix + cmd.name) for cmd in sorted(cmds, key=lambda c: c.name)), 37 | ) 38 | for module, cmds in sorted(commands.items(), key=lambda mc: mc[0].rsplit(".", 1)[-1]) 39 | ) 40 | 41 | await self.get_destination().send( 42 | format( 43 | "**Commands:**\n{}\n\nType {!i} for more info on a command.", 44 | listing, 45 | prefix + (self.invoked_with or "") + " ", 46 | ) 47 | ) 48 | 49 | async def send_command_help(self, command: Command[Any, Any, Any]) -> None: 50 | if self.context is None: 51 | return 52 | 53 | prefix = self.context.prefix or "" 54 | usage = prefix + " ".join(s for s in [command.qualified_name, command.signature] if s) 55 | akanote = ( 56 | "" 57 | if not command.aliases 58 | else "\naka: {}".format(", ".join(format("{!i}", alias) for alias in command.aliases)) 59 | ) 60 | desc = command.help 61 | 62 | privnote = "" 63 | for check in command.checks: 64 | if isinstance(check, ACLCheck): 65 | acl = bot.acl.commands.get(command.qualified_name) 66 | if evaluate_acl(acl, self.context.author, None) == EvalResult.FALSE: 67 | privnote = "\nYou are not allowed to use this command." 68 | break 69 | elif evaluate_acl(acl, *evaluate_ctx(self.context)) == EvalResult.FALSE: # type: ignore 70 | privnote = "\nYou are not allowed to use this command here specifically." 71 | break 72 | 73 | await self.get_destination().send(format("**Usage:** {!i}{}\n{}{}", usage, akanote, desc, privnote)) 74 | 75 | async def send_group_help(self, group: Group[Any, Any, Any]) -> None: 76 | if self.context is None: 77 | return 78 | 79 | prefix = self.context.prefix or "" 80 | args = [group.qualified_name, group.signature] 81 | if not group.invoke_without_command: 82 | args.append("...") 83 | 84 | usage = prefix + " ".join(s for s in args if s) 85 | akanote = ( 86 | "" if not group.aliases else "\naka: {}".format(", ".join(format("{!i}", alias) for alias in group.aliases)) 87 | ) 88 | desc = group.help 89 | 90 | subcommands = [] 91 | for cmd in sorted(group.walk_commands(), key=lambda c: c.qualified_name): 92 | args = [cmd.name, cmd.signature] 93 | if isinstance(cmd, Group) and not cmd.invoke_without_command: 94 | continue 95 | for parent in cmd.parents: 96 | if not parent.invoke_without_command: 97 | args.insert(0, parent.signature) 98 | args.insert(0, parent.name) 99 | subcommands.append(format("- {!i}", prefix + " ".join(s for s in args if s))) 100 | 101 | privnote = "" 102 | for check in group.checks: 103 | if isinstance(check, ACLCheck): 104 | acl = bot.acl.commands.get(group.qualified_name) 105 | if evaluate_acl(acl, self.context.author, None) == EvalResult.FALSE: 106 | privnote = "\nYou are not allowed to use this command." 107 | break 108 | elif evaluate_acl(acl, *evaluate_ctx(self.context)) == EvalResult.FALSE: # type: ignore 109 | privnote = "\nYou are not allowed to use this command here specifically." 110 | break 111 | 112 | await self.get_destination().send( 113 | format( 114 | "**Usage:** {!i}{}\n{}\n**Sub-commands:**\n{}{}\n\nType {!i} for more info on a sub-command.", 115 | usage, 116 | akanote, 117 | desc, 118 | "\n".join(subcommands), 119 | privnote, 120 | prefix + (self.invoked_with or "") + " " + group.qualified_name + " ", 121 | ) 122 | ) 123 | 124 | 125 | old_help = client.help_command 126 | client.help_command = HelpCommand() 127 | 128 | 129 | @plugins.finalizer 130 | def restore_help_command() -> None: 131 | client.help_command = old_help 132 | -------------------------------------------------------------------------------- /plugins/keepvanity.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import discord 4 | from discord import Guild, Message, MessageType 5 | from discord.ext.commands import group 6 | from sqlalchemy import TEXT, BigInteger, select 7 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker 8 | import sqlalchemy.orm 9 | from sqlalchemy.orm import Mapped, mapped_column 10 | 11 | from bot.client import client 12 | from bot.cogs import Cog, cog 13 | from bot.commands import Context 14 | from bot.config import plugin_config_command 15 | import plugins 16 | import util.db.kv 17 | from util.discord import PartialGuildConverter, format 18 | 19 | 20 | registry = sqlalchemy.orm.registry() 21 | sessionmaker = async_sessionmaker(util.db.engine) 22 | 23 | 24 | @registry.mapped 25 | class GuildConfig: 26 | __tablename__ = "keep_vanity" 27 | 28 | guild_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 29 | vanity: Mapped[str] = mapped_column(TEXT, nullable=False) 30 | 31 | if TYPE_CHECKING: 32 | 33 | def __init__(self, *, guild_id: int, vanity: str) -> None: ... 34 | 35 | 36 | @plugins.init 37 | async def init() -> None: 38 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 39 | 40 | async with sessionmaker() as session: 41 | conf = await util.db.kv.load(__name__) 42 | if isinstance(guild_id := conf.guild, int): 43 | if isinstance(vanity := conf.vanity, str): 44 | session.add(GuildConfig(guild_id=guild_id, vanity=vanity)) 45 | conf.guild_id = None 46 | conf.vanity = None 47 | await session.commit() 48 | await conf 49 | 50 | for guild in client.guilds: 51 | await check_guild_vanity(session, guild) 52 | 53 | 54 | async def check_guild_vanity(session: AsyncSession, guild: Guild) -> None: 55 | if conf := await session.get(GuildConfig, guild.id): 56 | try: 57 | if await guild.vanity_invite() is not None: 58 | return 59 | except discord.Forbidden: 60 | return 61 | except discord.NotFound: 62 | pass 63 | await guild.edit(vanity_code=conf.vanity) 64 | 65 | 66 | @cog 67 | class KeepVanity(Cog): 68 | """Restores the guild vanity URL as soon as enough boosts are available""" 69 | 70 | @Cog.listener() 71 | async def on_ready(self) -> None: 72 | async with sessionmaker() as session: 73 | for guild in client.guilds: 74 | await check_guild_vanity(session, guild) 75 | 76 | @Cog.listener() 77 | async def on_message(self, msg: Message) -> None: 78 | if msg.type != MessageType.premium_guild_tier_3: 79 | return 80 | if msg.guild is None: 81 | return 82 | async with sessionmaker() as session: 83 | await check_guild_vanity(session, msg.guild) 84 | 85 | 86 | @plugin_config_command 87 | @group("keepvanity", invoke_without_command=True) 88 | async def config(ctx: Context) -> None: 89 | async with sessionmaker() as session: 90 | stmt = select(GuildConfig) 91 | configs = (await session.execute(stmt)).scalars() 92 | await ctx.send( 93 | "\n".join(format("- {!c}: {!i}", conf.guild_id, conf.vanity) for conf in configs) or "No servers registered" 94 | ) 95 | 96 | 97 | @config.command("add") 98 | async def config_add(ctx: Context, server: PartialGuildConverter, vanity: str) -> None: 99 | async with sessionmaker() as session: 100 | session.add(GuildConfig(guild_id=server.id, vanity=vanity)) 101 | await session.commit() 102 | await ctx.send("\u2705") 103 | 104 | 105 | @config.command("remove") 106 | async def config_remove(ctx: Context, server: PartialGuildConverter) -> None: 107 | async with sessionmaker() as session: 108 | await session.delete(await session.get(GuildConfig, server.id)) 109 | await session.commit() 110 | await ctx.send("\u2705") 111 | -------------------------------------------------------------------------------- /plugins/modmail.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime, timedelta 3 | import logging 4 | from typing import TYPE_CHECKING, Dict, Optional 5 | 6 | import discord 7 | from discord import ( 8 | Activity, 9 | ActivityType, 10 | AllowedMentions, 11 | Client, 12 | DMChannel, 13 | Intents, 14 | Message, 15 | MessageReference, 16 | TextChannel, 17 | Thread, 18 | ) 19 | from discord.ext.commands import group 20 | from sqlalchemy import TEXT, TIMESTAMP, BigInteger, select, update 21 | from sqlalchemy.dialects.postgresql import INTERVAL 22 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker 23 | import sqlalchemy.orm 24 | from sqlalchemy.orm import Mapped, mapped_column 25 | from sqlalchemy.schema import CreateSchema 26 | from sqlalchemy.sql.functions import current_timestamp 27 | 28 | import bot.client 29 | from bot.cogs import Cog, cog 30 | from bot.commands import Context 31 | from bot.config import plugin_config_command 32 | from bot.reactions import get_reaction 33 | import plugins 34 | import util.db 35 | from util.discord import ( 36 | DurationConverter, 37 | PartialChannelConverter, 38 | PartialGuildConverter, 39 | PartialRoleConverter, 40 | PlainItem, 41 | UserError, 42 | chunk_messages, 43 | format, 44 | retry, 45 | ) 46 | 47 | 48 | registry = sqlalchemy.orm.registry() 49 | sessionmaker = async_sessionmaker(util.db.engine, expire_on_commit=False) 50 | 51 | 52 | @registry.mapped 53 | class ModmailMessage: 54 | __tablename__ = "messages" 55 | __table_args__ = {"schema": "modmail"} 56 | 57 | dm_channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 58 | dm_message_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 59 | staff_message_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 60 | 61 | if TYPE_CHECKING: 62 | 63 | def __init__(self, *, dm_channel_id: int, dm_message_id: int, staff_message_id: int) -> None: ... 64 | 65 | 66 | @registry.mapped 67 | class ModmailThread: 68 | __tablename__ = "threads" 69 | __table_args__ = {"schema": "modmail"} 70 | 71 | user_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 72 | thread_first_message_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 73 | last_used: Mapped[datetime] = mapped_column(TIMESTAMP, nullable=False) 74 | 75 | if TYPE_CHECKING: 76 | 77 | def __init__(self, *, user_id: int, thread_first_message_id: int, last_used: datetime) -> None: ... 78 | 79 | 80 | @registry.mapped 81 | class GuildConfig: 82 | __tablename__ = "guilds" 83 | __table_args__ = {"schema": "modmail"} 84 | 85 | guild_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 86 | token: Mapped[str] = mapped_column(TEXT, nullable=False) 87 | channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 88 | role_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 89 | thread_expiry: Mapped[timedelta] = mapped_column(INTERVAL, nullable=False) 90 | 91 | if TYPE_CHECKING: 92 | 93 | def __init__( 94 | self, *, guild_id: int, token: str, channel_id: int, role_id: int, thread_expiry: timedelta 95 | ) -> None: ... 96 | 97 | 98 | logger: logging.Logger = logging.getLogger(__name__) 99 | 100 | message_map: Dict[int, ModmailMessage] = {} 101 | 102 | 103 | @plugins.init 104 | async def init() -> None: 105 | await util.db.init(util.db.get_ddl(CreateSchema("modmail"), registry.metadata.create_all)) 106 | 107 | async with sessionmaker() as session: 108 | for msg in (await session.execute(select(ModmailMessage))).scalars(): 109 | message_map[msg.staff_message_id] = msg 110 | 111 | 112 | async def add_modmail(source: Message, copy: Message) -> None: 113 | async with sessionmaker() as session: 114 | msg = ModmailMessage(dm_channel_id=source.channel.id, dm_message_id=source.id, staff_message_id=copy.id) 115 | session.add(msg) 116 | await session.commit() 117 | message_map[msg.staff_message_id] = msg 118 | 119 | 120 | async def update_thread(conf: GuildConfig, user_id: int) -> Optional[int]: 121 | async with sessionmaker() as session: 122 | stmt = ( 123 | update(ModmailThread) 124 | .returning(ModmailThread.thread_first_message_id) 125 | .where(ModmailThread.user_id == user_id, ModmailThread.last_used > current_timestamp() - conf.thread_expiry) 126 | .values(last_used=current_timestamp()) 127 | .execution_options(synchronize_session=False) 128 | ) 129 | thread = (await session.execute(stmt)).scalars().first() 130 | await session.commit() 131 | return thread 132 | 133 | 134 | async def create_thread(user_id: int, msg_id: int) -> None: 135 | async with sessionmaker() as session: 136 | session.add( 137 | ModmailThread( 138 | user_id=user_id, thread_first_message_id=msg_id, last_used=current_timestamp() # type: ignore 139 | ) 140 | ) 141 | await session.commit() 142 | 143 | 144 | class ModMailClient(Client): 145 | conf: GuildConfig 146 | 147 | async def on_ready(self) -> None: 148 | await self.change_presence(activity=Activity(type=ActivityType.watching, name="DMs")) 149 | 150 | async def on_error(self, event_method: str, *args: object, **kwargs: object) -> None: 151 | logger.error("Exception in modmail client {}".format(event_method), exc_info=True) 152 | 153 | async def on_message(self, msg: Message) -> None: 154 | if not msg.guild and self.user is not None and msg.author.id != self.user.id: 155 | try: 156 | guild = bot.client.client.get_guild(self.conf.guild_id) 157 | if guild is None: 158 | return 159 | channel = guild.get_channel(self.conf.channel_id) 160 | if not isinstance(channel, (TextChannel, Thread)): 161 | return 162 | role = guild.get_role(self.conf.role_id) 163 | if role is None: 164 | return 165 | except (ValueError, AttributeError): 166 | return 167 | thread_id = await update_thread(self.conf, msg.author.id) 168 | 169 | items = [PlainItem(msg.content)] 170 | 171 | footer = "".join("\n**Attachment:** {} {}".format(att.filename, att.url) for att in msg.attachments) 172 | if thread_id is None: 173 | footer += format("\n{!m}", role) 174 | if footer: 175 | items.append(PlainItem("\n" + footer)) 176 | 177 | mentions = AllowedMentions.none() 178 | mentions.roles = [role] 179 | reference = None 180 | if thread_id is not None: 181 | reference = MessageReference(message_id=thread_id, channel_id=channel.id, fail_if_not_exists=False) 182 | 183 | embed = ( 184 | discord.Embed( 185 | title=format("Modmail from {}#{}", msg.author.name, msg.author.discriminator), 186 | timestamp=msg.created_at, 187 | ) 188 | .add_field(name="From", value=format("{!m}", msg.author)) 189 | .add_field(name="ID", value=msg.author.id) 190 | ) 191 | if reference is not None: 192 | header = await retry( 193 | lambda: channel.send(embed=embed, allowed_mentions=mentions, reference=reference), attempts=10 194 | ) 195 | else: 196 | header = await retry(lambda: channel.send(embed=embed, allowed_mentions=mentions), attempts=10) 197 | await add_modmail(msg, header) 198 | 199 | for content, _ in chunk_messages(items): 200 | copy = await retry(lambda: channel.send(content, allowed_mentions=mentions), attempts=10) 201 | await add_modmail(msg, copy) 202 | 203 | if thread_id is None: 204 | await create_thread(msg.author.id, header.id) 205 | 206 | await msg.add_reaction("\u2709") 207 | 208 | 209 | @cog 210 | class Modmail(Cog): 211 | """Handle modmail messages""" 212 | 213 | @Cog.listener("on_message") 214 | async def modmail_reply(self, msg: Message) -> None: 215 | if msg.author.bot: 216 | return 217 | if msg.reference is None or msg.reference.message_id is None: 218 | return 219 | if msg.reference.message_id not in message_map: 220 | return 221 | if not msg.guild: 222 | return 223 | if not (client := clients.get(msg.guild.id)): 224 | return 225 | modmail = message_map[msg.reference.message_id] 226 | 227 | anon_react = "\U0001F574" 228 | named_react = "\U0001F9CD" 229 | cancel_react = "\u274C" 230 | 231 | try: 232 | query = await msg.channel.send( 233 | "Reply anonymously {}, personally {}, or cancel {}".format(anon_react, named_react, cancel_react) 234 | ) 235 | except (discord.NotFound, discord.Forbidden): 236 | return 237 | 238 | result = await get_reaction( 239 | query, 240 | msg.author, 241 | {anon_react: "anon", named_react: "named", cancel_react: None}, 242 | timeout=120, 243 | unreact=False, 244 | ) 245 | 246 | await query.delete() 247 | if result is None: 248 | await msg.channel.send("Cancelled") 249 | else: 250 | items = [] 251 | if result == "named": 252 | items.append(PlainItem(format("**From {}** {!m}:\n\n", msg.author.display_name, msg.author))) 253 | items.append(PlainItem(msg.content)) 254 | for att in msg.attachments: 255 | items.append(PlainItem("\n**Attachment:** {}".format(att.url))) 256 | 257 | try: 258 | chan = await client.fetch_channel(modmail.dm_channel_id) 259 | if not isinstance(chan, DMChannel): 260 | await msg.channel.send("Could not deliver DM (DM closed)") 261 | return 262 | for content, _ in chunk_messages(items): 263 | await chan.send( 264 | content, 265 | reference=MessageReference( 266 | message_id=modmail.dm_message_id, channel_id=modmail.dm_channel_id, fail_if_not_exists=False 267 | ), 268 | ) 269 | except (discord.NotFound, discord.Forbidden): 270 | await msg.channel.send("Could not deliver DM (User left guild?)") 271 | else: 272 | await msg.channel.send("Signed reply delivered" if result == "named" else "Anonymous reply delivered") 273 | 274 | 275 | clients: Dict[int, ModMailClient] = {} 276 | 277 | 278 | @plugins.init 279 | async def init_task() -> None: 280 | async def run_modmail(conf: GuildConfig) -> None: 281 | client = clients[conf.guild_id] = ModMailClient( 282 | max_messages=None, 283 | intents=Intents(dm_messages=True), 284 | allowed_mentions=AllowedMentions(everyone=False, roles=False), 285 | ) 286 | client.conf = conf 287 | try: 288 | async with client: 289 | await client.start(conf.token, reconnect=True) 290 | except asyncio.CancelledError: 291 | pass 292 | except: 293 | logger.error("Exception in modmail client task", exc_info=True) 294 | finally: 295 | await client.close() 296 | 297 | async with sessionmaker() as session: 298 | for conf in (await session.execute(select(GuildConfig))).scalars(): 299 | task = asyncio.create_task(run_modmail(conf), name="Modmail client for {}".format(conf.guild_id)) 300 | plugins.finalizer(task.cancel) 301 | 302 | 303 | class GuildContext(Context): 304 | guild_id: int 305 | 306 | 307 | @plugin_config_command 308 | @group("modmail") 309 | async def config(ctx: GuildContext, server: PartialGuildConverter) -> None: 310 | ctx.guild_id = server.id 311 | 312 | 313 | @config.command("new") 314 | async def config_new( 315 | ctx: GuildContext, 316 | token: str, 317 | channel: PartialChannelConverter, 318 | role: PartialRoleConverter, 319 | thread_expiry: DurationConverter, 320 | ) -> None: 321 | async with sessionmaker() as session: 322 | session.add( 323 | GuildConfig( 324 | guild_id=ctx.guild_id, token=token, channel_id=channel.id, role_id=role.id, thread_expiry=thread_expiry 325 | ) 326 | ) 327 | await session.commit() 328 | await ctx.send("\u2705") 329 | 330 | 331 | async def get_conf(session: AsyncSession, ctx: GuildContext) -> GuildConfig: 332 | if (conf := await session.get(GuildConfig, ctx.guild_id)) is None: 333 | raise UserError("No config for {}".format(ctx.guild_id)) 334 | return conf 335 | 336 | 337 | @config.command("token") 338 | async def config_token(ctx: GuildContext, token: Optional[str]) -> None: 339 | async with sessionmaker() as session: 340 | conf = await get_conf(session, ctx) 341 | if token is None: 342 | await ctx.send(format("{!i}", conf.token)) 343 | else: 344 | conf.token = token 345 | await session.commit() 346 | await ctx.send("\u2705") 347 | 348 | 349 | @config.command("channel") 350 | async def config_channel(ctx: GuildContext, channel: Optional[PartialChannelConverter]) -> None: 351 | async with sessionmaker() as session: 352 | conf = await get_conf(session, ctx) 353 | if channel is None: 354 | await ctx.send(format("{!c}", conf.channel_id)) 355 | else: 356 | conf.channel_id = channel.id 357 | await session.commit() 358 | await ctx.send("\u2705") 359 | 360 | 361 | @config.command("role") 362 | async def config_role(ctx: GuildContext, role: Optional[PartialRoleConverter]) -> None: 363 | async with sessionmaker() as session: 364 | conf = await get_conf(session, ctx) 365 | if role is None: 366 | await ctx.send(format("{!M}", conf.role_id), allowed_mentions=AllowedMentions.none()) 367 | else: 368 | conf.role_id = role.id 369 | await session.commit() 370 | await ctx.send("\u2705") 371 | 372 | 373 | @config.command("thread_expiry") 374 | async def config_thread_expiry(ctx: GuildContext, thread_expiry: Optional[DurationConverter]) -> None: 375 | async with sessionmaker() as session: 376 | conf = await get_conf(session, ctx) 377 | if thread_expiry is None: 378 | await ctx.send(str(thread_expiry)) 379 | else: 380 | conf.thread_expiry = thread_expiry 381 | await session.commit() 382 | await ctx.send("\u2705") 383 | -------------------------------------------------------------------------------- /plugins/persistence.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, List, Set, cast 2 | 3 | from discord import AllowedMentions, Member 4 | from discord.ext.commands import group 5 | from sqlalchemy import BigInteger, delete, select 6 | from sqlalchemy.dialects.postgresql import insert 7 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker 8 | import sqlalchemy.orm 9 | from sqlalchemy.orm import Mapped, mapped_column 10 | from sqlalchemy.schema import CreateSchema 11 | 12 | from bot.cogs import Cog, cog 13 | from bot.commands import Context 14 | from bot.config import plugin_config_command 15 | import plugins 16 | import util.db 17 | import util.db.kv 18 | from util.discord import PartialRoleConverter, format, retry 19 | 20 | 21 | registry: sqlalchemy.orm.registry = sqlalchemy.orm.registry() 22 | 23 | sessionmaker = async_sessionmaker(util.db.engine) 24 | 25 | 26 | @registry.mapped 27 | class PersistedRole: 28 | __tablename__ = "roles" 29 | __table_args__ = {"schema": "persistence"} 30 | 31 | id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 32 | 33 | if TYPE_CHECKING: 34 | 35 | def __init__(self, *, id: int) -> None: ... 36 | 37 | 38 | @registry.mapped 39 | class MemberRole: 40 | __tablename__ = "member_roles" 41 | __table_args__ = {"schema": "persistence"} 42 | 43 | user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 44 | role_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 45 | 46 | 47 | persisted_roles: Set[int] 48 | 49 | 50 | async def rehash_roles(session: AsyncSession) -> None: 51 | global persisted_roles 52 | stmt = select(PersistedRole.id) 53 | persisted_roles = set((await session.execute(stmt)).scalars()) 54 | 55 | 56 | @plugins.init 57 | async def init() -> None: 58 | global persisted_roles 59 | await util.db.init(util.db.get_ddl(CreateSchema("persistence"), registry.metadata.create_all)) 60 | 61 | async with sessionmaker() as session: 62 | conf = await util.db.kv.load(__name__) 63 | if conf.roles is not None: 64 | for id in cast(List[int], conf.roles): 65 | session.add(PersistedRole(id=id)) 66 | await session.commit() 67 | conf.roles = None 68 | await conf 69 | 70 | await rehash_roles(session) 71 | 72 | 73 | @cog 74 | class Persistence(Cog): 75 | """Role persistence.""" 76 | 77 | @Cog.listener() 78 | async def on_member_remove(self, member: Member) -> None: 79 | role_ids = set(role.id for role in member.roles if role.id in persisted_roles) 80 | if len(role_ids) == 0: 81 | return 82 | async with sessionmaker() as session: 83 | stmt = ( 84 | insert(MemberRole) 85 | .values([{"user_id": member.id, "role_id": role_id} for role_id in role_ids]) 86 | .on_conflict_do_nothing(index_elements=["user_id", "role_id"]) 87 | ) 88 | await session.execute(stmt) 89 | await session.commit() 90 | 91 | @Cog.listener() 92 | async def on_member_join(self, member: Member) -> None: 93 | async with sessionmaker() as session: 94 | stmt = delete(MemberRole).where(MemberRole.user_id == member.id).returning(MemberRole.role_id) 95 | roles = [] 96 | for (role_id,) in await session.execute(stmt): 97 | if (role := member.guild.get_role(role_id)) is not None: 98 | roles.append(role) 99 | if len(roles) == 0: 100 | return 101 | await retry(lambda: member.add_roles(*roles, reason="Role persistence", atomic=False)) 102 | await session.commit() 103 | 104 | 105 | async def drop_persistent_role(*, user_id: int, role_id: int) -> None: 106 | async with sessionmaker() as session: 107 | stmt = delete(MemberRole).where(MemberRole.user_id == user_id, MemberRole.role_id == role_id) 108 | await session.execute(stmt) 109 | await session.commit() 110 | 111 | 112 | @plugin_config_command 113 | @group("persistence", invoke_without_command=True) 114 | async def config(ctx: Context) -> None: 115 | async with sessionmaker() as session: 116 | stmt = select(PersistedRole.id) 117 | roles = (await session.execute(stmt)).scalars() 118 | await ctx.send( 119 | ", ".join(format("{!M}", id) for id in roles) or "No roles registered", 120 | allowed_mentions=AllowedMentions.none(), 121 | ) 122 | 123 | 124 | @config.command("add") 125 | async def config_add(ctx: Context, role: PartialRoleConverter) -> None: 126 | async with sessionmaker() as session: 127 | session.add(PersistedRole(id=role.id)) 128 | await session.commit() 129 | await rehash_roles(session) 130 | await ctx.send("\u2705") 131 | 132 | 133 | @config.command("remove") 134 | async def config_remove(ctx: Context, role: PartialRoleConverter) -> None: 135 | async with sessionmaker() as session: 136 | await session.delete(await session.get(PersistedRole, role.id)) 137 | await session.commit() 138 | await rehash_roles(session) 139 | await ctx.send("\u2705") 140 | -------------------------------------------------------------------------------- /plugins/pins.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Dict, Optional 3 | 4 | import discord 5 | from discord import MessageType, RawReactionActionEvent 6 | from discord.ext.commands import command 7 | 8 | from bot.acl import privileged 9 | from bot.client import client 10 | from bot.commands import Context, add_cleanup, cleanup, plugin_command 11 | from bot.reactions import ReactionMonitor 12 | from util.discord import ReplyConverter, TempMessage, UserError, format, partial_from_reply 13 | 14 | 15 | class AbortDueToUnpin(Exception): 16 | pass 17 | 18 | 19 | class AbortDueToOtherPin(Exception): 20 | pass 21 | 22 | 23 | unpin_requests: Dict[int, ReactionMonitor[RawReactionActionEvent]] = {} 24 | 25 | 26 | @plugin_command 27 | @cleanup 28 | @command("pin") 29 | @privileged 30 | async def pin_command(ctx: Context, message: Optional[ReplyConverter]) -> None: 31 | """Pin a message.""" 32 | to_pin = partial_from_reply(message, ctx) 33 | if ctx.guild is None: 34 | raise UserError("Can only be used in a guild") 35 | guild = ctx.guild 36 | 37 | pin_msg_task = asyncio.create_task( 38 | client.wait_for( 39 | "message", 40 | check=lambda m: m.guild is not None 41 | and m.guild.id == guild.id 42 | and m.channel.id == ctx.channel.id 43 | and m.type == MessageType.pins_add 44 | and m.reference is not None 45 | and m.reference.message_id == to_pin.id, 46 | ) 47 | ) 48 | try: 49 | while True: 50 | try: 51 | await to_pin.pin(reason=format("Requested by {!m}", ctx.author)) 52 | break 53 | except (discord.Forbidden, discord.NotFound): 54 | pin_msg_task.cancel() 55 | break 56 | except discord.HTTPException as exc: 57 | if exc.text == "Cannot execute action on a system message" or exc.text == "Unknown Message": 58 | pin_msg_task.cancel() 59 | break 60 | elif not exc.text.startswith("Maximum number of pins reached"): 61 | raise 62 | pins = await ctx.channel.pins() 63 | 64 | oldest_pin = pins[-1] 65 | 66 | async with TempMessage(ctx, "No space in pins. Unpin or press \u267B to remove oldest") as confirm_msg: 67 | await confirm_msg.add_reaction("\u267B") 68 | await confirm_msg.add_reaction("\u274C") 69 | 70 | with ReactionMonitor( 71 | guild_id=guild.id, 72 | channel_id=ctx.channel.id, 73 | message_id=confirm_msg.id, 74 | author_id=ctx.author.id, 75 | event="add", 76 | filter=lambda _, p: p.emoji.name in ["\u267B", "\u274C"], 77 | timeout_each=60, 78 | ) as mon: 79 | try: 80 | if ctx.author.id in unpin_requests: 81 | unpin_requests[ctx.author.id].cancel(AbortDueToOtherPin()) 82 | unpin_requests[ctx.author.id] = mon 83 | _, p = await mon 84 | if p.emoji.name == "\u267B": 85 | await oldest_pin.unpin(reason=format("Requested by {!m}", ctx.author)) 86 | else: 87 | break 88 | except AbortDueToUnpin: 89 | del unpin_requests[ctx.author.id] 90 | except (asyncio.TimeoutError, AbortDueToOtherPin): 91 | pin_msg_task.cancel() 92 | break 93 | else: 94 | del unpin_requests[ctx.author.id] 95 | finally: 96 | try: 97 | pin_msg = await asyncio.wait_for(pin_msg_task, timeout=60) 98 | add_cleanup(ctx, pin_msg) 99 | except asyncio.TimeoutError: 100 | pin_msg_task.cancel() 101 | 102 | 103 | @plugin_command 104 | @cleanup 105 | @command("unpin") 106 | @privileged 107 | async def unpin_command(ctx: Context, message: Optional[ReplyConverter]) -> None: 108 | """Unpin a message.""" 109 | to_unpin = partial_from_reply(message, ctx) 110 | if ctx.guild is None: 111 | raise UserError("Can only be used in a guild") 112 | 113 | try: 114 | await to_unpin.unpin(reason=format("Requested by {!m}", ctx.author)) 115 | if ctx.author.id in unpin_requests: 116 | unpin_requests[ctx.author.id].cancel(AbortDueToUnpin()) 117 | 118 | await ctx.send("\u2705") 119 | except (discord.Forbidden, discord.NotFound): 120 | pass 121 | except discord.HTTPException as exc: 122 | if exc.text == "Cannot execute action on a system message": 123 | pass 124 | elif exc.text == "Unknown Message": 125 | pass 126 | else: 127 | raise 128 | -------------------------------------------------------------------------------- /plugins/reminders.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import logging 3 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast 4 | 5 | import discord 6 | from discord import AllowedMentions, MessageReference, Object, TextChannel, Thread 7 | from discord.ext.commands import command, group 8 | from sqlalchemy import TEXT, TIMESTAMP, BigInteger, delete, func, select 9 | from sqlalchemy.ext.asyncio import async_sessionmaker 10 | import sqlalchemy.orm 11 | from sqlalchemy.orm import Mapped, mapped_column 12 | from sqlalchemy.sql.functions import current_timestamp 13 | 14 | from bot.acl import privileged 15 | from bot.client import client 16 | from bot.commands import Context, cleanup, plugin_command 17 | from bot.tasks import task 18 | import plugins 19 | import util.db.kv 20 | from util.discord import DurationConverter, PlainItem, UserError, chunk_messages, format 21 | 22 | 23 | registry = sqlalchemy.orm.registry() 24 | sessionmaker = async_sessionmaker(util.db.engine, expire_on_commit=False) 25 | 26 | 27 | @registry.mapped 28 | class Reminder: 29 | __tablename__ = "reminders" 30 | 31 | id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 32 | user_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 33 | guild_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 34 | channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 35 | message_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 36 | time: Mapped[datetime] = mapped_column(TIMESTAMP, nullable=False) 37 | content: Mapped[str] = mapped_column(TEXT, nullable=False) 38 | 39 | if TYPE_CHECKING: 40 | 41 | def __init__( 42 | self, *, user_id: int, guild_id: int, channel_id: int, message_id: int, time: datetime, content: str 43 | ) -> None: ... 44 | 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | 49 | def format_msg(guild_id: int, channel_id: int, msg_id: int) -> str: 50 | return "https://discord.com/channels/{}/{}/{}".format(guild_id, channel_id, msg_id) 51 | 52 | 53 | def format_reminder(reminder: Reminder) -> str: 54 | msg = format_msg(reminder.guild_id, reminder.channel_id, reminder.message_id) 55 | if reminder.content == "": 56 | return format("{} for {!f}", msg, reminder.time) 57 | return format("{!i} ({}) for {!f}", reminder.content, msg, reminder.time) 58 | 59 | 60 | async def send_reminder(reminder: Reminder) -> None: 61 | guild = client.get_guild(reminder.guild_id) 62 | if guild is None: 63 | logger.info( 64 | "Reminder {} for user {} silently removed (guild no longer exists)".format(reminder.id, reminder.user_id) 65 | ) 66 | return 67 | try: 68 | channel = await guild.fetch_channel(reminder.channel_id) 69 | except discord.NotFound: 70 | logger.info( 71 | "Reminder {} for user {} silently removed (channel no longer exists)".format(reminder.id, reminder.user_id) 72 | ) 73 | return 74 | if not isinstance(channel, (TextChannel, Thread)): 75 | logger.info( 76 | "Reminder {} for user {} silently removed (invalid channel type)".format(reminder.id, reminder.user_id) 77 | ) 78 | return 79 | try: 80 | creation_time = discord.utils.snowflake_time(reminder.message_id) 81 | await channel.send( 82 | format( 83 | "{!m} asked to be reminded {!R}, {}", 84 | reminder.user_id, 85 | creation_time, 86 | reminder.content, 87 | )[:2000], 88 | reference=MessageReference( 89 | message_id=reminder.message_id, channel_id=reminder.channel_id, fail_if_not_exists=False 90 | ), 91 | allowed_mentions=AllowedMentions(everyone=False, users=[Object(reminder.user_id)], roles=False), 92 | ) 93 | except discord.Forbidden: 94 | logger.info("Reminder {} for user {} silently removed (permission error)".format(reminder.id, reminder.user_id)) 95 | 96 | 97 | @task(name="Reminder expiry task", every=86400, exc_backoff_base=60) 98 | async def expiry_task() -> None: 99 | await client.wait_until_ready() 100 | 101 | async with sessionmaker() as session: 102 | stmt = delete(Reminder).where(Reminder.time <= func.timezone("UTC", current_timestamp())).returning(Reminder) 103 | for reminder in (await session.execute(stmt)).scalars(): 104 | logger.debug("Expiring reminder for user #{}".format(reminder.user_id)) 105 | await send_reminder(reminder) 106 | await session.commit() 107 | 108 | stmt = select(Reminder.time).order_by(Reminder.time).limit(1) 109 | next_expiry = (await session.execute(stmt)).scalar() 110 | 111 | if next_expiry is not None: 112 | delay = next_expiry - datetime.utcnow() 113 | expiry_task.run_coalesced(delay.total_seconds()) 114 | logger.debug("Waiting for next reminder to expire in {}".format(delay)) 115 | 116 | 117 | @plugins.init 118 | async def init() -> None: 119 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 120 | 121 | async with sessionmaker() as session: 122 | conf = await util.db.kv.load(__name__) 123 | for (user_id,) in conf: 124 | for reminder in cast(List[Dict[str, Any]], conf[user_id]): 125 | session.add( 126 | Reminder( 127 | user_id=int(user_id), 128 | guild_id=reminder["guild"], 129 | channel_id=reminder["channel"], 130 | message_id=reminder["msg"], 131 | time=datetime.fromtimestamp(reminder["time"]), 132 | content=reminder["contents"], 133 | ) 134 | ) 135 | await session.commit() 136 | for user_id in [user_id for user_id, in conf]: 137 | conf[user_id] = None 138 | await conf 139 | 140 | expiry_task.run_coalesced(0) 141 | 142 | 143 | @plugin_command 144 | @cleanup 145 | @command("remindme", aliases=["remind"]) 146 | @privileged 147 | async def remindme_command(ctx: Context, interval: DurationConverter, *, text: Optional[str]) -> None: 148 | """Set a reminder with a given message.""" 149 | if ctx.guild is None: 150 | raise UserError("Only usable in a server") 151 | 152 | async with sessionmaker() as session: 153 | reminder = Reminder( 154 | user_id=ctx.author.id, 155 | guild_id=ctx.guild.id, 156 | channel_id=ctx.channel.id, 157 | message_id=ctx.message.id, 158 | time=datetime.utcnow() + interval, 159 | content=text or "", 160 | ) 161 | session.add(reminder) 162 | await session.commit() 163 | 164 | expiry_task.run_coalesced(0) 165 | 166 | await ctx.send( 167 | "Created reminder {}".format(format_reminder(reminder))[:2000], allowed_mentions=AllowedMentions.none() 168 | ) 169 | 170 | 171 | @plugin_command 172 | @cleanup 173 | @group("reminder", aliases=["reminders"], invoke_without_command=True) 174 | @privileged 175 | async def reminder_command(ctx: Context) -> None: 176 | """Display your reminders.""" 177 | async with sessionmaker() as session: 178 | stmt = select(Reminder).where(Reminder.user_id == ctx.author.id) 179 | reminders = (await session.execute(stmt)).scalars() 180 | 181 | items = [PlainItem("Your reminders include:\n")] 182 | for reminder in reminders: 183 | items.append(PlainItem("**{}.** Reminder {}\n".format(reminder.id, format_reminder(reminder)))) 184 | for content, _ in chunk_messages(items): 185 | await ctx.send(content, allowed_mentions=AllowedMentions.none()) 186 | 187 | 188 | @reminder_command.command("remove") 189 | @privileged 190 | async def reminder_remove(ctx: Context, id: int) -> None: 191 | """Delete a reminder.""" 192 | async with sessionmaker() as session: 193 | if reminder := await session.get(Reminder, id): 194 | await session.delete(reminder) 195 | await session.commit() 196 | await ctx.send( 197 | "Removed reminder {}".format(format_reminder(reminder))[:2000], allowed_mentions=AllowedMentions.none() 198 | ) 199 | 200 | expiry_task.run_coalesced(0) 201 | else: 202 | raise UserError("Reminder {} does not exist".format(id)) 203 | -------------------------------------------------------------------------------- /plugins/roleoverride.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import TYPE_CHECKING, Dict, List, Set, cast 3 | 4 | from discord import AllowedMentions, Member 5 | from discord.ext.commands import group 6 | from sqlalchemy import BigInteger, select 7 | from sqlalchemy.ext.asyncio import async_sessionmaker 8 | import sqlalchemy.orm 9 | from sqlalchemy.orm import Mapped, mapped_column 10 | 11 | from bot.cogs import Cog, cog 12 | from bot.commands import Context 13 | from bot.config import plugin_config_command 14 | import plugins 15 | import util.db.kv 16 | from util.discord import PartialRoleConverter, format, retry 17 | 18 | 19 | registry = sqlalchemy.orm.registry() 20 | sessionmaker = async_sessionmaker(util.db.engine) 21 | 22 | 23 | @registry.mapped 24 | class Override: 25 | __tablename__ = "roleoverrides" 26 | 27 | retained_role_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 28 | excluded_role_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 29 | 30 | if TYPE_CHECKING: 31 | 32 | def __init__(self, *, retained_role_id: int, excluded_role_id: int) -> None: ... 33 | 34 | 35 | @plugins.init 36 | async def init() -> None: 37 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 38 | 39 | async with sessionmaker() as session: 40 | conf = await util.db.kv.load(__name__) 41 | for (key,) in conf: 42 | for role in cast(List[int], conf[key]): 43 | session.add(Override(retained_role_id=int(key), excluded_role_id=role)) 44 | await session.commit() 45 | for key in [key for key, in conf]: 46 | conf[key] = None 47 | await conf 48 | 49 | 50 | @cog 51 | class RoleOverride(Cog): 52 | @Cog.listener() 53 | async def on_member_update(self, before: Member, after: Member) -> None: 54 | removed = [] 55 | async with sessionmaker() as session: 56 | stmt = ( 57 | select(Override.excluded_role_id) 58 | .distinct() 59 | .where(Override.retained_role_id.in_([role.id for role in after.roles])) 60 | ) 61 | excluded = set((await session.execute(stmt)).scalars()) 62 | for role in after.roles: 63 | if role.id in excluded: 64 | removed.append(role) 65 | if len(removed): 66 | await retry(lambda: after.remove_roles(*removed)) 67 | 68 | 69 | @plugin_config_command 70 | @group("roleoverride", invoke_without_command=True) 71 | async def config(ctx: Context) -> None: 72 | async with sessionmaker() as session: 73 | overrides: Dict[int, Set[int]] = defaultdict(set) 74 | stmt = select(Override) 75 | for override in (await session.execute(stmt)).scalars(): 76 | overrides[override.retained_role_id].add(override.excluded_role_id) 77 | 78 | await ctx.send( 79 | "\n".join( 80 | format("- having {!M} removes {}", retained, ", ".join(format("{!M}", excluded) for excluded in excludeds)) 81 | for retained, excludeds in overrides.items() 82 | ) 83 | or "No roles registered", 84 | allowed_mentions=AllowedMentions.none(), 85 | ) 86 | 87 | 88 | @config.command("add") 89 | async def config_add(ctx: Context, retained_role: PartialRoleConverter, excluded_role: PartialRoleConverter) -> None: 90 | async with sessionmaker() as session: 91 | session.add(Override(retained_role_id=retained_role.id, excluded_role_id=excluded_role.id)) 92 | await session.commit() 93 | await ctx.send("\u2705") 94 | 95 | 96 | @config.command("remove") 97 | async def config_remove(ctx: Context, retained_role: PartialRoleConverter, excluded_role: PartialRoleConverter) -> None: 98 | async with sessionmaker() as session: 99 | await session.delete(await session.get(Override, (retained_role.id, excluded_role.id))) 100 | await session.commit() 101 | await ctx.send("\u2705") 102 | -------------------------------------------------------------------------------- /plugins/rolereactions.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast 2 | 3 | import discord 4 | from discord import ( 5 | AllowedMentions, 6 | Emoji, 7 | Guild, 8 | Message, 9 | MessageReference, 10 | Object, 11 | PartialEmoji, 12 | RawReactionActionEvent, 13 | ) 14 | from discord.abc import Snowflake 15 | import discord.utils 16 | from sqlalchemy import TEXT, BigInteger, ForeignKey, delete, select 17 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker 18 | import sqlalchemy.orm 19 | from sqlalchemy.orm import Mapped, mapped_column, raiseload, relationship 20 | from sqlalchemy.schema import CreateSchema 21 | 22 | from bot.acl import privileged 23 | from bot.client import client 24 | from bot.cogs import Cog, cog, group 25 | from bot.commands import Context, cleanup 26 | import plugins 27 | import util.db.kv 28 | from util.discord import ( 29 | InvocationError, 30 | PartialRoleConverter, 31 | ReplyConverter, 32 | UserError, 33 | format, 34 | partial_from_reply, 35 | retry, 36 | ) 37 | 38 | 39 | registry = sqlalchemy.orm.registry() 40 | sessionmaker = async_sessionmaker(util.db.engine) 41 | 42 | 43 | @registry.mapped 44 | class ReactionMessage: 45 | __tablename__ = "messages" 46 | __table_args__ = {"schema": "role_reactions"} 47 | 48 | id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 49 | guild_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 50 | channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 51 | 52 | reactions: Mapped[List["Reaction"]] = relationship("Reaction", lazy="joined") 53 | 54 | def reference(self) -> MessageReference: 55 | return MessageReference(guild_id=self.guild_id, channel_id=self.channel_id, message_id=self.id) 56 | 57 | if TYPE_CHECKING: 58 | 59 | def __init__(self, *, id: int, guild_id: int, channel_id: int) -> None: ... 60 | 61 | 62 | @registry.mapped 63 | class Reaction: 64 | __tablename__ = "reactions" 65 | __table_args__ = {"schema": "role_reactions"} 66 | 67 | message_id: Mapped[int] = mapped_column(BigInteger, ForeignKey(ReactionMessage.id), primary_key=True) 68 | # either a unicode emoji, or a discord emoji's ID converted to string 69 | emoji: Mapped[str] = mapped_column(TEXT, primary_key=True) 70 | role_id: Mapped[int] = mapped_column(BigInteger, nullable=False) 71 | 72 | if TYPE_CHECKING: 73 | 74 | def __init__(self, *, message_id: int, emoji: str, role_id: int) -> None: ... 75 | 76 | 77 | reaction_messages: Set[int] 78 | 79 | 80 | @plugins.init 81 | async def init() -> None: 82 | global reaction_messages 83 | await util.db.init(util.db.get_ddl(CreateSchema("role_reactions"), registry.metadata.create_all)) 84 | 85 | async with sessionmaker() as session: 86 | conf = await util.db.kv.load(__name__) 87 | for (msg_id_str,) in conf: 88 | msg_id = int(msg_id_str) 89 | obj = cast(Dict[str, Any], conf[msg_id]) 90 | session.add(ReactionMessage(id=msg_id, guild_id=obj["guild"], channel_id=obj["channel"])) 91 | for emoji, role_id in obj["rolereacts"].items(): 92 | session.add(Reaction(message_id=msg_id, emoji=emoji, role_id=role_id)) 93 | await session.commit() 94 | for msg_id_str in [msg_id_str for msg_id_str, in conf]: 95 | conf[msg_id_str] = None 96 | await conf 97 | 98 | stmt = select(ReactionMessage.id) 99 | reaction_messages = set((await session.execute(stmt)).scalars()) 100 | 101 | 102 | async def find_message(channel_id: int, msg_id: int) -> Optional[Message]: 103 | channel = client.get_partial_messageable(channel_id) 104 | try: 105 | return await channel.fetch_message(msg_id) 106 | except (discord.NotFound, discord.Forbidden): 107 | return None 108 | 109 | 110 | def format_role(guild: Optional[Guild], role_id: int) -> str: 111 | role = discord.utils.find(lambda r: r.id == role_id, guild.roles if guild else ()) 112 | if role is None: 113 | return format("{!i}", role_id) 114 | else: 115 | return format("{!M}({!i} {!i})", role, role.name, role.id) 116 | 117 | 118 | def format_emoji(emoji_str: str) -> str: 119 | if emoji_str.isdigit(): 120 | emoji = client.get_emoji(int(emoji_str)) 121 | if emoji is not None and emoji.is_usable(): 122 | return str(emoji) + format("({!i})", emoji) 123 | return format("{!i}", emoji_str) 124 | 125 | 126 | def make_discord_emoji(emoji_str: str) -> Union[str, Emoji, None]: 127 | if emoji_str.isdigit(): 128 | emoji = client.get_emoji(int(emoji_str)) 129 | if emoji is not None and emoji.is_usable(): 130 | return emoji 131 | return None 132 | else: 133 | return emoji_str 134 | 135 | 136 | async def react_initial(channel_id: int, msg_id: int, emoji_str: str) -> None: 137 | react_msg = await find_message(channel_id, msg_id) 138 | if react_msg is None: 139 | return 140 | react_emoji = make_discord_emoji(emoji_str) 141 | if react_emoji is None: 142 | return 143 | try: 144 | await react_msg.add_reaction(react_emoji) 145 | except (discord.Forbidden, discord.NotFound): 146 | pass 147 | except discord.HTTPException as exc: 148 | if exc.text != "Unknown Emoji": 149 | raise 150 | 151 | 152 | async def get_payload_role(session: AsyncSession, guild: Guild, payload: RawReactionActionEvent) -> Optional[Snowflake]: 153 | if payload.emoji.id is not None: 154 | emoji = str(payload.emoji.id) 155 | else: 156 | emoji = payload.emoji.name 157 | if obj := await session.get(Reaction, (payload.message_id, emoji)): 158 | return Object(obj.role_id) 159 | else: 160 | return None 161 | 162 | 163 | @cog 164 | class RoleReactions(Cog): 165 | """Manage role reactions.""" 166 | 167 | @Cog.listener() 168 | async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: 169 | if (member := payload.member) is None: 170 | return 171 | if member.bot: 172 | return 173 | if payload.message_id not in reaction_messages: 174 | return 175 | async with sessionmaker() as session: 176 | if (role := await get_payload_role(session, member.guild, payload)) is None: 177 | return 178 | await retry(lambda: member.add_roles(role, reason="Role reactions on {}".format(payload.message_id))) 179 | 180 | @Cog.listener() 181 | async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: 182 | if payload.guild_id is None: 183 | return 184 | if payload.message_id not in reaction_messages: 185 | return 186 | if (guild := client.get_guild(payload.guild_id)) is None: 187 | return 188 | if (member := guild.get_member(payload.user_id)) is None: 189 | return 190 | if member.bot: 191 | return 192 | async with sessionmaker() as session: 193 | if (role := await get_payload_role(session, member.guild, payload)) is None: 194 | return 195 | await retry(lambda: member.remove_roles(role, reason="Role reactions on {}".format(payload.message_id))) 196 | 197 | @cleanup 198 | @group("rolereact") 199 | @privileged 200 | async def rolereact_command(self, ctx: Context) -> None: 201 | """Manage role reactions.""" 202 | pass 203 | 204 | @rolereact_command.command("new") 205 | @privileged 206 | async def rolereact_new(self, ctx: Context, message: Optional[ReplyConverter]) -> None: 207 | """Make the given message a role react message.""" 208 | msg = partial_from_reply(message, ctx) 209 | async with sessionmaker() as session: 210 | if await session.get(ReactionMessage, msg.id, options=[raiseload(ReactionMessage.reactions)]): 211 | raise UserError("Role reactions already exist on {}".format(msg.jump_url)) 212 | 213 | if msg.guild is None: 214 | raise InvocationError("The message must be in a guild") 215 | 216 | session.add(ReactionMessage(id=msg.id, guild_id=msg.guild.id, channel_id=msg.channel.id)) 217 | await session.commit() 218 | 219 | await ctx.send("Created role reactions on {}".format(msg.jump_url)) 220 | 221 | @rolereact_command.command("delete") 222 | @privileged 223 | async def rolereact_delete(self, ctx: Context, message: Optional[ReplyConverter]) -> None: 224 | """Make the given message not a role react message.""" 225 | msg = partial_from_reply(message, ctx) 226 | async with sessionmaker() as session: 227 | if obj := await session.get(ReactionMessage, msg.id, options=[raiseload(ReactionMessage.reactions)]): 228 | stmt = delete(Reaction).where(Reaction.message_id == obj.id) 229 | await session.execute(stmt) 230 | await session.delete(obj) 231 | await session.commit() 232 | else: 233 | raise UserError("Role reactions do not exist on {}".format(msg.jump_url)) 234 | 235 | await ctx.send("Removed role reactions on {}".format(msg.jump_url)) 236 | 237 | @rolereact_command.command("list") 238 | @privileged 239 | async def rolereact_list(self, ctx: Context) -> None: 240 | """List role react messages.""" 241 | async with sessionmaker() as session: 242 | stmt = select(ReactionMessage).options(raiseload(ReactionMessage.reactions)) 243 | messages = (await session.execute(stmt)).scalars() 244 | await ctx.send( 245 | "Role reactions exist on:\n{}".format("\n".join(obj.reference().jump_url for obj in messages)) 246 | ) 247 | 248 | @rolereact_command.command("show") 249 | @privileged 250 | async def rolereact_show(self, ctx: Context, message: Optional[ReplyConverter]) -> None: 251 | """List roles on a role react message.""" 252 | msg = partial_from_reply(message, ctx) 253 | async with sessionmaker() as session: 254 | if obj := await session.get(ReactionMessage, msg.id): 255 | await ctx.send( 256 | "Role reactions on {} include: {}".format( 257 | msg.jump_url, 258 | "; ".join( 259 | "{} for {}".format(format_emoji(reaction.emoji), format_role(msg.guild, reaction.role_id)) 260 | for reaction in obj.reactions 261 | ), 262 | ), 263 | allowed_mentions=AllowedMentions.none(), 264 | ) 265 | 266 | else: 267 | raise UserError("Role reactions do not exist on {}".format(msg.jump_url)) 268 | 269 | @rolereact_command.command("add") 270 | @privileged 271 | async def rolereact_add( 272 | self, ctx: Context, message: ReplyConverter, emoji: Union[PartialEmoji, str], role: PartialRoleConverter 273 | ) -> None: 274 | """Add an emoji/role to a role react message.""" 275 | async with sessionmaker() as session: 276 | if not (obj := await session.get(ReactionMessage, message.id)): 277 | raise UserError("Role reactions do not exist on {}".format(message.jump_url)) 278 | 279 | emoji_str = str(emoji.id) if isinstance(emoji, PartialEmoji) else emoji 280 | if reaction := await session.get(Reaction, (message.id, emoji_str)): 281 | await ctx.send( 282 | "Emoji {} already sets role {}".format( 283 | format_emoji(emoji_str), format_role(message.guild, reaction.role_id) 284 | ), 285 | allowed_mentions=AllowedMentions.none(), 286 | ) 287 | return 288 | 289 | obj.reactions.append(Reaction(message_id=message.id, emoji=emoji_str, role_id=role.id)) 290 | await session.commit() 291 | 292 | await react_initial(message.channel.id, message.id, emoji_str) 293 | await ctx.send( 294 | "Reacting with {} on message {} now sets {}".format( 295 | format_emoji(emoji_str), message.jump_url, format_role(message.guild, role.id) 296 | ), 297 | allowed_mentions=AllowedMentions.none(), 298 | ) 299 | 300 | @rolereact_command.command("remove") 301 | @privileged 302 | async def rolereact_remove(self, ctx: Context, message: ReplyConverter, emoji: Union[PartialEmoji, str]) -> None: 303 | """Remove an emoji from a role react message.""" 304 | async with sessionmaker() as session: 305 | if not await session.get(ReactionMessage, message.id): 306 | raise UserError("Role reactions do not exist on {}".format(message.jump_url)) 307 | 308 | emoji_str = str(emoji.id) if isinstance(emoji, PartialEmoji) else emoji 309 | if not (reaction := await session.get(Reaction, (message.id, emoji_str))): 310 | await ctx.send( 311 | "Role reactions for {} do not exist on {}".format(format_emoji(emoji_str), message.jump_url) 312 | ) 313 | return 314 | 315 | await session.delete(reaction) 316 | await session.commit() 317 | 318 | await ctx.send( 319 | "Reacting with {} on message {} no longer sets roles".format(format_emoji(emoji_str), message.jump_url), 320 | allowed_mentions=AllowedMentions.none(), 321 | ) 322 | -------------------------------------------------------------------------------- /plugins/roles_dialog.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, Union, cast 2 | 3 | from discord import AllowedMentions, ButtonStyle, Interaction, Member, Role, SelectOption 4 | from discord.abc import Messageable 5 | from discord.ext.commands import group 6 | from discord.ui import Button, Select, View 7 | from sqlalchemy import BOOLEAN, TEXT, BigInteger, ForeignKey, select 8 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker 9 | import sqlalchemy.orm 10 | from sqlalchemy.orm import Mapped, mapped_column, raiseload, relationship 11 | from sqlalchemy.schema import CreateSchema 12 | 13 | from bot.commands import Context 14 | from bot.config import plugin_config_command 15 | from bot.interactions import command as app_command, persistent_view 16 | import plugins 17 | import plugins.roles_review 18 | import util.db.kv 19 | from util.discord import CodeBlock, Inline, PartialRoleConverter, Quoted, format, retry 20 | 21 | 22 | registry = sqlalchemy.orm.registry() 23 | sessionmaker = async_sessionmaker(util.db.engine, expire_on_commit=False) 24 | 25 | 26 | @registry.mapped 27 | class SelectField: 28 | __tablename__ = "selects" 29 | __table_args__ = {"schema": "roles_dialog"} 30 | 31 | index: Mapped[int] = mapped_column(BigInteger, primary_key=True) 32 | boolean: Mapped[bool] = mapped_column(BOOLEAN, nullable=False) 33 | 34 | items: Mapped[List["SelectItem"]] = relationship("SelectItem", lazy="joined", order_by="SelectItem.id") 35 | 36 | if TYPE_CHECKING: 37 | 38 | def __init__(self, index: int, boolean: bool) -> None: ... 39 | 40 | 41 | @registry.mapped 42 | class SelectItem: 43 | __tablename__ = "items" 44 | __table_args__ = {"schema": "roles_dialog"} 45 | 46 | id: Mapped[int] = mapped_column(BigInteger, primary_key=True) 47 | index: Mapped[int] = mapped_column(BigInteger, ForeignKey(SelectField.index), nullable=False) 48 | role_id: Mapped[Optional[int]] = mapped_column(BigInteger) 49 | label: Mapped[Optional[str]] = mapped_column(TEXT) 50 | description: Mapped[Optional[str]] = mapped_column(TEXT) 51 | 52 | if TYPE_CHECKING: 53 | 54 | def __init__( 55 | self, 56 | index: int, 57 | id: int = ..., 58 | role_id: Optional[int] = ..., 59 | label: Optional[str] = ..., 60 | description: Optional[str] = ..., 61 | ) -> None: ... 62 | 63 | 64 | selects: List[SelectField] 65 | 66 | 67 | async def rehash(session: AsyncSession) -> None: 68 | global selects 69 | stmt = select(SelectField).order_by(SelectField.index) 70 | selects = list((await session.execute(stmt)).scalars().unique()) 71 | 72 | 73 | @plugins.init 74 | async def init() -> None: 75 | await util.db.init(util.db.get_ddl(CreateSchema("roles_dialog"), registry.metadata.create_all)) 76 | 77 | async with sessionmaker() as session: 78 | conf = await util.db.kv.load(__name__) 79 | 80 | def mk_item(index: int, item: Union[int, str]) -> SelectItem: 81 | if isinstance(item, int): 82 | return SelectItem(index=index, role_id=item, description=cast(Optional[str], conf[item, "desc"])) 83 | else: 84 | return SelectItem(index=index, label=item) 85 | 86 | if conf.roles is not None: 87 | index = 0 88 | booleans: bool = False 89 | for lst in cast(List[List[Union[int, str]]], conf.roles): 90 | if len(lst) == 1: 91 | session.add(mk_item(index, lst[0])) 92 | booleans = True 93 | else: 94 | if booleans: 95 | session.add(SelectField(index=index, boolean=True)) 96 | index += 1 97 | booleans = False 98 | for l in lst: 99 | session.add(mk_item(index, l)) 100 | session.add(SelectField(index=index, boolean=False)) 101 | index += 1 102 | if booleans: 103 | session.add(SelectField(index=index, boolean=True)) 104 | index += 1 105 | 106 | await session.commit() 107 | conf.roles = None 108 | await conf 109 | 110 | await rehash(session) 111 | 112 | 113 | class RoleSelect(Select["RolesView"]): 114 | def __init__( 115 | self, boolean: bool, role_items: Iterable[SelectItem], member: Member, row: Optional[int] = None 116 | ) -> None: 117 | self.roles: Dict[str, Role] = {} 118 | index = 0 119 | options = [] 120 | 121 | for item in role_items: 122 | if item.role_id is not None: 123 | if (role := member.guild.get_role(item.role_id)) is not None: 124 | options.append( 125 | SelectOption( 126 | label=(role.name if item.label is None else item.label)[:100], 127 | value=str(index), 128 | description=(item.description or "")[:100], 129 | default=role in member.roles, 130 | ) 131 | ) 132 | self.roles[str(index)] = role 133 | index += 1 134 | elif item.label is not None: 135 | options.append( 136 | SelectOption(label=item.label[:100], value="_", description=(item.description or "")[:100]) 137 | ) 138 | 139 | if not boolean and sum(option.default for option in options) > 1: 140 | for option in options: 141 | option.default = False 142 | 143 | super().__init__( 144 | placeholder="Select roles..." if boolean else "Select a role...", 145 | min_values=0 if boolean else 1, 146 | max_values=len(options) if boolean else 1, 147 | options=options, 148 | ) 149 | 150 | async def callback(self, interaction: Interaction) -> None: 151 | if not isinstance(interaction.user, Member): 152 | await interaction.response.send_message( 153 | "This can only be done in a server.", ephemeral=True, delete_after=60 154 | ) 155 | return 156 | member = interaction.user 157 | 158 | selected_roles = set() 159 | for index in self.values: 160 | if index in self.roles: 161 | selected_roles.add(self.roles[index]) 162 | add_roles = set() 163 | remove_roles = set() 164 | prompt_roles: List[Tuple[Role, plugins.roles_review.ReviewedRole]] = [] 165 | for role in self.roles.values(): 166 | if role in member.roles and role not in selected_roles: 167 | remove_roles.add(role) 168 | if role not in member.roles and role in selected_roles: 169 | pre = await plugins.roles_review.pre_apply(member, role) 170 | if pre == plugins.roles_review.ApplicationStatus.APPROVED: 171 | add_roles.add(role) 172 | elif isinstance(pre, plugins.roles_review.ReviewedRole): 173 | prompt_roles.append((role, pre)) 174 | # TODO: tell them if False? 175 | 176 | if prompt_roles: 177 | await interaction.response.send_modal(plugins.roles_review.RolePromptModal(member.guild, prompt_roles)) 178 | else: 179 | await interaction.response.defer(ephemeral=True) 180 | 181 | if add_roles: 182 | await retry(lambda: member.add_roles(*add_roles, reason="Role dialog")) 183 | if remove_roles: 184 | await retry(lambda: member.remove_roles(*remove_roles, reason="Role dialog")) 185 | 186 | if not prompt_roles: 187 | await interaction.followup.send( 188 | "\u2705 Updated roles." if add_roles or remove_roles else "Roles not changed.", ephemeral=True 189 | ) 190 | 191 | 192 | class RolesView(View): 193 | def __init__(self, member: Member) -> None: 194 | super().__init__(timeout=600) 195 | 196 | for select in selects: 197 | self.add_item(RoleSelect(select.boolean, select.items, member)) 198 | 199 | 200 | async def send_roles_view(interaction: Interaction) -> None: 201 | if not isinstance(interaction.user, Member): 202 | await interaction.response.send_message("This can only be done in a server.", ephemeral=True, delete_after=60) 203 | return 204 | await interaction.response.send_message("Select your roles:", view=RolesView(interaction.user), ephemeral=True) 205 | 206 | 207 | class ManageRolesButton(Button["ManageRolesView"]): 208 | def __init__(self) -> None: 209 | super().__init__(style=ButtonStyle.primary, label="Manage roles", custom_id="{}:manage".format(__name__)) 210 | 211 | async def callback(self, interaction: Interaction) -> None: 212 | await send_roles_view(interaction) 213 | 214 | 215 | class ManageRolesView(View): 216 | def __init__(self) -> None: 217 | super().__init__(timeout=None) 218 | self.add_item(ManageRolesButton()) 219 | 220 | 221 | persistent_view(ManageRolesView()) 222 | 223 | 224 | @app_command("roles", description="Manage self-assigned roles.") 225 | async def roles_command(interaction: Interaction) -> None: 226 | await send_roles_view(interaction) 227 | 228 | 229 | @plugin_config_command 230 | @group("roles_dialog", invoke_without_command=True) 231 | async def config(ctx: Context) -> None: 232 | async with sessionmaker() as session: 233 | stmt = select(SelectField).order_by(SelectField.index) 234 | selects = (await session.execute(stmt)).scalars().unique() 235 | 236 | await ctx.send( 237 | "\n".join( 238 | format( 239 | "- index {!i}: {} {}", 240 | select.index, 241 | "multi" if select.boolean else "choice", 242 | ", ".join(format("ID {!i}", item.id) for item in select.items), 243 | ) 244 | for select in selects 245 | ) 246 | ) 247 | 248 | 249 | @config.command("new") 250 | async def config_new(ctx: Context, index: int) -> None: 251 | async with sessionmaker() as session: 252 | if not await session.get(SelectField, index, options=[raiseload(SelectField.items)]): 253 | session.add(SelectField(index=index, boolean=True)) 254 | item = SelectItem(index=index) 255 | session.add(item) 256 | await session.commit() 257 | await rehash(session) 258 | await ctx.send(format("Created: ID {!i}", item.id)) 259 | 260 | 261 | @config.command("remove") 262 | async def config_remove(ctx: Context, id: int) -> None: 263 | async with sessionmaker() as session: 264 | item = await session.get(SelectItem, id) 265 | assert item 266 | await session.delete(item) 267 | select = await session.get(SelectField, item.index) 268 | assert select 269 | if not select.items: 270 | await session.delete(select) 271 | await session.commit() 272 | await rehash(session) 273 | await ctx.send( 274 | format("Removed ID {!i} and index {!i}", item.id, select.index) 275 | if not select.items 276 | else format("Removed ID {!i}", item.id) 277 | ) 278 | 279 | 280 | @config.command("mode") 281 | async def config_mode(ctx: Context, index: int, mode: Optional[Literal["choice", "multi"]]) -> None: 282 | async with sessionmaker() as session: 283 | select = await session.get(SelectField, index, options=[raiseload(SelectField.items)]) 284 | assert select 285 | if mode is None: 286 | await ctx.send("multi" if select.boolean else "choice") 287 | else: 288 | select.boolean = mode == "choice" 289 | await session.commit() 290 | await rehash(session) 291 | await ctx.send("\u2705") 292 | 293 | 294 | @config.command("role") 295 | async def config_role(ctx: Context, id: int, role: Optional[Union[Literal["None"], PartialRoleConverter]]) -> None: 296 | async with sessionmaker() as session: 297 | item = await session.get(SelectItem, id) 298 | assert item 299 | if role is None: 300 | await ctx.send( 301 | "None" if item.role_id is None else format("{!M}", item.role_id), 302 | allowed_mentions=AllowedMentions.none(), 303 | ) 304 | else: 305 | item.role_id = None if role == "None" else role.id 306 | await session.commit() 307 | await rehash(session) 308 | await ctx.send("\u2705") 309 | 310 | 311 | @config.command("label") 312 | async def config_label( 313 | ctx: Context, id: int, label: Optional[Union[Literal["None"], CodeBlock, Inline, Quoted]] 314 | ) -> None: 315 | async with sessionmaker() as session: 316 | item = await session.get(SelectItem, id) 317 | assert item 318 | if label is None: 319 | await ctx.send("None" if item.label is None else format("{!b}", item.label)) 320 | else: 321 | item.label = None if label == "None" else label.text 322 | await session.commit() 323 | await rehash(session) 324 | await ctx.send("\u2705") 325 | 326 | 327 | @config.command("description") 328 | async def config_description( 329 | ctx: Context, id: int, description: Optional[Union[Literal["None"], CodeBlock, Inline, Quoted]] 330 | ) -> None: 331 | async with sessionmaker() as session: 332 | item = await session.get(SelectItem, id) 333 | assert item 334 | if description is None: 335 | await ctx.send("None" if item.description is None else format("{!b}", item.description)) 336 | else: 337 | item.description = None if description == "None" else description.text 338 | await session.commit() 339 | await rehash(session) 340 | await ctx.send("\u2705") 341 | 342 | 343 | async def setup(target: Messageable) -> None: 344 | await target.send(view=ManageRolesView()) 345 | -------------------------------------------------------------------------------- /plugins/update.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import asyncio.subprocess 3 | from typing import TYPE_CHECKING, Optional 4 | 5 | from discord.ext.commands import command, group 6 | from sqlalchemy import TEXT, select 7 | from sqlalchemy.ext.asyncio import async_sessionmaker 8 | import sqlalchemy.orm 9 | from sqlalchemy.orm import Mapped, mapped_column 10 | 11 | from bot.acl import privileged 12 | from bot.commands import Context, cleanup, plugin_command 13 | from bot.config import plugin_config_command 14 | import plugins 15 | import util.db.kv 16 | from util.discord import CodeItem, Typing, chunk_messages, format 17 | 18 | 19 | registry = sqlalchemy.orm.registry() 20 | sessionmaker = async_sessionmaker(util.db.engine) 21 | 22 | 23 | @registry.mapped 24 | class GitDirectory: 25 | __tablename__ = "update_git_directories" 26 | 27 | name: Mapped[str] = mapped_column(TEXT, primary_key=True) 28 | directory: Mapped[str] = mapped_column(TEXT, nullable=False) 29 | 30 | if TYPE_CHECKING: 31 | 32 | def __init__(self, *, name: str, directory: str) -> None: ... 33 | 34 | 35 | @plugins.init 36 | async def init() -> None: 37 | await util.db.init(util.db.get_ddl(registry.metadata.create_all)) 38 | 39 | async with sessionmaker() as session: 40 | conf = await util.db.kv.load(__name__) 41 | for key in [key for key, in conf]: 42 | session.add(GitDirectory(name=key, directory=str(conf[key]))) 43 | conf[key] = None 44 | await session.commit() 45 | await conf 46 | 47 | 48 | @plugin_command 49 | @cleanup 50 | @command("update") 51 | @privileged 52 | async def update_command(ctx: Context, bot_directory: Optional[str]) -> None: 53 | """Pull changes from git remote.""" 54 | async with sessionmaker() as session: 55 | cwd = None 56 | if bot_directory is not None: 57 | if conf := await session.get(GitDirectory, bot_directory): 58 | cwd = conf.directory 59 | 60 | git_pull = await asyncio.create_subprocess_exec( 61 | "git", 62 | "pull", 63 | "--ff-only", 64 | "--recurse-submodules", 65 | cwd=cwd, 66 | stdout=asyncio.subprocess.PIPE, 67 | stderr=asyncio.subprocess.STDOUT, 68 | ) 69 | 70 | async with Typing(ctx): 71 | try: 72 | assert git_pull.stdout 73 | output = (await git_pull.stdout.read()).decode("utf", "replace") 74 | finally: 75 | await git_pull.wait() 76 | 77 | for content, files in chunk_messages((CodeItem(output, filename="update.txt"),)): 78 | await ctx.send(content, files=files) 79 | 80 | 81 | @plugin_config_command 82 | @group("update", invoke_without_command=True) 83 | async def config(ctx: Context) -> None: 84 | async with sessionmaker() as session: 85 | stmt = select(GitDirectory) 86 | dirs = (await session.execute(stmt)).scalars() 87 | await ctx.send( 88 | "\n".join(format("- {!i}: {!i}", conf.name, conf.directory) for conf in dirs) 89 | or "No repositories registered" 90 | ) 91 | 92 | 93 | @config.command("add") 94 | async def config_add(ctx: Context, name: str, directory: str) -> None: 95 | async with sessionmaker() as session: 96 | session.add(GitDirectory(name=name, directory=directory)) 97 | await session.commit() 98 | await ctx.send("\u2705") 99 | 100 | 101 | @config.command("remove") 102 | async def config_remove(ctx: Context, name: str) -> None: 103 | async with sessionmaker() as session: 104 | await session.delete(await session.get(GitDirectory, name)) 105 | await session.commit() 106 | await ctx.send("\u2705") 107 | -------------------------------------------------------------------------------- /plugins/version.py: -------------------------------------------------------------------------------- 1 | import asyncio.subprocess 2 | 3 | from discord.ext.commands import command 4 | 5 | from bot.acl import privileged 6 | from bot.commands import Context, cleanup, plugin_command 7 | from util.discord import format 8 | 9 | 10 | @plugin_command 11 | @cleanup 12 | @command("version") 13 | @privileged 14 | async def version_command(ctx: Context) -> None: 15 | """Display running bot version including any local changes.""" 16 | git_log = await asyncio.subprocess.create_subprocess_exec( 17 | "git", "log", "--max-count=1", "--format=format:%H%d", "HEAD", stdout=asyncio.subprocess.PIPE 18 | ) 19 | try: 20 | assert git_log.stdout 21 | version = (await git_log.stdout.read()).decode("utf", "replace").rstrip("\n") 22 | finally: 23 | await git_log.wait() 24 | 25 | git_status = await asyncio.subprocess.create_subprocess_exec( 26 | "git", "status", "--porcelain", "-z", stdout=asyncio.subprocess.PIPE 27 | ) 28 | try: 29 | assert git_status.stdout 30 | changes = (await git_status.stdout.read()).decode("utf", "replace").split("\0") 31 | finally: 32 | await git_status.wait() 33 | 34 | changes = list(filter(lambda line: line and not line.startswith("??"), changes)) 35 | 36 | if changes: 37 | await ctx.send("{} with changes:\n{}".format(version, "\n".join(format("{!i}", change) for change in changes))) 38 | else: 39 | await ctx.send(version) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pyright] 2 | exclude = ["discord", "discord.py"] 3 | pythonVersion = "3.9" 4 | 5 | reportUnnecessaryCast = "warning" 6 | reportUnnecessaryTypeIgnoreComment = "warning" 7 | reportFunctionMemberAccess = "warning" 8 | reportUnusedImport = "warning" 9 | reportUnusedClass = "warning" 10 | reportUnusedFunction = "warning" 11 | reportUnusedVariable = "warning" 12 | reportDuplicateImport = "warning" 13 | reportUntypedFunctionDecorator = "error" 14 | reportUntypedClassDecorator = "error" 15 | reportUntypedBaseClass = "error" 16 | reportUntypedNamedTuple = "error" 17 | reportUnknownParameterType = "error" 18 | reportUnknownLambdaType = "error" 19 | reportMissingParameterType = "error" 20 | reportMissingTypeArgument = "error" 21 | reportUnnecessaryIsInstance = "warning" 22 | 23 | [tool.black] 24 | extend-exclude = "discord/|discord.py/" 25 | line-length = 120 26 | target-version = ["py39"] 27 | 28 | [tool.isort] 29 | skip_glob = ["discord/*", "discord.py/*"] 30 | multi_line_output = 3 31 | include_trailing_comma = true 32 | combine_as_imports = true 33 | force_sort_within_sections = true 34 | line_length = 120 35 | known_third_party = "discord" 36 | extra_standard_library = "typing_extensions" 37 | lines_after_imports = 2 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements_core.txt 2 | -r requirements_linting.txt 3 | -------------------------------------------------------------------------------- /requirements_core.txt: -------------------------------------------------------------------------------- 1 | docker/requirements_core.txt -------------------------------------------------------------------------------- /requirements_linting.txt: -------------------------------------------------------------------------------- 1 | black==24.4 2 | isort==5.13 3 | -------------------------------------------------------------------------------- /static_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Read a basic ini-like config file at startup. The values inside the config aren't really supposed to change during 3 | execution of the bot. This module implements __getattr__ so that you could write: 4 | 5 | import static_config 6 | static_config.foo["bar"] 7 | """ 8 | 9 | from configparser import ConfigParser, SectionProxy 10 | 11 | 12 | config_file = "bot.conf" 13 | 14 | config = ConfigParser() 15 | config.read(config_file, encoding="utf") 16 | 17 | 18 | def writeback() -> None: 19 | """Save the modified config. This will erase comments.""" 20 | with open(config_file, "w", encoding="utf") as f: 21 | config.write(f) 22 | 23 | 24 | def __getattr__(name: str) -> SectionProxy: 25 | try: 26 | return config[name] 27 | except KeyError as exc: 28 | raise AttributeError(*exc.args) 29 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/discord-math/bot/1231f964ea2d701e18dda7b63d565de43846de9b/util/__init__.py -------------------------------------------------------------------------------- /util/asyncio.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Awaitable, Callable, Generator, TypeVar 3 | 4 | 5 | R_co = TypeVar("R_co", covariant=True) 6 | 7 | 8 | def __await__(fun: Callable[..., Awaitable[R_co]]) -> Callable[..., Generator[None, None, R_co]]: 9 | """Decorate a class's __await__ with this to be able to write it as an async def.""" 10 | 11 | def wrapper(*args: object, **kwargs: object) -> Generator[None, None, R_co]: 12 | return fun(*args, **kwargs).__await__() 13 | 14 | return wrapper 15 | 16 | 17 | def concurrently(fun: Callable[..., R_co], *args: object, **kwargs: object) -> Awaitable[R_co]: 18 | """ 19 | Run a synchronous blocking computation in a different python thread, avoiding blocking the current async thread. 20 | This function starts the computation and returns a future referring to its result. Beware of (actual) thread-safety. 21 | """ 22 | return asyncio.get_running_loop().run_in_executor(None, lambda: fun(*args, **kwargs)) 23 | -------------------------------------------------------------------------------- /util/db/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import AsyncIterator, Callable, Union 3 | 4 | import asyncpg 5 | import sqlalchemy 6 | from sqlalchemy import Connection 7 | import sqlalchemy.dialects.postgresql 8 | import sqlalchemy.ext.asyncio 9 | from sqlalchemy.schema import DDLElement, ExecutableDDLElement 10 | 11 | import static_config 12 | import util.db.dsn as util_db_dsn 13 | import util.db.log as util_db_log 14 | 15 | 16 | connection_dsn: str = static_config.DB["dsn"] 17 | 18 | connection_uri: str = util_db_dsn.dsn_to_uri(connection_dsn) 19 | async_connection_uri: str = util_db_dsn.uri_to_asyncpg(connection_uri) 20 | 21 | 22 | @contextlib.asynccontextmanager 23 | async def connection() -> AsyncIterator[util_db_log.LoggingConnection]: 24 | conn = await asyncpg.connect(connection_uri, connection_class=util_db_log.LoggingConnection) 25 | try: 26 | yield conn 27 | finally: 28 | await conn.close() 29 | 30 | 31 | engine: sqlalchemy.ext.asyncio.AsyncEngine = sqlalchemy.ext.asyncio.create_async_engine( 32 | async_connection_uri, pool_pre_ping=True, connect_args={"connection_class": util_db_log.LoggingConnection} 33 | ) 34 | 35 | from util.db.initialization import init as init, init_for as init_for 36 | 37 | 38 | def get_ddl(*cbs: Union[DDLElement, Callable[[Connection], None]]) -> str: 39 | # By default sqlalchemy treats asyncpg as if it had paramstyle="format", which means it tries to escape percent 40 | # signs. We don't want that so we have to override the paramstyle. Ideally "numeric" would be the right choice here 41 | # but that doesn't work. 42 | dialect = sqlalchemy.dialects.postgresql.dialect(paramstyle="qmark") 43 | ddls = [] 44 | 45 | def executor(sql: ExecutableDDLElement, *args: object, **kwargs: object) -> None: 46 | ddls.append(str(sql.compile(dialect=dialect)) + ";") 47 | 48 | conn = sqlalchemy.create_mock_engine(sqlalchemy.make_url("postgresql://"), executor) 49 | for cb in cbs: 50 | if isinstance(cb, DDLElement): 51 | conn.execute(cb) 52 | else: 53 | cb(conn) # type: ignore 54 | 55 | return "\n".join(ddls) 56 | -------------------------------------------------------------------------------- /util/db/dsn.py: -------------------------------------------------------------------------------- 1 | import re 2 | import urllib.parse 3 | 4 | 5 | dsn_re: re.Pattern[str] = re.compile(r"\s*(\w*)\s*=\s*(?:([^\s' \\]+)|'((?:[^'\\]|\\.)*)')\s*") 6 | unquote_re: re.Pattern[str] = re.compile(r"\\(.)") 7 | 8 | 9 | def dsn_to_uri(dsn: str) -> str: 10 | """ 11 | Convert a key=value style DSN into a postgres:// URI 12 | """ 13 | if dsn.startswith("postgres://") or dsn.startswith("postgresql://"): 14 | return dsn 15 | if "=" not in dsn: 16 | return "postgres://" + urllib.parse.quote(dsn, safe="") 17 | kvs = [] 18 | for key, val, val_quoted in dsn_re.findall(dsn): 19 | if not val: 20 | val = unquote_re.sub(r"\1", val_quoted) 21 | kvs.append((key, val)) 22 | return "postgres://?" + urllib.parse.urlencode(kvs) 23 | 24 | 25 | def uri_to_asyncpg(uri: str) -> str: 26 | return "postgresql+asyncpg://?dsn=" + urllib.parse.quote(uri, safe="") 27 | -------------------------------------------------------------------------------- /util/db/initialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple database migration manager. A module can request to initialize something in the database with the @init_for 3 | and @init decorators. 4 | """ 5 | 6 | import hashlib 7 | import logging 8 | 9 | import plugins 10 | import static_config 11 | import util.db as db 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | meta_initialized = False 17 | 18 | 19 | async def initialize_meta() -> None: 20 | global meta_initialized 21 | if not meta_initialized: 22 | logger.debug("Initializing migration metadata") 23 | try: 24 | async with db.connection() as conn: 25 | await conn.execute( 26 | """ 27 | CREATE SCHEMA IF NOT EXISTS meta 28 | """ 29 | ) 30 | await conn.execute( 31 | """ 32 | CREATE TABLE IF NOT EXISTS meta.schema_hashes 33 | ( name TEXT NOT NULL PRIMARY KEY 34 | , sha1 BYTEA NOT NULL ) 35 | """ 36 | ) 37 | finally: 38 | meta_initialized = True 39 | 40 | 41 | async def init_for(name: str, schema: str) -> None: 42 | """ 43 | Pass DDL SQL statements to initialize something in the database. 44 | 45 | await init_for("module name", "CREATE TABLE foo (bar TEXT)") 46 | 47 | The SQL will be hashed. If a hash for this module doesn't yet exist the SQL code will be executed and the 48 | hash saved. If the known hash for the module matches the computed one, nothing happens. Otherwise we look for a 49 | migration file in a configurable directory and run it, updating the known hash. 50 | """ 51 | logger.debug("Schema for {}:\n{}".format(name, schema)) 52 | async with db.connection() as conn: 53 | async with conn.transaction(): 54 | await initialize_meta() 55 | old_sha = await conn.fetchval("SELECT sha1 FROM meta.schema_hashes WHERE name = $1", name) 56 | sha = hashlib.sha1(schema.encode("utf")).digest() 57 | logger.debug("{}: old {} new {}".format(name, old_sha.hex() if old_sha is not None else None, sha.hex())) 58 | if old_sha is not None: 59 | if old_sha != sha: 60 | for dirname in static_config.DB["migrations"].split(":"): 61 | filename = "{}/{}-{}-{}.sql".format(dirname, name, old_sha.hex(), sha.hex()) 62 | try: 63 | fp = open(filename, "r", encoding="utf") 64 | break 65 | except FileNotFoundError: 66 | continue 67 | else: 68 | raise FileNotFoundError( 69 | "Could not find {}-{}-{}.sql in {}".format( 70 | name, old_sha.hex(), sha.hex(), static_config.DB["migrations"] 71 | ) 72 | ) 73 | with fp: 74 | logger.debug("{}: Loading migration {}".format(name, filename)) 75 | await conn.execute(fp.read()) 76 | await conn.execute("UPDATE meta.schema_hashes SET sha1 = $2 WHERE name = $1", name, sha) 77 | else: 78 | await conn.execute(schema) 79 | await conn.execute("INSERT INTO meta.schema_hashes (name, sha1) VALUES ($1, $2)", name, sha) 80 | 81 | 82 | async def init(schema: str) -> None: 83 | """Request database initialization for the current plugin.""" 84 | await init_for(plugins.current_plugin().name, schema) 85 | -------------------------------------------------------------------------------- /util/db/kv/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple key-value store that associates to each module name and a string key a 3 | piece of JSON. If a module needs more efficient or structured storage it should 4 | probably have its own DB handling code. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import asyncio 10 | import contextlib 11 | import json 12 | from typing import AsyncIterator, Dict, Iterator, Optional, Sequence, Set, Tuple, Union, cast 13 | from weakref import WeakValueDictionary 14 | 15 | import util.asyncio 16 | import util.db as util_db 17 | import util.db.log as util_db_log 18 | from util.frozen_dict import FrozenDict 19 | from util.frozen_list import FrozenList 20 | 21 | 22 | schema_initialized = False 23 | 24 | 25 | async def init_schema() -> None: 26 | global schema_initialized 27 | if not schema_initialized: 28 | await util_db.init_for( 29 | __name__, 30 | """ 31 | CREATE TABLE kv 32 | ( namespace TEXT NOT NULL 33 | , key TEXT ARRAY NOT NULL 34 | , value TEXT NOT NULL 35 | , PRIMARY KEY(namespace, key) ); 36 | CREATE INDEX kv_namespace_index 37 | ON kv USING BTREE(namespace); 38 | """, 39 | ) 40 | schema_initialized = True 41 | 42 | 43 | def json_freeze(value: Optional[object]) -> Optional[object]: 44 | if isinstance(value, list): 45 | return FrozenList(json_freeze(v) for v in value) 46 | elif isinstance(value, dict): 47 | return FrozenDict((k, json_freeze(v)) for k, v in value.items()) 48 | else: 49 | return value 50 | 51 | 52 | class ThawingJSONEncoder(json.JSONEncoder): 53 | __slots__ = () 54 | 55 | def default(self, o: object) -> object: 56 | if isinstance(o, FrozenList): 57 | return o.copy() 58 | elif isinstance(o, FrozenDict): 59 | return o.copy() 60 | else: 61 | return super().default(o) 62 | 63 | 64 | def json_encode(value: object) -> Optional[str]: 65 | return json.dumps(value, cls=ThawingJSONEncoder) if value is not None else None 66 | 67 | 68 | def json_decode(text: Optional[str]) -> object: 69 | return json_freeze(json.loads(text)) if text is not None else None 70 | 71 | 72 | @contextlib.asynccontextmanager 73 | async def connect() -> AsyncIterator[util_db_log.LoggingConnection]: 74 | await init_schema() 75 | async with util_db.connection() as conn: 76 | yield conn 77 | 78 | 79 | async def get_raw_value(namespace: Sequence[str], key: Sequence[str]) -> Optional[str]: 80 | async with connect() as conn: 81 | val = await conn.fetchval( 82 | """ 83 | SELECT value FROM kv WHERE namespace = $1 AND key = $2 84 | """, 85 | namespace, 86 | tuple(key), 87 | ) 88 | return cast(Optional[str], val) 89 | 90 | 91 | async def get_raw_key_values(namespace: str) -> Dict[Tuple[str, ...], str]: 92 | async with connect() as conn: 93 | rows = await conn.fetch( 94 | """ 95 | SELECT key, value FROM kv WHERE namespace = $1 96 | """, 97 | namespace, 98 | ) 99 | return {tuple(row["key"]): row["value"] for row in rows} 100 | 101 | 102 | async def get_namespaces() -> Sequence[str]: 103 | async with connect() as conn: 104 | rows = await conn.fetch( 105 | """ 106 | SELECT DISTINCT namespace FROM kv 107 | """ 108 | ) 109 | return [row["namespace"] for row in rows] 110 | 111 | 112 | async def set_raw_value(namespace: str, key: Sequence[str], value: Optional[str], log_value: bool = True) -> None: 113 | async with connect() as conn: 114 | if value is None: 115 | await conn.execute( 116 | """ 117 | DELETE FROM kv 118 | WHERE namespace = $1 AND key = $2 119 | """, 120 | namespace, 121 | tuple(key), 122 | ) 123 | else: 124 | await conn.execute( 125 | """ 126 | INSERT INTO kv (namespace, key, value) 127 | VALUES ($1, $2, $3) 128 | ON CONFLICT (namespace, key) DO UPDATE SET value = EXCLUDED.value 129 | """, 130 | namespace, 131 | tuple(key), 132 | value, 133 | log_data=True if log_value else {1, 2}, 134 | ) 135 | 136 | 137 | async def set_raw_values(namespace: str, dict: Dict[Sequence[str], Optional[str]], log_value: bool = False) -> None: 138 | removals = [(namespace, tuple(key)) for key, value in dict.items() if value is None] 139 | updates = [(namespace, tuple(key), value) for key, value in dict.items() if value is not None] 140 | async with connect() as conn: 141 | async with conn.transaction(): 142 | if removals: 143 | await conn.executemany( 144 | """ 145 | DELETE FROM kv 146 | WHERE namespace = $1 AND key = $2 147 | """, 148 | removals, 149 | ) 150 | if updates: 151 | await conn.executemany( 152 | """ 153 | INSERT INTO kv (namespace, key, value) 154 | VALUES ($1, $2, $3) 155 | ON CONFLICT (namespace, key) DO UPDATE SET value = EXCLUDED.value 156 | """, 157 | updates, 158 | log_data=True if log_value else {1, 2}, 159 | ) 160 | 161 | 162 | class ConfigStore(Dict[Tuple[str, ...], str]): 163 | __slots__ = ("__weakref__", "ready") 164 | ready: asyncio.Event 165 | 166 | def __init__(self, *args: object, **kwargs: object): 167 | super().__init__(*args, **kwargs) 168 | self.ready = asyncio.Event() 169 | 170 | 171 | config_stores: WeakValueDictionary[str, ConfigStore] 172 | config_stores = WeakValueDictionary() 173 | 174 | KeyType = Union[str, int, Sequence[Union[str, int]]] 175 | 176 | 177 | def encode_key(key: KeyType) -> Tuple[str, ...]: 178 | if isinstance(key, (str, int)): 179 | key = (key,) 180 | return tuple(str(k) for k in key) 181 | 182 | 183 | class Config: 184 | """ 185 | This object encapsulates access to the key-value store for a fixed module. Upon construction we load all the pairs 186 | from the DB into memory. The in-memory copy is shared across Config objects for the same module. 187 | __iter__ and __getitem__/__getattr__ will read from this in-memory copy. 188 | __setitem__/__setattr__ will update the in-memory copy. awaiting will commit the keys that were modified by this 189 | Config object to the DB (the values may have since been overwritten by other Config objects) 190 | """ 191 | 192 | __slots__ = "_namespace", "_log_value", "_store", "_dirty" 193 | _namespace: str 194 | _log_value: bool 195 | _store: ConfigStore 196 | _dirty: Set[Tuple[str, ...]] 197 | 198 | def __init__(self, namespace: str, log_value: bool, store: ConfigStore): 199 | self._namespace = namespace 200 | self._log_value = log_value 201 | self._store = store 202 | self._dirty = set() 203 | 204 | def __iter__(self) -> Iterator[Tuple[str, ...]]: 205 | return self._store.__iter__() 206 | 207 | def __getitem__(self, key: KeyType) -> object: 208 | return json_decode(self._store.get(encode_key(key))) 209 | 210 | def __setitem__(self, key: KeyType, value: object) -> None: 211 | ek = encode_key(key) 212 | ev = json_encode(value) 213 | if ev is None: 214 | self._store.pop(ek, None) 215 | else: 216 | self._store[ek] = ev 217 | self._dirty.add(ek) 218 | 219 | @util.asyncio.__await__ 220 | async def __await__(self) -> None: 221 | dirty = self._dirty 222 | self._dirty = set() 223 | try: 224 | await set_raw_values(self._namespace, {key: self._store.get(key) for key in dirty}) 225 | except: 226 | self._dirty.update(dirty) 227 | raise 228 | 229 | def __getattr__(self, key: str) -> object: 230 | if key.startswith("_"): 231 | return None 232 | return self[key] 233 | 234 | def __setattr__(self, key: str, value: object) -> None: 235 | if key.startswith("_"): 236 | return super().__setattr__(key, value) 237 | self[key] = value 238 | 239 | 240 | async def load(namespace: str, log_value: bool = False) -> Config: 241 | store = config_stores.get(namespace) 242 | if store is None: 243 | store = ConfigStore() 244 | config_stores[namespace] = store 245 | store.update(await get_raw_key_values(namespace)) 246 | store.ready.set() 247 | await store.ready.wait() 248 | return Config(namespace, log_value, store) 249 | -------------------------------------------------------------------------------- /util/db/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Callable, Collection, Optional, Sequence, Union 3 | 4 | from asyncpg import Connection, PostgresLogMessage, Record 5 | from asyncpg.cursor import CursorFactory 6 | from asyncpg.prepared_stmt import PreparedStatement 7 | from asyncpg.transaction import Transaction 8 | 9 | 10 | logger: logging.Logger = logging.getLogger(__name__) 11 | 12 | severity_map = { 13 | "DEBUG": logging.DEBUG, 14 | "LOG": logging.DEBUG, 15 | "NOTICE": logging.INFO, 16 | "INFO": logging.INFO, 17 | "WARNING": logging.WARNING, 18 | "ERROR": logging.ERROR, 19 | "FATAL": logging.ERROR, 20 | "PANIC": logging.ERROR, 21 | } 22 | 23 | 24 | def filter_single(log_data: Union[bool, Collection[int]], data: Sequence[object]) -> str: 25 | spec: Callable[[int], bool] 26 | if isinstance(log_data, bool): 27 | log_data_bool = log_data 28 | spec = lambda _: log_data_bool 29 | else: 30 | log_data_set = log_data 31 | spec = lambda i: i in log_data_set 32 | return "({})".format(",".join(repr(data[i]) if spec(i + 1) else "?" for i in range(len(data)))) 33 | 34 | 35 | def filter_multi(log_data: Union[bool, Collection[int]], data: Sequence[Sequence[object]]) -> str: 36 | spec: Callable[[int], bool] 37 | if isinstance(log_data, bool): 38 | log_data_bool = log_data 39 | spec = lambda _: log_data_bool 40 | else: 41 | log_data_set = log_data 42 | spec = lambda i: i in log_data_set 43 | return ",".join( 44 | "({})".format(",".join(repr(datum[i]) if spec(i + 1) else "?" for i in range(len(datum)))) for datum in data 45 | ) 46 | 47 | 48 | def fmt_query_single(query: str, log_data: Union[bool, Collection[int]], args: Sequence[object]) -> str: 49 | if log_data: 50 | return "{} % {}".format(query, filter_single(log_data, args)) 51 | else: 52 | return query 53 | 54 | 55 | def fmt_query_multi(query: str, log_data: Union[bool, Collection[int]], args: Sequence[Sequence[object]]) -> str: 56 | if log_data: 57 | return "{} % {}".format(query, filter_multi(log_data, args)) 58 | else: 59 | return query 60 | 61 | 62 | def fmt_table(name: str, schema: Optional[str]) -> str: 63 | return schema + "." + name if schema is not None else name 64 | 65 | 66 | def log_message(conn: Connection, msg: PostgresLogMessage) -> None: 67 | severity = getattr(msg, "severity_en") or getattr(msg, "severity") 68 | logger.log(severity_map.get(severity, logging.INFO), "{} {}".format(id(conn), msg)) 69 | 70 | 71 | def log_termination(conn: Connection) -> None: 72 | logger.debug("{} closed".format(id(conn))) 73 | 74 | 75 | class LoggingConnection(Connection): 76 | def __init__(self, proto: Any, transport: Any, *args: Any, **kwargs: Any): 77 | logger.debug("{} connected over {!r}".format(id(self), transport)) 78 | super().__init__(proto, transport, *args, **kwargs) 79 | self.add_log_listener(log_message) 80 | self.add_termination_listener(log_termination) 81 | 82 | async def copy_from_query( 83 | self, query: str, *args: object, log_data: Union[bool, Collection[int]] = True, **kwargs: object 84 | ) -> str: 85 | logger.debug("{} copy_from_query: {}".format(id(self), fmt_query_single(query, log_data, args))) 86 | return await super().copy_from_query(query, *args, **kwargs) 87 | 88 | async def copy_from_table(self, table_name: str, schema_name: Optional[str] = None, **kwargs: object) -> str: 89 | logger.debug("{}: copy_from_table: {}".format(id(self), fmt_table(table_name, schema_name))) 90 | return await super().copy_from_table(table_name, schema_name=schema_name, **kwargs) 91 | 92 | async def copy_records_to_table(self, table_name: str, schema_name: Optional[str] = None, **kwargs: object) -> str: 93 | logger.debug("{}: copy_records_to_table: {}".format(id(self), fmt_table(table_name, schema_name))) 94 | return await super().copy_records_to_table(table_name, schema_name=schema_name, **kwargs) 95 | 96 | async def copy_to_table(self, table_name: str, schema_name: Optional[str] = None, **kwargs: object) -> str: 97 | logger.debug("{}: copy_to_table: {}".format(id(self), fmt_table(table_name, schema_name))) 98 | return await super().copy_to_table(table_name, schema_name=schema_name, **kwargs) 99 | 100 | def cursor( 101 | self, query: str, *args: object, log_data: Union[bool, Collection[int]] = True, **kwargs: object 102 | ) -> CursorFactory: 103 | logger.debug("{}: cursor: {}".format(id(self), fmt_query_single(query, log_data, args))) 104 | return super().cursor(query, *args, **kwargs) 105 | 106 | async def execute( 107 | self, query: str, *args: object, log_data: Union[bool, Collection[int]] = True, **kwargs: Any 108 | ) -> str: 109 | logger.debug("{} execute: {}".format(id(self), fmt_query_single(query, log_data, args))) 110 | return await super().execute(query, *args, **kwargs) 111 | 112 | async def executemany( 113 | self, 114 | command: str, 115 | args: Sequence[Sequence[object]], 116 | log_data: Union[bool, Collection[int]] = True, 117 | **kwargs: Any, 118 | ) -> None: 119 | logger.debug("{} executemany: {}".format(id(self), fmt_query_multi(command, log_data, args))) 120 | return await super().executemany(command, args, **kwargs) 121 | 122 | async def fetch( # type: ignore 123 | self, query: str, *args: object, log_data: Union[bool, Collection[int]] = True, **kwargs: object 124 | ) -> Sequence[Record]: 125 | logger.debug("{} fetch: {}".format(id(self), fmt_query_single(query, log_data, args))) 126 | return await super().fetch(query, *args, **kwargs) 127 | 128 | async def fetchrow( # type: ignore 129 | self, query: str, *args: object, log_data: Union[bool, Collection[int]] = True, **kwargs: object 130 | ) -> Optional[Record]: 131 | logger.debug("{} fetchrow: {}".format(id(self), fmt_query_single(query, log_data, args))) 132 | return await super().fetchrow(query, *args, **kwargs) 133 | 134 | async def fetchval( # type: ignore 135 | self, query: str, *args: object, log_data: Union[bool, Collection[int]] = True, **kwargs: Any 136 | ) -> Optional[Record]: 137 | logger.debug("{} fetchval: {}".format(id(self), fmt_query_single(query, log_data, args))) 138 | return await super().fetchval(query, *args, **kwargs) 139 | 140 | def transaction(self, **kwargs: Any) -> Transaction: 141 | logger.debug("{} transaction".format(id(self))) 142 | return super().transaction(**kwargs) 143 | 144 | async def prepare(self, query: str, **kwargs: object) -> PreparedStatement: 145 | logger.debug("{} prepare: {}".format(id(self), query)) 146 | # TODO: hook into PreparedStatement 147 | return await super().prepare(query, **kwargs) 148 | -------------------------------------------------------------------------------- /util/digraph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Generic, Iterable, Iterator, Set, TypeVar 4 | 5 | 6 | T = TypeVar("T") 7 | 8 | 9 | class Digraph(Generic[T]): 10 | """A directed graph with no isolated vertices and no duplicate edges.""" 11 | 12 | __slots__ = "fwd", "bck" 13 | fwd: Dict[T, Set[T]] 14 | bck: Dict[T, Set[T]] 15 | 16 | def __init__(self) -> None: 17 | """Create an empty graph.""" 18 | self.fwd = {} 19 | self.bck = {} 20 | 21 | def add_edge(self, x: T, y: T) -> None: 22 | """Add an edge from x to y.""" 23 | if x not in self.fwd: 24 | self.fwd[x] = set() 25 | self.fwd[x].add(y) 26 | if y not in self.bck: 27 | self.bck[y] = set() 28 | self.bck[y].add(x) 29 | 30 | def edges_to(self, x: T) -> Set[T]: 31 | """Return a (read-only) set of edges into x.""" 32 | return self.bck[x] if x in self.bck else set() 33 | 34 | def edges_from(self, x: T) -> Set[T]: 35 | """Return a (read-only) set of edges from x.""" 36 | return self.fwd[x] if x in self.fwd else set() 37 | 38 | def paths_from(self, x: T) -> Iterator[T]: 39 | """Return vertices that can be reached from x via a path.""" 40 | seen: Set[T] = set() 41 | 42 | def dfs(x: T) -> Iterator[T]: 43 | if x in seen: 44 | return 45 | seen.add(x) 46 | yield x 47 | if x in self.fwd: 48 | for y in self.fwd[x]: 49 | yield from dfs(y) 50 | 51 | yield from dfs(x) 52 | 53 | def paths_to(self, x: T) -> Iterator[T]: 54 | """Return vertices that can reached x via a path.""" 55 | seen: Set[T] = set() 56 | 57 | def dfs(x: T) -> Iterator[T]: 58 | if x in seen: 59 | return 60 | seen.add(x) 61 | yield x 62 | if x in self.bck: 63 | for y in self.bck[x]: 64 | yield from dfs(y) 65 | 66 | yield from dfs(x) 67 | 68 | def subgraph_paths_from(self, x: T) -> Digraph[T]: 69 | """Return an induced subgraph of exactly those vertices that can be reached from x via a path.""" 70 | graph: Digraph[T] = Digraph() 71 | seen: Set[T] = set() 72 | 73 | def dfs(x: T) -> None: 74 | if x in seen: 75 | return 76 | seen.add(x) 77 | if x in self.fwd: 78 | for y in self.fwd[x]: 79 | graph.add_edge(x, y) 80 | dfs(y) 81 | 82 | dfs(x) 83 | return graph 84 | 85 | def subgraph_paths_to(self, x: T) -> Digraph[T]: 86 | """Return an induced subgraph of exactly those vertices that can reach x via a path.""" 87 | graph: Digraph[T] = Digraph() 88 | seen: Set[T] = set() 89 | 90 | def dfs(x: T) -> None: 91 | if x in seen: 92 | return 93 | seen.add(x) 94 | if x in self.bck: 95 | for y in self.bck[x]: 96 | graph.add_edge(y, x) 97 | dfs(y) 98 | 99 | dfs(x) 100 | return graph 101 | 102 | def topo_sort_fwd(self, sources: Iterable[T] = ()) -> Iterator[T]: 103 | """ 104 | Iterate through vertices in such a way that whenever there is an edge from x to y, x will come up earlier in 105 | iteration than y. The sources are forcibly included in the iteration. 106 | """ 107 | seen: Set[T] = set() 108 | 109 | def dfs(x: T) -> Iterator[T]: 110 | if x in seen: 111 | return 112 | seen.add(x) 113 | if x in self.bck: 114 | for y in self.bck[x]: 115 | yield from dfs(y) 116 | yield x 117 | 118 | for x in self.fwd: 119 | yield from dfs(x) 120 | for x in self.bck: 121 | yield from dfs(x) 122 | for x in sources: 123 | yield from dfs(x) 124 | 125 | def topo_sort_bck(self, sources: Iterable[T] = ()) -> Iterator[T]: 126 | """ 127 | Iterate through vertices in such a way that whenever there is an edge from x to y, x will come up later in 128 | iteration than y. The sources are forcibly included in the iteration. 129 | """ 130 | seen: Set[T] = set() 131 | 132 | def dfs(x: T) -> Iterator[T]: 133 | if x in seen: 134 | return 135 | seen.add(x) 136 | if x in self.fwd: 137 | for y in self.fwd[x]: 138 | yield from dfs(y) 139 | yield x 140 | 141 | for x in self.bck: 142 | yield from dfs(x) 143 | for x in self.fwd: 144 | yield from dfs(x) 145 | for x in sources: 146 | yield from dfs(x) 147 | 148 | def del_edges_from(self, x: T) -> None: 149 | """Delete all edges from x.""" 150 | if x in self.fwd: 151 | for y in self.fwd[x]: 152 | self.bck[y].discard(x) 153 | if not self.bck[y]: 154 | del self.bck[y] 155 | del self.fwd[x] 156 | 157 | def del_edges_to(self, x: T) -> None: 158 | """Delete all edges into x.""" 159 | if x in self.bck: 160 | for y in self.bck[x]: 161 | self.fwd[y].discard(x) 162 | if not self.fwd[y]: 163 | del self.fwd[y] 164 | del self.bck[x] 165 | -------------------------------------------------------------------------------- /util/frozen_dict.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Generic, Iterable, Iterator, Optional, Tuple, TypeVar, Union, overload 4 | 5 | import yaml 6 | 7 | 8 | K = TypeVar("K") 9 | V = TypeVar("V", covariant=True) 10 | T = TypeVar("T") 11 | 12 | 13 | class FrozenDict(Generic[K, V]): 14 | """ 15 | Immutable dict. Doesn't actually store the underlying dict as a field, instead its methods are closed over the 16 | underlying dict object. 17 | """ 18 | 19 | __slots__ = ( 20 | "___iter__", 21 | "__getitem__", 22 | "__len__", 23 | "__str__", 24 | "__repr__", 25 | "__eq__", 26 | "__ne__", 27 | "__or__", 28 | "__ror__", 29 | "__contains__", 30 | "__reversed__", 31 | "copy", 32 | "get", 33 | "items", 34 | "keys", 35 | "values", 36 | ) 37 | 38 | def __init__(self, *args: object, **kwargs: object): 39 | dct: Dict[K, V] = dict(*args, **kwargs) 40 | 41 | def __iter__() -> Iterator[K]: 42 | return dct.__iter__() 43 | 44 | self.___iter__ = __iter__ 45 | 46 | def __getitem__(key: K, /) -> V: 47 | return dct.__getitem__(key) 48 | 49 | self.__getitem__ = __getitem__ 50 | 51 | def __len__() -> int: 52 | return dct.__len__() 53 | 54 | self.__len__ = __len__ 55 | 56 | def __str__() -> str: 57 | return "FrozenDict({})".format(dct.__str__()) 58 | 59 | self.__str__ = __str__ 60 | 61 | def __repr__() -> str: 62 | return "FrozenDict({})".format(dct.__repr__()) 63 | 64 | self.__repr__ = __repr__ 65 | 66 | def __eq__(other: object, /) -> bool: 67 | return other.__eq__(dct) if isinstance(other, FrozenDict) else dct.__eq__(other) 68 | 69 | self.__eq__ = __eq__ 70 | 71 | def __ne__(other: object, /) -> bool: 72 | return other.__ne__(dct) if isinstance(other, FrozenDict) else dct.__ne__(other) 73 | 74 | self.__ne__ = __ne__ 75 | 76 | def __or__(other: Union[Dict[K, T], FrozenDict[K, T]], /) -> FrozenDict[K, Union[V, T]]: 77 | return other.__ror__(dct) if isinstance(other, FrozenDict) else FrozenDict(dct.__or__(other)) 78 | 79 | self.__or__ = __or__ 80 | 81 | def __ror__(other: Union[Dict[K, T], FrozenDict[K, T]], /) -> FrozenDict[K, Union[V, T]]: 82 | return other.__or__(dct) if isinstance(other, FrozenDict) else FrozenDict(dct.__ror__(other)) 83 | 84 | self.__ror__ = __ror__ 85 | 86 | def __contains__(key: object, /) -> bool: 87 | return dct.__contains__(key) 88 | 89 | self.__contains__ = __contains__ 90 | 91 | def __reversed__() -> Iterator[K]: 92 | return dct.__reversed__() 93 | 94 | self.__reversed__ = __reversed__ 95 | 96 | def copy() -> Dict[K, V]: 97 | return dct.copy() 98 | 99 | self.copy = copy 100 | 101 | @overload 102 | def get(key: K, /) -> Optional[V]: ... 103 | 104 | @overload 105 | def get(key: K, default: T, /) -> Union[V, T]: ... 106 | 107 | def get(key: K, default: Optional[T] = None) -> Optional[Union[V, T]]: 108 | return dct.get(key, default) 109 | 110 | self.get = get 111 | 112 | def items() -> Iterable[Tuple[K, V]]: 113 | return dct.items() 114 | 115 | self.items = items 116 | 117 | def keys() -> Iterable[K]: 118 | return dct.keys() 119 | 120 | self.keys = keys 121 | 122 | def values() -> Iterable[V]: 123 | return dct.values() 124 | 125 | self.values = values 126 | 127 | def __iter__(self) -> Iterator[K]: 128 | return self.___iter__() 129 | 130 | 131 | yaml.add_representer(FrozenDict, lambda dumper, data: dumper.represent_dict(data)) # type: ignore 132 | -------------------------------------------------------------------------------- /util/frozen_list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Generic, Iterable, Iterator, List, Optional, SupportsIndex, TypeVar, Union, overload 4 | 5 | import yaml 6 | 7 | 8 | T = TypeVar("T", covariant=True) 9 | 10 | 11 | class FrozenList(Generic[T]): 12 | """ 13 | Immutable list. Doesn't actually store the underlying list as a field, instead its methods are closed over the 14 | underlying list object. 15 | """ 16 | 17 | __slots__ = ( 18 | "___iter__", 19 | "__getitem__", 20 | "__len__", 21 | "__str__", 22 | "__repr__", 23 | "__gt__", 24 | "__lt__", 25 | "__ge__", 26 | "__le__", 27 | "__eq__", 28 | "__ne__", 29 | "__mul__", 30 | "__rmul__", 31 | "__add__", 32 | "__radd__", 33 | "__contains__", 34 | "copy", 35 | "index", 36 | "count", 37 | "without", 38 | ) 39 | 40 | def __init__(self, gen: Iterable[T] = (), /): 41 | lst = list(gen) 42 | 43 | def __iter__() -> Iterator[T]: 44 | return lst.__iter__() 45 | 46 | self.___iter__ = __iter__ 47 | 48 | @overload 49 | def __getitem__(index: SupportsIndex, /) -> T: ... 50 | 51 | @overload 52 | def __getitem__(index: slice, /) -> FrozenList[T]: ... 53 | 54 | def __getitem__(index: Union[SupportsIndex, slice], /) -> Union[T, FrozenList[T]]: 55 | if isinstance(index, slice): 56 | return FrozenList(lst.__getitem__(index)) 57 | else: 58 | return lst.__getitem__(index) 59 | 60 | self.__getitem__ = __getitem__ 61 | 62 | def __len__() -> int: 63 | return lst.__len__() 64 | 65 | self.__len__ = __len__ 66 | 67 | def __str__() -> str: 68 | return "FrozenList({})".format(lst.__str__()) 69 | 70 | self.__str__ = __str__ 71 | 72 | def __repr__() -> str: 73 | return "FrozenList({})".format(lst.__repr__()) 74 | 75 | self.__repr__ = __repr__ 76 | 77 | def __gt__(other: Union[List[T], FrozenList[T]], /) -> bool: 78 | return other.__lt__(lst) if isinstance(other, FrozenList) else lst.__gt__(other) 79 | 80 | self.__gt__ = __gt__ 81 | 82 | def __lt__(other: Union[List[T], FrozenList[T]], /) -> bool: 83 | return other.__gt__(lst) if isinstance(other, FrozenList) else lst.__lt__(other) 84 | 85 | self.__lt__ = __lt__ 86 | 87 | def __ge__(other: Union[List[T], FrozenList[T]], /) -> bool: 88 | return other.__le__(lst) if isinstance(other, FrozenList) else lst.__ge__(other) 89 | 90 | self.__ge__ = __ge__ 91 | 92 | def __le__(other: Union[List[T], FrozenList[T]], /) -> bool: 93 | return other.__ge__(lst) if isinstance(other, FrozenList) else lst.__le__(other) 94 | 95 | self.__le__ = __le__ 96 | 97 | def __eq__(other: object, /) -> bool: 98 | return other.__eq__(lst) if isinstance(other, FrozenList) else lst.__eq__(other) 99 | 100 | self.__eq__ = __eq__ 101 | 102 | def __ne__(other: object, /) -> bool: 103 | return other.__ne__(lst) if isinstance(other, FrozenList) else lst.__ne__(other) 104 | 105 | self.__ne__ = __ne__ 106 | 107 | def __mul__(other: SupportsIndex, /) -> FrozenList[T]: 108 | return FrozenList(lst.__mul__(other)) 109 | 110 | self.__mul__ = __mul__ 111 | 112 | def __rmul__(other: SupportsIndex, /) -> FrozenList[T]: 113 | return FrozenList(lst.__rmul__(other)) 114 | 115 | self.__rmul__ = __rmul__ 116 | 117 | def __add__(other: Union[List[T], FrozenList[T]], /) -> FrozenList[T]: 118 | return other.__radd__(lst) if isinstance(other, FrozenList) else FrozenList(lst.__add__(other)) 119 | 120 | self.__add__ = __add__ 121 | 122 | def __radd__(other: Union[List[T], FrozenList[T]], /) -> FrozenList[T]: 123 | return other.__add__(lst) if isinstance(other, FrozenList) else FrozenList(other.__add__(lst)) 124 | 125 | self.__radd__ = __radd__ 126 | 127 | def __contains__(other: object, /) -> bool: 128 | return lst.__contains__(other) 129 | 130 | self.__contains__ = __contains__ 131 | 132 | def copy() -> List[T]: 133 | return lst.copy() 134 | 135 | self.copy = copy 136 | 137 | @overload 138 | def index(value: object, /) -> int: ... 139 | 140 | @overload 141 | def index(value: object, start: SupportsIndex, /) -> int: ... 142 | 143 | @overload 144 | def index(value: object, start: SupportsIndex, stop: SupportsIndex, /) -> int: ... 145 | 146 | def index(value: object, start: Optional[SupportsIndex] = None, stop: Optional[SupportsIndex] = None, /) -> int: 147 | if stop is None: 148 | if start is None: 149 | return lst.index(value) # type: ignore 150 | else: 151 | return lst.index(value, start) # type: ignore 152 | elif start is None: 153 | return lst.index(value, 0, stop) # type: ignore 154 | else: 155 | return lst.index(value, start, stop) # type: ignore 156 | 157 | self.index = index 158 | 159 | def count(other: object) -> int: 160 | return lst.count(other) # type: ignore 161 | 162 | self.count = count 163 | 164 | def without(other: object) -> FrozenList[T]: 165 | return FrozenList(value for value in lst if value != other) 166 | 167 | self.without = without 168 | 169 | def __iter__(self) -> Iterator[T]: 170 | return self.___iter__() 171 | 172 | 173 | yaml.add_representer(FrozenList, lambda dumper, data: dumper.represent_list(data)) 174 | -------------------------------------------------------------------------------- /util/restart.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import logging 3 | import os 4 | import sys 5 | 6 | import bot.main_tasks 7 | 8 | 9 | will_restart: bool = False 10 | 11 | logger: logging.Logger = logging.getLogger(__name__) 12 | 13 | 14 | @atexit.register 15 | def atexit_restart_maybe() -> None: 16 | if will_restart: 17 | logger.info("Re-executing {!r} {!r}".format(sys.executable, sys.argv)) 18 | try: 19 | os.execv(sys.executable, [sys.executable] + sys.argv) 20 | except: 21 | logger.critical("Restart failed", exc_info=True) 22 | 23 | 24 | def restart() -> None: 25 | """Restart the bot by stopping the event loop and exec'ing during the shutdown of the python interpreter.""" 26 | global will_restart 27 | logger.info("Restart requested", stack_info=True) 28 | will_restart = True 29 | bot.main_tasks.cancel() 30 | -------------------------------------------------------------------------------- /util/setup/__main__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from sqlalchemy import delete, select 5 | from sqlalchemy.ext.asyncio import AsyncSession 6 | from sqlalchemy.schema import CreateSchema 7 | 8 | import plugins 9 | 10 | 11 | logging.basicConfig() 12 | logging.getLogger().setLevel(logging.INFO) 13 | 14 | manager = plugins.PluginManager(["bot", "util"]) 15 | manager.register() 16 | 17 | import bot.acl 18 | import bot.autoload 19 | import bot.commands 20 | import util.db 21 | 22 | 23 | async def async_main() -> None: 24 | async with AsyncSession(util.db.engine) as session: 25 | logging.info("Connecting to database") 26 | await session.execute(select(1)) 27 | logging.info("Creating schema for bot.acl") 28 | await util.db.init_for( 29 | "bot.acl", util.db.get_ddl(CreateSchema("permissions"), bot.acl.registry.metadata.create_all) 30 | ) 31 | logging.info("Creating schema for bot.commands") 32 | await util.db.init_for("bot.commands", util.db.get_ddl(bot.commands.registry.metadata.create_all)) 33 | logging.info("Creating schema for bot.autoload") 34 | await util.db.init_for("bot.autoload", util.db.get_ddl(bot.autoload.registry.metadata.create_all)) 35 | 36 | logging.info("Deleting ACLs") 37 | await session.execute(delete(bot.acl.CommandPermissions)) 38 | await session.execute(delete(bot.acl.ActionPermissions)) 39 | await session.execute(delete(bot.acl.ACL)) 40 | admin_id = await asyncio.get_event_loop().run_in_executor( 41 | None, lambda: int(input("Input your Discord user ID: ")) 42 | ) 43 | logging.info('Creating "admin" ACL assigned to {}'.format(admin_id)) 44 | session.add(bot.acl.ACL(name="admin", data=bot.acl.UserACL(admin_id).serialize(), meta="admin")) 45 | session.add(bot.acl.ActionPermissions(name="acl_override", acl="admin")) 46 | 47 | logging.info("Deleting global command config") 48 | await session.execute(delete(bot.commands.GlobalConfig)) 49 | prefix = await asyncio.get_event_loop().run_in_executor(None, lambda: input("Input command prefix: ")) 50 | logging.info("Creating global command config with prefix {!r}".format(prefix)) 51 | session.add(bot.commands.GlobalConfig(prefix=prefix)) 52 | 53 | logging.info("Deleting autoload") 54 | await session.execute(delete(bot.autoload.AutoloadedPlugin)) 55 | autoloaded = ["plugins.eval", "plugins.bot_manager", "plugins.db_manager"] 56 | logging.info("Adding {!r} to autoload".format(autoloaded)) 57 | for p in autoloaded: 58 | session.add(bot.autoload.AutoloadedPlugin(name=p, order=0)) 59 | 60 | logging.info("Committing transaction") 61 | await session.commit() 62 | 63 | await manager.unload_all() 64 | 65 | 66 | asyncio.run(async_main()) 67 | --------------------------------------------------------------------------------