├── .gitattributes ├── .gitignore ├── Dockerfile ├── INSTALL.md ├── LICENSE.txt ├── README.md ├── ace.py ├── cogs ├── ahk │ └── ahk.py ├── backend │ ├── error_handler.py │ └── logger.py ├── configuration.py ├── dwitter.py ├── fun.py ├── games.py ├── hl.py ├── meta.py ├── mixins.py ├── mod.py ├── owner.py ├── remind.py ├── roles.py ├── stars.py ├── tags.py ├── welcome.py └── whois.py ├── compose.yaml ├── docs_service ├── Dockerfile ├── aggregator.py ├── api.py ├── build.py ├── migrate.sql ├── parser_instances │ ├── common.py │ ├── v1.py │ └── v2.py ├── parsers.py └── requirements.txt ├── main.py ├── migrate.py ├── migrate.sql ├── neural ├── Dockerfile ├── README.md ├── api.py ├── data_fetcher.py ├── dataset.py ├── make_embeddings.py ├── model.py ├── process_glove.py ├── requirements.txt ├── text_processor.py └── train.py ├── pyproject.toml ├── requirements.txt └── utils ├── commanderrorlogic.py ├── configtable.py ├── context.py ├── converters.py ├── databasetimer.py ├── fakeuser.py ├── guildconfigrecord.py ├── help.py ├── html2markdown.py ├── pager.py ├── string.py └── time.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .vscode 4 | data 5 | error 6 | feedback 7 | logs 8 | ahk_eval 9 | venv 10 | config.py 11 | model 12 | 13 | test.py 14 | ids.py 15 | test_md.py 16 | test_docs_parser.py 17 | notes.md 18 | cogs/ahk/internal 19 | 20 | torch_config.py 21 | *.pth 22 | neural/embeddings 23 | neural/models 24 | neural/corpus 25 | neural/corpus.zip 26 | 27 | docs_service/docs_v1 28 | docs_service/docs_v2 29 | docs_service/docs_v1.zip 30 | docs_service/docs_v2.zip 31 | build_tag_upload.sh 32 | attachments 33 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.2 2 | 3 | COPY requirements.txt ./ 4 | RUN pip install --no-cache-dir -r requirements.txt 5 | 6 | COPY cogs cogs 7 | COPY utils utils 8 | COPY ace.py . 9 | COPY main.py . 10 | 11 | CMD ["python3", "-u", "main.py"] 12 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | **If you just want to add this bot to your Discord sever, 4 | it is recommended that you add the official instance 5 | with the link found [here](README.md#installing-the-bot).** 6 | 7 | This file describes how you can setup your own insance 8 | on your local PC or on a dedicated sever. 9 | If you want to permanently host your own instance, 10 | you should probably put it on a dedicated server, 11 | but for development your local PC will suffice. 12 | 13 | ## Requirements 14 | 15 | * PostgreSQL 16 | * Python 3.9 17 | * PIP (should come with Python) 18 | * Git (or GitHub etc.) 19 | 20 | Please install these according to their instructions. 21 | 22 | ## Setting up PostgreSQL 23 | 24 | * **Windows** 25 | * When installing PostgreSQL, you will be asked to choose a password. **Remember it.** 26 | * Open up a Command Prompt and run `psql -U postgres`. 27 | * Log in using the password you chose during the installation. 28 | * **Linux (and \*nix in general)** 29 | * The installation of PostgreSQL should have created a user account called `postgres`. 30 | (If not, it's probably best to reinstall PostgreSQL or search online for a solution.) 31 | * Log into that user's account (`sudo -iu postgres` for example). 32 | * Run `psql`. 33 | * In this `psql` shell, run the following commands: 34 | ```postgresql 35 | CREATE ROLE ace WITH LOGIN PASSWORD 'choose_a_password'; 36 | CREATE DATABASE acebot OWNER ace; 37 | \c acebot 38 | CREATE EXTENSION pg_trgm; 39 | \q -- quit out of psql 40 | ``` 41 | * On Linux, you can now `exit` to return to your own user account. 42 | 43 | ## Setting up AceBot 44 | 45 | * Clone this repository and change into its root folder: 46 | `git clone --recurse-submodules https://github.com/Run1e/AceBot && cd AceBot` 47 | * Create a file called `config.py` and add this content to it: 48 | ```python 49 | import logging 50 | import disnake 51 | 52 | DESCRIPTION = '''A.C.E. - Non-official Instance''' 53 | 54 | BOT_TOKEN = 'your_bot_token' 55 | BOT_INTENTS = disnake.Intents.all() 56 | DEFAULT_PREFIX = '.' 57 | OWNER_ID = your_discord_id # do not put quotes around this 58 | DB_BIND = 'your_database_bind' 59 | LOG_LEVEL = logging.DEBUG # logging.INFO recommended for production 60 | 61 | BOT_ACTIVITY = disnake.Game(name='@me for help menu') 62 | 63 | CLOUDAHK_URL = None 64 | CLOUDAHK_USER = None 65 | CLOUDAHK_PASS = None 66 | 67 | DOCS_API_URL = None 68 | 69 | DBL_KEY = None 70 | THECATAPI_KEY = None 71 | WOLFRAM_KEY = None 72 | APIXU_KEY = None 73 | 74 | GAME_PRED_URL = "" 75 | 76 | TEST_GUILDS = None 77 | HELP_CONTROLLERS = {} 78 | ``` 79 | * You can get your bot token from the [Discord Developer Portal](https://discord.com/developers/applications). 80 | If you haven't already: 81 | * Create a new application (its name doesn't matter). 82 | * Go to “Bot” in the left sidebar. 83 | * Click “Add Bot”, read the warning and accept it. 84 | * **Important:** You must enable all the privileged Intents that you are requesting 85 | (by default you are requesting all Intents, so you should enable all of them in that case). 86 | * Your database bind will look like this: 87 | ``` 88 | postgresql://ace:your_password@localhost/acebot 89 | ``` 90 | * Your owner ID is the ID of your Discord account. 91 | * To obtain it, open your Discord user settings, under "APP SETTINGS", go to “Advanced”, and enable “Developer Mode”. 92 | Exit the settings, then right-click yourself anywhere and click “Copy ID”. 93 | * Create another file called `ids.py`, this time with these contents: 94 | ```python 95 | AHK_GUILD_ID = None 96 | 97 | # roles 98 | STAFF_ROLE_ID = None 99 | FORUM_ADM_ROLE_ID = None 100 | FORUM_MOD_ROLE_ID = None 101 | VIP_ROLE_ID = None 102 | LOUNGE_ROLE_ID = None 103 | 104 | # level roles 105 | LEVEL_ROLE_IDS = {} 106 | 107 | # channels 108 | ROLES_CHAN_ID = None 109 | RULES_CHAN_ID = None 110 | GENERAL_CHAN_ID = None 111 | LOGS_CHAN_ID = None 112 | FORUM_THRD_CHAN_ID = None 113 | HELP_FORUM_CHAN_ID = = None 114 | ACTIVITY_CHAN_ID = None 115 | EDITED_CHAN_ID = None 116 | DELETED_CHAN_ID = None 117 | GUILD_CHAN_ID = None 118 | EMOJI_SUGGESTIONS_CHAN_ID = None 119 | SUGGESTIONS_CHAN_ID = None 120 | GET_HELP_CHAN_ID = None 121 | 122 | # messages 123 | RULES_MSG_ID = None 124 | 125 | # category ids 126 | OPEN_CATEGORY_ID = None 127 | ACTIVE_CATEGORY_ID = None 128 | ACTIVE_INFO_CHAN_ID = None 129 | CLOSED_CATEGORY_ID = None 130 | 131 | HELP_CHANNEL_IDS = {} 132 | IGNORE_ACTIVE_CHAN_IDS = tuple() 133 | ``` 134 | Both of these files are templates and you can change almost any value in them as you see fit. 135 | Particularly `ids.py` needs configuring if you want to use the bot to its full potential; 136 | you can get the ID of channels, categories, roles, etc. by right-clicking on them 137 | (if you've got the Developer Mode enabled). 138 | 139 | ## Finishing up 140 | 141 | * Run `pip install -r requirements.txt`. 142 | * Run `python migrate.py` to setup all necessary databases automatically. 143 | 144 | ## That's it! 145 | 146 | You should be able to start the bot with `python main.py`! 147 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 RUNIE 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ![Avatar](https://i.imgur.com/Sv7L0a1.png) A.C.E. - Autonomous Command Executor 2 | 3 | [![Discord Bots](https://top.gg/api/widget/status/367977994486022146.svg)](https://discordbots.org/bot/367977994486022146) 4 | [![Discord Bots](https://top.gg/api/widget/servers/367977994486022146.svg)](https://discordbots.org/bot/367977994486022146) 5 | 6 | A fun, general purpose Discord bot! 7 | 8 | [Click here to add it to your server!](https://discordapp.com/oauth2/authorize?&client_id=367977994486022146&scope=bot&permissions=268823632) 9 | 10 | Support server invite [here.](https://discord.gg/3MsSbRxbKV) 11 | 12 | ## Table of Contents 13 | 14 | * [Usage](#usage) 15 | * [General commands](#general-commands) 16 | * [Starboard](#starboard) 17 | * [Tags](#tags) 18 | * [Bot configuration](#bot-configuration) 19 | * [Moderation](#moderation) 20 | * [Welcome](#welcome) 21 | * [Roles](#roles) 22 | * [Feedback](#feedback) 23 | * [Installing the bot](#installing-the-bot) 24 | * [Acknowledgements](#acknowledgements) 25 | 26 | ## Usage 27 | 28 | ### General commands 29 | 30 | The bot has a plethora of commands. To invoke these, send a message starting with `.` followed by the command name. 31 | For example, `.woof` would invoke the `woof` command. 32 | 33 | The `help` command can be ran at any point for reference how to use the bot. If you need help about a specific command, 34 | `help *command name here*` can be run. 35 | 36 | ``` 37 | remindme Have the bot remind you about something in the future 38 | wolfram Query Wolfram Alpha 39 | weather Get the weather at a location 40 | choose Pick an item from a list of choices 41 | hl Highlight some code 42 | info Information about a member 43 | server Information about the server 44 | fact Get a random fact 45 | 8 Classic 8ball! 46 | And many more! 47 | ``` 48 | 49 | It can fetch cute random images on demand! 50 | ``` 51 | woof Get a random doggo picture 52 | meow Get a random cat picture 53 | floof Get a random fox picture 54 | quack Get a random duck picture 55 | ``` 56 | 57 | ### Starboard 58 | 59 | Classic Starboard implementation. 60 | 61 | A starboard is a channel where "starred" messages are posted. A message can be starred by anyone by reacting to it with 62 | the :star: emoji. At this point anyone can additionally star the message, giving it more stars. 63 | 64 | To create a starboard use the `starboard create` command. This will create a channel where starred messages will be posted. 65 | 66 | Automatic starboard cleaning can be enabled using `starboard threshold`. To have starred messages with fewer than 5 stars be 67 | removed after a week, do `starboard threshold 5`. 68 | To disable auto-cleaning, do `starboard threshold`. The starboard can also be 69 | temporarily disabled (to clean it, for example) using `starboard lock` and enabled using `starboard unlock`. 70 | 71 | Other misc. starboard commands: 72 | ``` 73 | star Star a message by ID 74 | unstar Unstar a message by ID 75 | star info Show information about a starred message 76 | star show Bring up a starred message in the current channel 77 | star delete Delete a starred message from the starboard. Appropriate permissions/relations required to run this. 78 | ``` 79 | Run `help starboard` for a complete list. 80 | 81 | ### Tags 82 | 83 | The tag system is immensely useful for bringing up text or images on demand. 84 | 85 | To try it out, you can create a new tag interactively by simply running `tag make`. 86 | 87 | Here's an example of the tag system in work: 88 | # ![Tag demonstration](https://i.imgur.com/LxEteHI.gif) 89 | 90 | A few of the tag commands: 91 | ``` 92 | tag Bring up a tag 93 | tag create Create a new tag 94 | tag make Create a new tag interactively (recommended!) 95 | tag edit Edit one of your tags 96 | tag delete Delete one of your tags 97 | tag list List a members tags, or all the server tags 98 | tag info Extensive information about a tag 99 | tags List all of your own tags 100 | ``` 101 | Run `help tag` for a complete list. 102 | 103 | ### Bot configuration 104 | 105 | The prefix of the bot is configurable using the `prefix` command. If you forget the current prefix, the help menu can be brought up by simply mentioning the bot. 106 | ``` 107 | .prefix ! 108 | # new commands are now invoked using ! 109 | !woof 110 | ``` 111 | 112 | A role can be set that can also configure the bot using the `modrole` command. 113 | Members with the mod role can delete any tag, delete starred messages, change the prefix, etc. Only thing members with this role can't do is change the mod role, as this requires administrator privileges. 114 | ``` 115 | .modrole @somerole 116 | ``` 117 | To see what the current configuration is, run `config`. 118 | 119 | 120 | ### Moderation 121 | 122 | To enable member muting, create a role that prohibits sending messages and set it with `muterole `. 123 | To mute a member do `mute ` and to unmute use `unmute `. Similarly, `ban` and `unban` also exist. 124 | 125 | You can issue tempbans and tempmutes: 126 | ``` 127 | tempban [reason] 128 | tempmute [reason] 129 | 130 | examples: 131 | tempban @dave 5 days Stop spamming! 132 | tempmute @bobby 1 hr Read rules again. 133 | ``` 134 | 135 | Tempbanning will make the bot attempt to send a DM to the banned member with the reasoning for the ban, if provided. 136 | 137 | ### Welcome 138 | 139 | The bot can be configured to send a message each time a new member joins your server. 140 | 141 | To set this up, first specify a channel the messages should be sent in using `welcome channel`. Then set up a welcome message using `welcome message`. A list of replacements is listed here: 142 | ``` 143 | {user} Replaced with a mention of the member that joined 144 | {guild} Replaced with the server name 145 | {member_count} Replaced with the server member count 146 | ``` 147 | 148 | To see that the welcome system works you can run `welcome test`. If it fails it will tell you what to fix! 149 | 150 | Run `help welcome` for a complete list of commands. 151 | 152 | ### Roles 153 | 154 | The bot can create a role selector for you. Here's an example of such a selector: 155 | # ![Role selector](https://i.imgur.com/1RoSHLs.png) 156 | By clicking the reactions the user is given the correlated role. 157 | 158 | Run `help roles` for a full list of commands. 159 | 160 | ### Feedback 161 | 162 | You can send thoughts, feedback and suggestions directly to me by using the `feedback` command! 163 | 164 | ## Installing the bot 165 | 166 | **If you want this bot in your server, I would prefer if you invite the official instance using the 167 | [invite link](https://discordapp.com/oauth2/authorize?&client_id=367977994486022146&scope=bot&permissions=268823632).** 168 | 169 | Nevertheless, if you want to set it up for yourself, follow the instuctions in [`INSTALL.md`](INSTALL.md). 170 | 171 | ## Acknowledgements 172 | 173 | Contributors: CloakerSmoker, GeekDude, sjc, Cap'n Odin 174 | -------------------------------------------------------------------------------- /ace.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging.handlers 3 | import os 4 | from datetime import datetime 5 | 6 | import aiohttp 7 | import asyncpg 8 | from disnake.ext import commands 9 | 10 | from config import * 11 | from utils.configtable import ConfigTable 12 | from utils.context import AceContext 13 | from utils.guildconfigrecord import GuildConfigRecord 14 | from utils.help import PaginatedHelpCommand 15 | from utils.string import po 16 | 17 | EXTENSIONS = ( 18 | "cogs.backend.error_handler", 19 | "cogs.backend.logger", 20 | "cogs.fun", 21 | "cogs.configuration", 22 | "cogs.tags", 23 | "cogs.stars", 24 | "cogs.meta", 25 | "cogs.mod", 26 | "cogs.games", 27 | "cogs.remind", 28 | "cogs.welcome", 29 | "cogs.roles", 30 | "cogs.whois", 31 | "cogs.hl", 32 | "cogs.ahk.ahk", 33 | "cogs.ahk.internal.logger", 34 | "cogs.ahk.internal.security", 35 | "cogs.dwitter", 36 | "cogs.owner", 37 | ) 38 | 39 | 40 | class AceBot(commands.Bot): 41 | support_link = "https://discord.gg/3MsSbRxbKV" 42 | 43 | aiohttp: aiohttp.ClientSession 44 | db: asyncpg.pool 45 | config: ConfigTable 46 | startup_time: datetime 47 | 48 | def __init__(self, **kwargs): 49 | super().__init__( 50 | command_prefix=self.prefix_resolver, 51 | owner_id=OWNER_ID, 52 | description=DESCRIPTION, 53 | help_command=PaginatedHelpCommand(), 54 | max_messages=20000, 55 | activity=disnake.CustomActivity(name="Custom Status", state="Booting up..."), 56 | status=disnake.Status.do_not_disturb, 57 | **kwargs, 58 | ) 59 | 60 | # created in login 61 | self.db = None 62 | 63 | self.config = ConfigTable( 64 | self, table="config", primary="guild_id", record_class=GuildConfigRecord 65 | ) 66 | 67 | self.startup_time = datetime.utcnow() 68 | 69 | self.log = logging.getLogger("acebot") 70 | 71 | aiohttp_log = logging.getLogger("http") 72 | 73 | async def on_request_end(session, ctx, end): 74 | resp = end.response 75 | aiohttp_log.info( 76 | "[%s %s] %s %s (%s)", 77 | str(resp.status), 78 | resp.reason, 79 | end.method.upper(), 80 | end.url, 81 | resp.content_type, 82 | ) 83 | 84 | trace_config = aiohttp.TraceConfig() 85 | trace_config.on_request_end.append(on_request_end) 86 | 87 | self.aiohttp = aiohttp.ClientSession( 88 | loop=self.loop, 89 | timeout=aiohttp.ClientTimeout(total=5), 90 | trace_configs=[trace_config], 91 | ) 92 | 93 | self.modified_times = dict() 94 | 95 | async def on_connect(self): 96 | self.log.info("Connected to gateway!") 97 | 98 | async def on_resumed(self): 99 | self.log.info("Reconnected to gateway!") 100 | 101 | # re-set presence on connection resumed 102 | await self.change_presence() 103 | await self.set_status(activity_text=BOT_ACTIVITY) 104 | 105 | async def on_ready(self): 106 | await self.set_status(activity_text=BOT_ACTIVITY) 107 | self.log.info("Ready! %s", po(self.user)) 108 | 109 | async def set_status( 110 | self, 111 | status: disnake.Status = disnake.Status.online, 112 | activity_text: str = None, 113 | ): 114 | activity = disnake.CustomActivity(name="Custom Status", state=activity_text) 115 | await self.change_presence(activity=activity, status=status) 116 | 117 | async def on_message(self, message): 118 | # ignore DMs and bot accounts 119 | if message.guild is None or message.author.bot: 120 | return 121 | 122 | # don't process commands before bot is ready 123 | if not self.is_ready(): 124 | # rather than wait for the bot to be ready, we return to avoid users 125 | # who send their commands multiple times from being processed. 126 | return 127 | 128 | await self.process_commands(message) 129 | 130 | async def process_commands(self, message: disnake.Message): 131 | if message.author.bot: 132 | return 133 | 134 | ctx: AceContext = await self.get_context(message, cls=AceContext) 135 | 136 | # if messages starts with a bot mention... 137 | if message.content.startswith((self.user.mention, "<@!%s>" % self.user.id)): 138 | # set the bot prefix and invoke the help command 139 | prefixes = await self.prefix_resolver(self, message) 140 | ctx.prefix = prefixes[-1] 141 | command = message.content[message.content.find(">") + 1 :].strip() 142 | await ctx.send_help(command or None) 143 | return 144 | 145 | if ctx.command is None: 146 | return 147 | 148 | perms = ctx.perms 149 | if not perms.send_messages or not perms.read_message_history: 150 | return 151 | 152 | await self.invoke(ctx) 153 | 154 | async def prefix_resolver(self, bot, message): 155 | if message.guild is None: 156 | return DEFAULT_PREFIX 157 | 158 | gc = await self.config.get_entry(message.guild.id) 159 | return gc.prefix or DEFAULT_PREFIX 160 | 161 | def load_extensions(self): 162 | reloaded = list() 163 | 164 | for name in EXTENSIONS: 165 | file_name = name.replace(".", "/") + ".py" 166 | 167 | if os.path.isfile(file_name): 168 | mtime = os.stat(file_name).st_mtime_ns 169 | 170 | if mtime > self.modified_times.get(name, 0): 171 | if name in self.extensions.keys(): 172 | meth = self.reload_extension 173 | else: 174 | meth = self.load_extension 175 | 176 | self.log.debug("Loading %s", name) 177 | 178 | meth(name) 179 | self.modified_times[name] = mtime 180 | 181 | reloaded.append(name) 182 | 183 | return reloaded 184 | 185 | @property 186 | def invite_link(self): 187 | return disnake.utils.oauth_url( 188 | self.user.id, 189 | permissions=disnake.Permissions(1374658358486), 190 | scopes=["bot", "applications.commands"], 191 | ) 192 | 193 | async def login(self, token: str) -> None: 194 | self.log.info("Creating postgres pool") 195 | self.db = await asyncpg.create_pool(DB_BIND) 196 | self.log.info("Loading extensions") 197 | self.load_extensions() 198 | self.log.info("Logging in to discord") 199 | return await super().login(token) 200 | 201 | 202 | if __name__ == "__main__": 203 | import sys 204 | 205 | print("The entry point has moved, use main.py to run the bot now.") 206 | sys.exit(1) 207 | -------------------------------------------------------------------------------- /cogs/backend/error_handler.py: -------------------------------------------------------------------------------- 1 | import disnake 2 | from disnake.ext import commands 3 | 4 | from ace import AceBot 5 | from cogs.mixins import AceMixin 6 | from utils.commanderrorlogic import CommandErrorLogic 7 | from utils.context import AceContext 8 | from utils.time import pretty_seconds 9 | 10 | 11 | class ErrorHandler(commands.Cog, AceMixin): 12 | @commands.Cog.listener("on_command_error") 13 | @commands.Cog.listener("on_slash_command_error") 14 | async def on_invocable_error(self, ctx: AceContext, exc): 15 | """Handle command errors.""" 16 | async with CommandErrorLogic(ctx, exc) as handler: 17 | if isinstance(exc, commands.CommandInvokeError): 18 | if isinstance(exc.original, disnake.HTTPException): 19 | self.bot.log.debug("Command failed with %s", str(exc.original)) 20 | return 21 | 22 | handler.oops() 23 | 24 | elif isinstance(exc, commands.ConversionError): 25 | handler.oops() 26 | 27 | elif isinstance(exc, commands.UserInputError): 28 | handler.set( 29 | title=str(exc), 30 | description="Usage: `{0.prefix}{1.qualified_name} {1.signature}`".format( 31 | ctx, ctx.command 32 | ), 33 | ) 34 | 35 | elif isinstance(exc, commands.DisabledCommand): 36 | handler.set( 37 | description="Sorry, command has been disabled by owner. Try again later!" 38 | ) 39 | 40 | elif isinstance(exc, commands.CommandOnCooldown): 41 | handler.set( 42 | title="You are on cooldown.", 43 | description="Try again in {0}.".format(pretty_seconds(exc.retry_after)), 44 | ) 45 | 46 | elif isinstance(exc, commands.BotMissingPermissions): 47 | handler.set(description=str(exc)) 48 | 49 | elif isinstance(exc, (commands.CheckFailure, commands.CommandNotFound)): 50 | return 51 | 52 | elif isinstance(exc, commands.CommandError): 53 | handler.set(description=str(exc)) 54 | 55 | elif isinstance(exc, disnake.DiscordException): 56 | handler.oops() 57 | 58 | 59 | def setup(bot: AceBot): 60 | bot.add_cog(ErrorHandler(bot)) 61 | -------------------------------------------------------------------------------- /cogs/backend/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Union 3 | 4 | import disnake 5 | from disnake.ext import commands 6 | 7 | from ace import AceBot 8 | from cogs.mixins import AceMixin 9 | from utils.context import AceContext 10 | from utils.string import po 11 | 12 | 13 | class InternalLogger(commands.Cog, AceMixin): 14 | @commands.Cog.listener() 15 | async def on_command(self, ctx: Union[AceContext, disnake.Interaction]): 16 | spl = ctx.message.content.split("\n") 17 | self.bot.log.info( 18 | "%s in %s: %s", 19 | po(ctx.author), 20 | po(ctx.guild), 21 | spl[0] + (" ..." if len(spl) > 1 else ""), 22 | ) 23 | 24 | @commands.Cog.listener() 25 | async def on_command_completion(self, ctx: AceContext): 26 | await self.bot.db.execute( 27 | "INSERT INTO log (guild_id, channel_id, user_id, timestamp, command, type) VALUES ($1, $2, $3, $4, $5, $6)", 28 | ctx.guild.id, 29 | ctx.channel.id, 30 | ctx.author.id, 31 | datetime.utcnow(), 32 | ctx.command.qualified_name, 33 | "PREFIX", 34 | ) 35 | 36 | @commands.Cog.listener() 37 | async def on_slash_command_completion(self, inter: disnake.ApplicationCommandInteraction): 38 | await self.bot.db.execute( 39 | "INSERT INTO log (guild_id, channel_id, user_id, timestamp, command, type) VALUES ($1, $2, $3, $4, $5, $6)", 40 | inter.guild.id, 41 | inter.channel.id, 42 | inter.author.id, 43 | datetime.utcnow(), 44 | inter.application_command.qualified_name, 45 | "APPLICATION", 46 | ) 47 | 48 | 49 | def setup(bot: AceBot): 50 | bot.add_cog(InternalLogger(bot)) 51 | -------------------------------------------------------------------------------- /cogs/configuration.py: -------------------------------------------------------------------------------- 1 | import disnake 2 | from disnake.ext import commands 3 | 4 | from cogs.mixins import AceMixin 5 | from config import DEFAULT_PREFIX 6 | from utils.converters import LengthConverter 7 | 8 | 9 | class PrefixConverter(LengthConverter): 10 | async def convert(self, ctx, argument): 11 | argument = await super().convert(ctx, argument) 12 | 13 | if argument != disnake.utils.escape_markdown(argument): 14 | raise commands.BadArgument("No markdown allowed in prefix.") 15 | 16 | return argument 17 | 18 | 19 | class Configuration(AceMixin, commands.Cog): 20 | """Bot configuration available to administrators and people in the moderator role.""" 21 | 22 | async def cog_check(self, ctx): 23 | return await ctx.is_mod() 24 | 25 | @commands.command() 26 | async def config(self, ctx): 27 | """View current configuration.""" 28 | 29 | gc = await self.bot.config.get_entry(ctx.guild.id) 30 | 31 | e = disnake.Embed(description="Bot configuration.") 32 | e.set_author(name=ctx.guild.name, icon_url=ctx.guild.icon or None) 33 | 34 | mod_role = gc.mod_role 35 | 36 | format_obj = lambda o: "{}\nID: {}".format(o.mention, o.id) 37 | 38 | e.add_field(name="Prefix", value="`{0}`".format(gc.prefix or DEFAULT_PREFIX)) 39 | 40 | e.add_field( 41 | name="Moderation role", 42 | value="None" if mod_role is None else format_obj(mod_role), 43 | ) 44 | 45 | e.set_footer(text="ID: {}".format(ctx.guild.id)) 46 | 47 | await ctx.send(embed=e) 48 | 49 | @commands.command() 50 | async def prefix(self, ctx, *, prefix: PrefixConverter(1, 8) = None): 51 | """Set a guild-specific prefix. Leave argument empty to clear.""" 52 | 53 | gc = await self.bot.config.get_entry(ctx.guild.id) 54 | 55 | await gc.update(prefix=prefix) 56 | 57 | if prefix is None: 58 | data = "Prefix reset to `{0}`".format(DEFAULT_PREFIX) 59 | else: 60 | data = "Prefix set to `{0}`".format(prefix) 61 | 62 | data += "\n\nIf you forget your prefix, or simply need help, just mention the bot!" 63 | 64 | await ctx.send(data) 65 | 66 | @commands.command() 67 | @commands.has_permissions( 68 | administrator=True 69 | ) # only allow administrators to change the moderator role 70 | async def modrole(self, ctx, *, role: disnake.Role = None): 71 | """Set the moderator role. Only modifiable by server administrators. Leave argument empty to clear.""" 72 | 73 | gc = await self.bot.config.get_entry(ctx.guild.id) 74 | 75 | if role is None: 76 | await gc.update(mod_role_id=None) 77 | await ctx.send("Mod role cleared.") 78 | else: 79 | await gc.update(mod_role_id=role.id) 80 | await ctx.send( 81 | f"Mod role has been set to `{role.name}` ({role.id}). " 82 | "Members with this role can configure and manage the bot." 83 | ) 84 | 85 | 86 | def setup(bot): 87 | bot.add_cog(Configuration(bot)) 88 | -------------------------------------------------------------------------------- /cogs/dwitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | from datetime import datetime 4 | from itertools import islice 5 | 6 | import disnake 7 | from disnake.ext import commands 8 | 9 | from cogs.mixins import AceMixin 10 | 11 | 12 | class Dwitter(AceMixin, commands.Cog): 13 | """Commands for the Dwitter server.""" 14 | 15 | def __init__(self, bot): 16 | super().__init__(bot) 17 | 18 | self.url = "https://www.dwitter.net/" 19 | self.guilds = (395956681793863690, 517692823621861407) 20 | 21 | @commands.Cog.listener() 22 | async def on_message(self, message): 23 | if message.guild is None or message.author.bot: 24 | return 25 | 26 | if message.guild.id not in self.guilds: 27 | return 28 | 29 | short = OrderedDict.fromkeys(re.findall(r".?(d/(\d*)).?", message.content)) 30 | 31 | for group in islice(short.keys(), 0, 2): 32 | await self.dwitterlink(message, group[1]) 33 | 34 | async def dwitterlink(self, message, id): 35 | async with self.bot.aiohttp.get(self.url + "api/dweets/" + id) as resp: 36 | if resp.status != 200: 37 | return 38 | 39 | dweet = await resp.json() 40 | 41 | if "link" not in dweet: 42 | return 43 | 44 | e = await self.embeddweet(dweet) 45 | 46 | try: 47 | await message.channel.send(embed=e) 48 | except disnake.HTTPException: 49 | pass 50 | 51 | async def embeddweet(self, dweet): 52 | e = disnake.Embed(description="```js\n{}\n```".format(dweet["code"])) 53 | 54 | e.add_field(name="Awesomes", value=dweet["awesome_count"]) 55 | e.add_field(name="Link", value="[d/{}]({})".format(dweet["id"], dweet["link"])) 56 | 57 | remix_of = dweet["remix_of"] 58 | 59 | if remix_of is not None: 60 | e.add_field( 61 | name="Remix of", 62 | value="[d/{}]({})".format(remix_of, self.url + "d/" + str(remix_of)), 63 | ) 64 | 65 | author = dweet["author"] 66 | e.set_author(name=author["username"], url=author["link"], icon_url=author["avatar"]) 67 | 68 | e.set_footer(text="Posted") 69 | e.timestamp = datetime.strptime(dweet["posted"].split(".")[0], "%Y-%m-%dT%H:%M:%S") 70 | 71 | return e 72 | 73 | 74 | def setup(bot): 75 | bot.add_cog(Dwitter(bot)) 76 | -------------------------------------------------------------------------------- /cogs/games.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import dataclasses 3 | import logging 4 | import string 5 | from datetime import datetime 6 | from enum import Enum 7 | from random import choice, randrange, sample 8 | from typing import Any, Literal, Optional, Union 9 | from urllib.parse import unquote 10 | 11 | import aiohttp 12 | import disnake 13 | from disnake.ext import commands 14 | from rapidfuzz import fuzz, process 15 | 16 | from cogs.mixins import AceMixin 17 | from utils.configtable import ConfigTable 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | REQUEST_FAILED = commands.CommandError("Request failed, try again later.") 22 | 23 | 24 | # TRIVIA CONSTANTS 25 | 26 | 27 | class Difficulty(Enum): 28 | EASY = 1 29 | MEDIUM = 2 30 | HARD = 3 31 | 32 | 33 | TRIVIA_CUSTOM_ID_PREFIX = "trivia:v1:" 34 | 35 | CORRECT_EMOJI = "\N{WHITE HEAVY CHECK MARK}" 36 | WRONG_EMOJI = "\N{CROSS MARK}" 37 | 38 | CORRECT_MESSAGES = ( 39 | "Nice one!", 40 | "That's right!", 41 | "That one was easy, eh?", 42 | "Correct!", 43 | ) 44 | 45 | WRONG_MESSAGES = ( 46 | "Nope!", 47 | "Oof.", 48 | "Yikes.", 49 | "That one was hard.", 50 | "That was wrong!", 51 | "That wasn't easy.", 52 | "Afraid that's wrong!", 53 | "That's incorrect!", 54 | ) 55 | 56 | FOOTER_FORMAT = "Score: {} • You can go again in 5 minutes." 57 | 58 | DIFFICULTY_COLORS = { 59 | Difficulty.EASY: disnake.Color.blue(), 60 | Difficulty.MEDIUM: disnake.Color.from_rgb(212, 212, 35), 61 | Difficulty.HARD: disnake.Color.red(), 62 | } 63 | 64 | API_BASE = "https://opentdb.com/" 65 | API_CATEGORY_LIST_URL = API_BASE + "api_category.php" 66 | API_URL = API_BASE + "api.php" 67 | QUESTION_TIMEOUT = 20.0 68 | 69 | MULTIPLE_MAP = ( 70 | "\N{Digit One}\N{Combining Enclosing Keycap}", 71 | "\N{Digit Two}\N{Combining Enclosing Keycap}", 72 | "\N{Digit Three}\N{Combining Enclosing Keycap}", 73 | "\N{Digit Four}\N{Combining Enclosing Keycap}", 74 | ) 75 | 76 | BOOLEAN_MAP = ( 77 | "\N{REGIONAL INDICATOR SYMBOL LETTER Y}", 78 | "\N{REGIONAL INDICATOR SYMBOL LETTER N}", 79 | ) 80 | 81 | SCORE_POT = {Difficulty.EASY: 400, Difficulty.MEDIUM: 800, Difficulty.HARD: 1200} 82 | 83 | PENALTY_DIV = 2 84 | CATEGORY_PENALTY = 2.5 85 | 86 | # NATO CONSTANTS 87 | 88 | PHONETICS = ( 89 | ("alpha", "alfa"), 90 | "bravo", 91 | "charlie", 92 | "delta", 93 | "echo", 94 | "foxtrot", 95 | "golf", 96 | "hotel", 97 | "india", 98 | ("juliet", "juliett"), 99 | "kilo", 100 | "lima", 101 | "mike", 102 | "november", 103 | "oscar", 104 | "papa", 105 | "quebec", 106 | "romeo", 107 | "sierra", 108 | "tango", 109 | "uniform", 110 | "victor", 111 | ("whiskey", "whisky"), 112 | ("x-ray", "xray"), 113 | "yankee", 114 | "zulu", 115 | ) 116 | 117 | LETTERS = list(string.ascii_lowercase) 118 | NATO = {x[0]: x[1] for x in zip(LETTERS, PHONETICS)} 119 | 120 | 121 | @dataclasses.dataclass 122 | class TriviaQuestion: 123 | type: Literal["multiple", "boolean"] 124 | category: str 125 | question: str 126 | hash: int = dataclasses.field(init=False) 127 | correct_answer: str 128 | difficulty: Difficulty 129 | incorrect_answers: list[str] 130 | _options: Optional[list] = dataclasses.field(init=False, default=None) 131 | _correct_emoji: Optional[str] = dataclasses.field(init=False, default=None) 132 | 133 | def __post_init__(self): 134 | if self.type not in ("multiple", "boolean"): 135 | raise ValueError("Unknown question type: {}".format(self.type)) 136 | self.hash = hash(self.question) 137 | 138 | @classmethod 139 | def from_result(cls, res: dict, *, difficulty: Difficulty) -> "TriviaQuestion": 140 | question = unquote(res["question"]) 141 | return cls( 142 | type=res["type"], 143 | category=unquote(res["category"]), 144 | correct_answer=unquote(res["correct_answer"]), 145 | incorrect_answers=list(unquote(ans) for ans in res["incorrect_answers"]), 146 | question=question, 147 | difficulty=difficulty, 148 | ) 149 | 150 | @property 151 | def options(self) -> list[str]: 152 | if not self._options: 153 | if self.type == "multiple": 154 | options = list(self.incorrect_answers) 155 | correct_pos = randrange(0, len(options) + 1) 156 | options.insert(correct_pos, self.correct_answer) 157 | self._correct_emoji = MULTIPLE_MAP[correct_pos] 158 | 159 | elif self.type == "boolean": 160 | options = ["True", "False"] 161 | self._correct_emoji = BOOLEAN_MAP[int(self.correct_answer == "False")] 162 | else: 163 | raise RuntimeError 164 | self._options = options 165 | return self._options 166 | 167 | @property 168 | def buttons(self) -> list[disnake.ui.Button]: 169 | buttons = [] 170 | longest_option = max(len(opt) for opt in self.options) 171 | for emoji, option in zip(self.option_emojis, self.options): 172 | buttons.append( 173 | disnake.ui.Button( 174 | style=disnake.ButtonStyle.primary, 175 | label=option, 176 | custom_id=TRIVIA_CUSTOM_ID_PREFIX + "ans_choices:" + emoji, 177 | ) 178 | ) 179 | return buttons 180 | 181 | @property 182 | def correct_emoji(self) -> str: 183 | if not self._correct_emoji: 184 | self.options # get the correct emoji 185 | return self._correct_emoji 186 | 187 | @property 188 | def option_emojis(self) -> tuple[str, ...]: 189 | return BOOLEAN_MAP if self.type == "boolean" else MULTIPLE_MAP 190 | 191 | def to_embed(self) -> disnake.Embed: 192 | question_string = "{}\n\n{}\n".format( 193 | self.question, 194 | "\n".join( 195 | "{} {}".format(emoji, option) 196 | for emoji, option in zip(self.option_emojis, self.options) 197 | ), 198 | ) 199 | e = disnake.Embed( 200 | title="Trivia time!", 201 | description="**Category**: {}\n**Difficulty**: {}".format( 202 | self.category, self.difficulty.name.lower() 203 | ), 204 | color=DIFFICULTY_COLORS[self.difficulty], 205 | ) 206 | e.add_field(name="Question", value=self.question, inline=False) 207 | return e 208 | 209 | 210 | class CategoryConverter(commands.Converter): 211 | async def convert(self, ctx, argument): 212 | res, score, junk = process.extractOne( 213 | query=argument, choices=ctx.cog.trivia_categories.keys(), scorer=fuzz.ratio 214 | ) 215 | 216 | if score < 76: 217 | # will never be shown so no need to prettify it 218 | raise ValueError() 219 | 220 | _id = ctx.cog.trivia_categories[res] 221 | return choice(_id) if isinstance(_id, list) else _id 222 | 223 | 224 | class DifficultyConverter(commands.Converter): 225 | async def convert(self, ctx, argument): 226 | name = argument.upper() 227 | 228 | try: 229 | return Difficulty[name] 230 | except KeyError: 231 | pass 232 | 233 | try: 234 | return Difficulty(int(name)) 235 | except ValueError: 236 | cleaner = commands.clean_content(escape_markdown=True) 237 | raise commands.CommandError( 238 | "'{}' is not a valid difficulty.".format(await cleaner.convert(ctx, name)) 239 | ) 240 | 241 | 242 | class Games(AceMixin, commands.Cog): 243 | def __init__(self, bot): 244 | super().__init__(bot) 245 | 246 | self.config = ConfigTable(bot, "trivia", ("guild_id", "user_id")) 247 | 248 | self.trivia_categories: dict[str, Any] = {} 249 | 250 | self.bot.loop.create_task(self.get_trivia_categories()) 251 | 252 | async def get_trivia_categories(self): 253 | try: 254 | async with self.bot.aiohttp.get(API_CATEGORY_LIST_URL) as resp: 255 | if resp.status != 200: 256 | log.info("Failed getting trivia categories, trying again in 10 seconds...") 257 | await asyncio.sleep(10) 258 | asyncio.create_task(self.get_trivia_categories()) 259 | return 260 | 261 | res = await resp.json() 262 | except asyncio.TimeoutError: 263 | return 264 | 265 | categories = dict() 266 | 267 | for category in res["trivia_categories"]: 268 | name = category["name"].lower() 269 | 270 | if ":" in name: 271 | spl = name.split(":") 272 | cat = spl[0].strip() 273 | name = spl[1].strip() 274 | 275 | if cat not in categories: 276 | categories[cat] = list() 277 | 278 | if isinstance(categories[cat], list): 279 | categories[cat].append(category["id"]) 280 | 281 | name = name.replace(" ", "_") 282 | 283 | categories[name] = category["id"] 284 | 285 | categories["anime"] = categories.pop("japanese_anime_&_manga") 286 | categories["science"].append(categories.pop("science_&_nature")) 287 | categories["musicals"] = categories.pop("musicals_&_theatres") 288 | categories["cartoons"] = categories.pop("cartoon_&_animations") 289 | 290 | self.trivia_categories = categories 291 | 292 | async def fetch_question( 293 | self, 294 | *, 295 | difficulty: Difficulty = Difficulty.MEDIUM, 296 | category: Optional[Union[str, list[str]]] = None, 297 | ): 298 | params = dict( 299 | amount=1, 300 | encode="url3986", 301 | difficulty=difficulty.name.lower(), 302 | ) 303 | if category is not None: 304 | params["category"] = choice(category) if category is list else category 305 | try: 306 | async with self.bot.aiohttp.get(API_URL, params=params, raise_for_status=True) as resp: 307 | resp.raise_for_status() 308 | res = await resp.json() 309 | except (TimeoutError, aiohttp.ClientResponseError) as e: 310 | raise REQUEST_FAILED from e 311 | 312 | question = TriviaQuestion.from_result(res["results"][0], difficulty=difficulty) 313 | return question 314 | 315 | @commands.group(invoke_without_command=True, cooldown_after_parsing=True) 316 | @commands.bot_has_permissions(embed_links=True) 317 | @commands.cooldown(rate=2, per=60.0, type=commands.BucketType.member) 318 | async def trivia( 319 | self, 320 | ctx, 321 | category: Optional[CategoryConverter] = None, 322 | *, 323 | difficulty: DifficultyConverter = None, 324 | ): 325 | """Trivia time! Optionally specify a difficulty or category and difficulty as arguments. Valid difficulties are `easy`, `medium` and `hard`. Valid categories can be listed with `trivia categories`.""" 326 | 327 | diff = difficulty 328 | 329 | if diff is None: 330 | diff = choice(list(Difficulty)) 331 | 332 | # if we have a category ID, insert it into the query params for the question request 333 | 334 | try: 335 | question = await self.fetch_question( 336 | difficulty=diff, 337 | category=category, 338 | ) 339 | except Exception: 340 | self.trivia.reset_cooldown(ctx) 341 | raise 342 | 343 | embed = question.to_embed() 344 | 345 | msg = await ctx.send(embed=embed, components=question.buttons) 346 | 347 | now = datetime.utcnow() 348 | 349 | def check(interaction: disnake.MessageInteraction): 350 | return ( 351 | interaction.message.id == msg.id 352 | and interaction.author.id == ctx.author.id 353 | and (custom_id := interaction.data.custom_id).startswith( 354 | TRIVIA_CUSTOM_ID_PREFIX + "ans_choices:" 355 | ) 356 | and custom_id.removeprefix(TRIVIA_CUSTOM_ID_PREFIX + "ans_choices:") 357 | in question.option_emojis 358 | ) 359 | 360 | try: 361 | interaction: disnake.MessageInteraction = await self.bot.wait_for( 362 | "message_interaction", check=check, timeout=QUESTION_TIMEOUT 363 | ) 364 | 365 | answer = interaction.component.custom_id.removeprefix( 366 | TRIVIA_CUSTOM_ID_PREFIX + "ans_choices:" 367 | ) 368 | 369 | answered_at = datetime.utcnow() 370 | score = self._calculate_score(SCORE_POT[diff], answered_at - now) 371 | 372 | if answer == question.correct_emoji: 373 | # apply penalty if category was specified 374 | if category: 375 | score = int(score / CATEGORY_PENALTY) 376 | 377 | current_score = await self._on_correct(ctx, answered_at, question.hash, score) 378 | 379 | e = disnake.Embed( 380 | title="{} {}".format(CORRECT_EMOJI, choice(CORRECT_MESSAGES)), 381 | description="You gained {} points.".format(score), 382 | color=disnake.Color.green(), 383 | ) 384 | 385 | # make the correct answer green and disable the buttons 386 | components = disnake.ui.ActionRow.rows_from_message(msg) 387 | for row in components: 388 | for component in row: 389 | component.disabled = True 390 | if component.custom_id == interaction.component.custom_id: 391 | component.style = disnake.ButtonStyle.green 392 | await interaction.response.edit_message(components=components) 393 | 394 | await interaction.followup.send(embed=e) 395 | else: 396 | score = int(score / PENALTY_DIV) 397 | current_score = await self._on_wrong(ctx, answered_at, question.hash, score) 398 | 399 | e = disnake.Embed( 400 | title="{} {}".format(WRONG_EMOJI, choice(WRONG_MESSAGES)), 401 | description="You lost {} points.".format(score), 402 | color=disnake.Color.red(), 403 | ) 404 | 405 | if question.type == "multiple": 406 | e.description += "\nThe correct answer is ***`{}`***".format( 407 | question.correct_answer 408 | ) 409 | 410 | # make the correct answer green, the guessed answer red, and disable the buttons 411 | components = disnake.ui.ActionRow.rows_from_message(msg) 412 | for row in components: 413 | for component in row: 414 | component.disabled = True 415 | if not isinstance(component, disnake.ui.Button): 416 | continue 417 | if component.custom_id == interaction.component.custom_id: 418 | component.style = disnake.ButtonStyle.red 419 | elif ( 420 | component.custom_id.removeprefix( 421 | TRIVIA_CUSTOM_ID_PREFIX + "ans_choices:" 422 | ) 423 | == question.correct_emoji 424 | ): 425 | component.style = disnake.ButtonStyle.green 426 | 427 | await interaction.response.edit_message(components=components) 428 | 429 | await interaction.followup.send(embed=e) 430 | 431 | except asyncio.TimeoutError: 432 | score = int(SCORE_POT[diff] / 4) 433 | answered_at = datetime.utcnow() 434 | 435 | components = disnake.ui.ActionRow.rows_from_message(msg) 436 | for row in components: 437 | for component in row: 438 | if isinstance(component, disnake.ui.Button): 439 | component.disabled = True 440 | component.style = disnake.ButtonStyle.gray 441 | try: 442 | await msg.edit(components=components) 443 | except disnake.HTTPException: 444 | pass 445 | 446 | await msg.reply( 447 | "Question timed out and you lost {} points. Answer within {} seconds next time!".format( 448 | score, int(QUESTION_TIMEOUT) 449 | ), 450 | fail_if_not_exists=False, 451 | ) 452 | 453 | await self._on_wrong(ctx, answered_at, question.hash, score) 454 | 455 | def _calculate_score(self, pot, time_spent): 456 | time_div = QUESTION_TIMEOUT - time_spent.total_seconds() / 2 457 | points = pot * time_div / QUESTION_TIMEOUT 458 | return int(points) 459 | 460 | async def _on_correct(self, ctx, answered_at, question_hash, add_score): 461 | entry = await self.config.get_entry(ctx.guild.id, ctx.author.id) 462 | 463 | await entry.update(score=entry.score + add_score, correct_count=entry.correct_count + 1) 464 | await self._insert_question(ctx, answered_at, question_hash, True) 465 | 466 | return entry.score 467 | 468 | async def _on_wrong(self, ctx, answered_at, question_hash, remove_score): 469 | entry = await self.config.get_entry(ctx.guild.id, ctx.author.id) 470 | 471 | await entry.update(score=entry.score - remove_score, wrong_count=entry.wrong_count + 1) 472 | await self._insert_question(ctx, answered_at, question_hash, False) 473 | 474 | return entry.score 475 | 476 | async def _insert_question(self, ctx, answered_at, question_hash, result): 477 | await self.db.execute( 478 | "INSERT INTO trivia_stats (guild_id, user_id, timestamp, question_hash, result) VALUES ($1, $2, $3, $4, $5)", 479 | ctx.guild.id, 480 | ctx.author.id, 481 | answered_at, 482 | question_hash, 483 | result, 484 | ) 485 | 486 | @trivia.command() 487 | @commands.bot_has_permissions(embed_links=True) 488 | async def categories(self, ctx): 489 | """Get a list of valid categories for the trivia command.""" 490 | 491 | e = disnake.Embed(description="\n".join(self.trivia_categories.keys())) 492 | e.set_footer(text="Specifying a category halves your winnings.") 493 | 494 | await ctx.send(embed=e) 495 | 496 | @trivia.command() 497 | @commands.bot_has_permissions(embed_links=True) 498 | async def stats(self, ctx, *, member: disnake.Member = None): 499 | """Get your own or another members' trivia stats.""" 500 | 501 | member = member or ctx.author 502 | 503 | entry = await self.config.get_entry(ctx.guild.id, member.id) 504 | 505 | total_games = entry.correct_count + entry.wrong_count 506 | 507 | if total_games == 0: 508 | win_rate = 0 509 | else: 510 | win_rate = int(entry.correct_count / total_games * 100) 511 | 512 | e = disnake.Embed() 513 | 514 | e.set_author(name=member.display_name, icon_url=member.display_avatar.url) 515 | 516 | e.add_field(name="Score", value=str(entry.score)) 517 | e.add_field(name="Correct", value="{} games".format(str(entry.correct_count))) 518 | e.add_field(name="Wrong", value="{} games".format(str(entry.wrong_count))) 519 | e.add_field(name="Games played", value="{} games".format(str(total_games))) 520 | e.add_field(name="Correct percentage", value="{}%".format(str(win_rate))) 521 | 522 | await ctx.send(embed=e) 523 | 524 | @trivia.command() 525 | @commands.bot_has_permissions(embed_links=True) 526 | async def ranks(self, ctx): 527 | """See trivia leaderboard.""" 528 | 529 | leaders = await self.db.fetch( 530 | "SELECT * FROM trivia WHERE guild_id=$1 ORDER BY score DESC LIMIT 8", 531 | ctx.guild.id, 532 | ) 533 | 534 | e = disnake.Embed(title="Trivia leaderboard", color=DIFFICULTY_COLORS[Difficulty.MEDIUM]) 535 | 536 | mentions = "\n".join("<@{}>".format(leader.get("user_id")) for leader in leaders) 537 | scores = "\n".join(str(leader.get("score")) for leader in leaders) 538 | 539 | e.add_field(name="User", value=mentions) 540 | e.add_field(name="Score", value=scores) 541 | 542 | await ctx.send(embed=e) 543 | 544 | @commands.command() 545 | async def nato(self, ctx, count: int = 3): 546 | """Learn the NATO phonetic alphabet.""" 547 | 548 | if count < 1: 549 | raise commands.CommandError("Please pick a length larger than 0.") 550 | 551 | if count > 16: 552 | raise commands.CommandError("Sorry, please pick lengths lower or equal to 16.") 553 | 554 | lets = sample(LETTERS, k=count) 555 | 556 | await ctx.send(f'**{"".join(lets).upper()}**?') 557 | 558 | def check(m): 559 | return m.channel == ctx.channel and m.author == ctx.author 560 | 561 | try: 562 | msg = await self.bot.wait_for("message", check=check, timeout=60.0) 563 | except asyncio.TimeoutError: 564 | await ctx.send(f"Sorry {ctx.author.mention}, time ran out!") 565 | return 566 | 567 | answer = msg.content.lower().split() 568 | 569 | async def failed(): 570 | right = [] 571 | for let in lets: 572 | asd = NATO[let] 573 | right.append(asd[0] if isinstance(asd, tuple) else asd) 574 | await ctx.send( 575 | f'Sorry, that was wrong! The correct answer was `{" ".join(right).upper()}`' 576 | ) 577 | 578 | if len(answer) != len(lets): 579 | return await failed() 580 | 581 | for index, part in enumerate(answer): 582 | answer = NATO[lets[index]] 583 | if isinstance(answer, tuple): 584 | if part not in answer: 585 | return await failed() 586 | else: 587 | if part != answer: 588 | return await failed() 589 | 590 | await ctx.send("Correct! ✅") 591 | 592 | 593 | def setup(bot): 594 | bot.add_cog(Games(bot)) 595 | -------------------------------------------------------------------------------- /cogs/hl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import disnake 5 | from disnake.ext import commands 6 | 7 | from cogs.mixins import AceMixin 8 | from ids import AHK_GUILD_ID 9 | from utils.context import is_mod 10 | from utils.converters import LengthConverter 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | DELETE_EMOJI = "\N{Put Litter in Its Place Symbol}" 15 | DEFAULT_LANG = "py" 16 | 17 | 18 | class LangConverter(LengthConverter): 19 | async def convert(self, ctx, argument): 20 | argument = await super().convert(ctx, argument) 21 | 22 | if argument != disnake.utils.escape_markdown(argument): 23 | raise commands.BadArgument("No markdown allowed in the codebox language.") 24 | 25 | return argument 26 | 27 | 28 | lang_converter = LangConverter(1, 32) 29 | 30 | 31 | class Highlighter(AceMixin, commands.Cog): 32 | """Create highlighted code-boxes with one command.""" 33 | 34 | @commands.command(aliases=["h1"]) 35 | @commands.bot_has_permissions(manage_messages=True, add_reactions=True) 36 | async def hl(self, ctx, *, code): 37 | """Highlight some code.""" 38 | 39 | await ctx.message.delete() 40 | 41 | # include spaces/tabs at the beginning 42 | code = ctx.message.content[len(ctx.prefix) + 3 :] 43 | 44 | # don't allow three backticks in a row, alternative is to throw error upon this case 45 | code = code.replace("``", "`\u200b`") 46 | 47 | # replace triple+ newlines with double newlines 48 | code = re.sub("\n\n+", "\n\n", code) 49 | 50 | # trim start and finish 51 | code = code.strip() 52 | 53 | # get the language this user should use 54 | lang = ( 55 | await self.db.fetchval( 56 | "SELECT lang FROM highlight_lang WHERE guild_id=$1 AND (user_id=$2 OR user_id=$3)", 57 | ctx.guild.id, 58 | 0, 59 | ctx.author.id, 60 | ) 61 | or DEFAULT_LANG 62 | ) 63 | 64 | code = "```{}\n{}\n```".format(lang, code) 65 | code += "*Paste by {0} - Click {1} to delete.*".format(ctx.author.mention, DELETE_EMOJI) 66 | 67 | if len(code) > 2000: 68 | raise commands.CommandError("Code contents too long to paste.") 69 | 70 | ar = disnake.ui.ActionRow() 71 | ar.add_button( 72 | 0, 73 | style=disnake.ButtonStyle.secondary, 74 | label="🗑️", 75 | custom_id=f"hldeletebutton_{ctx.author.id}", 76 | ) 77 | await ctx.send(code, components=[ar]) 78 | 79 | @commands.Cog.listener() 80 | async def on_button_click(self, inter: disnake.MessageInteraction): 81 | if inter.guild_id is None: 82 | return 83 | 84 | if not inter.component.custom_id.startswith("hldeletebutton"): 85 | return 86 | 87 | try: 88 | author_id: str = int(inter.component.custom_id.split("_")[1]) 89 | except ValueError: 90 | return 91 | 92 | if author_id != inter.author.id: 93 | await inter.response.send_message( 94 | "Sorry, this button is not for you!", 95 | ephemeral=True, 96 | delete_after=12, 97 | ) 98 | return 99 | 100 | await inter.message.delete() 101 | 102 | @commands.command() 103 | @commands.bot_has_permissions(embed_links=True) 104 | async def lang(self, ctx, *, language: lang_converter = None): 105 | """Set your preferred highlighting language in this server.""" 106 | 107 | if language is None: 108 | server_lang = await self.db.fetchval( 109 | "SELECT lang FROM highlight_lang WHERE guild_id=$1 AND user_id=$2", 110 | ctx.guild.id, 111 | 0, 112 | ) 113 | 114 | user_lang = await self.db.fetchval( 115 | "SELECT lang FROM highlight_lang WHERE guild_id=$1 AND user_id=$2", 116 | ctx.guild.id, 117 | ctx.author.id, 118 | ) 119 | 120 | e = disnake.Embed(description="Do `.lang clear` to clear preference.") 121 | 122 | e.add_field( 123 | name="Server setting", 124 | value=f'`{DEFAULT_LANG + " (default)" if server_lang is None else server_lang}`', 125 | ) 126 | 127 | e.add_field( 128 | name="Personal setting", 129 | value="Not set" if user_lang is None else f"`{user_lang}`", 130 | ) 131 | 132 | await ctx.send(embed=e) 133 | return 134 | 135 | if language == "clear": 136 | ret = await self.db.execute( 137 | "DELETE FROM highlight_lang WHERE guild_id=$1 AND user_id=$2", 138 | ctx.guild.id, 139 | ctx.author.id, 140 | ) 141 | 142 | await ctx.send( 143 | "No preference previously set" if ret == "DELETE 0" else "Preference cleared." 144 | ) 145 | else: 146 | await self.db.execute( 147 | "INSERT INTO highlight_lang (guild_id, user_id, lang) VALUES ($1, $2, $3) ON CONFLICT " 148 | "(guild_id, user_id) DO UPDATE SET lang=$3", 149 | ctx.guild.id, 150 | ctx.author.id, 151 | language, 152 | ) 153 | 154 | await ctx.send(f"Set your specific highlighting language to '{language}'.") 155 | 156 | @commands.command(aliases=["guildlang"]) 157 | @is_mod() 158 | async def serverlang(self, ctx, *, language: lang_converter): 159 | """Set a guild-specific highlighting language. Can be overridden individually by users.""" 160 | 161 | if language == "clear": 162 | ret = await self.db.execute( 163 | "DELETE FROM highlight_lang WHERE guild_id=$1 AND user_id=$2", 164 | ctx.guild.id, 165 | 0, 166 | ) 167 | 168 | await ctx.send( 169 | "No preference previously set" if ret == "DELETE 0" else "Preference cleared." 170 | ) 171 | else: 172 | await self.db.execute( 173 | "INSERT INTO highlight_lang (guild_id, user_id, lang) VALUES ($1, $2, $3) ON CONFLICT " 174 | "(guild_id, user_id) DO UPDATE SET lang=$3", 175 | ctx.guild.id, 176 | 0, 177 | language, 178 | ) 179 | 180 | await ctx.send(f"Set server-specific highlighting language to '{language}'.") 181 | 182 | @commands.command(aliases=["p"], hidden=True) 183 | async def paste(self, ctx): 184 | """Legacy, not removed because some people still use it instead of the newer tags in the tag system.""" 185 | 186 | msg = "To paste code snippets directly into the chat, use the highlight command:\n```.hl *paste code here*```" 187 | 188 | if ctx.guild.id == AHK_GUILD_ID: 189 | msg += ( 190 | "\nIf you have a larger script you want to share, paste it to the AutoHotkey pastebin instead:\n" 191 | "https://p.autohotkey.com/" 192 | ) 193 | 194 | await ctx.send(msg) 195 | 196 | 197 | def setup(bot): 198 | bot.add_cog(Highlighter(bot)) 199 | -------------------------------------------------------------------------------- /cogs/meta.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from datetime import datetime, timedelta, timezone 3 | from itertools import islice 4 | from os import getcwd 5 | from pathlib import Path 6 | 7 | import disnake 8 | import psutil 9 | from disnake.ext import commands 10 | from pygit2 import GIT_SORT_TOPOLOGICAL, GIT_STATUS_IGNORED, Repository, GitError 11 | 12 | from cogs.mixins import AceMixin 13 | from utils.context import AceContext 14 | from utils.converters import MaybeMemberConverter 15 | from utils.string import yesno 16 | from utils.time import pretty_datetime, pretty_timedelta 17 | 18 | GITHUB_LINK = "https://github.com/Run1e/AceBot" 19 | GITHUB_BRANCH = "master" 20 | COULD_NOT_FIND = commands.CommandError("Couldn't find command.") 21 | 22 | MEDALS = ( 23 | "\N{FIRST PLACE MEDAL}", 24 | "\N{SECOND PLACE MEDAL}", 25 | "\N{THIRD PLACE MEDAL}", 26 | "\N{SPORTS MEDAL}", 27 | "\N{SPORTS MEDAL}", 28 | ) 29 | 30 | 31 | class Meta(AceMixin, commands.Cog): 32 | """Commands about the bot itself.""" 33 | 34 | def __init__(self, bot): 35 | super().__init__(bot) 36 | 37 | try: 38 | self.repo = Repository(".") 39 | except GitError: 40 | self.repo = None 41 | 42 | self.process = psutil.Process() 43 | 44 | # no blockerino so we do this here in init 45 | self.process.cpu_percent() 46 | 47 | @commands.command(aliases=["join"]) 48 | async def invite(self, ctx): 49 | """Get bot invite link.""" 50 | 51 | await ctx.send(self.bot.invite_link) 52 | 53 | @commands.command() 54 | @commands.bot_has_permissions(embed_links=True) 55 | async def stats(self, ctx, member: MaybeMemberConverter = None): 56 | """Show bot or user command stats.""" 57 | 58 | if member is None: 59 | await self._stats_guild(ctx) 60 | else: 61 | await self._stats_member(ctx, member) 62 | 63 | async def _stats_member(self, ctx, member): 64 | past_day = datetime.utcnow() - timedelta(days=1) 65 | 66 | first_command = await self.db.fetchval( 67 | "SELECT timestamp FROM log WHERE guild_id=$1 AND user_id=$2 LIMIT 1", 68 | ctx.guild.id, 69 | member.id, 70 | ) 71 | 72 | total_uses = await self.db.fetchval( 73 | "SELECT COUNT(id) FROM log WHERE guild_id=$1 AND user_id=$2", 74 | ctx.guild.id, 75 | member.id, 76 | ) 77 | 78 | commands_alltime = await self.db.fetch( 79 | "SELECT COUNT(id), command FROM log WHERE guild_id=$1 AND user_id=$2 GROUP BY command " 80 | "ORDER BY COUNT DESC LIMIT 5", 81 | ctx.guild.id, 82 | member.id, 83 | ) 84 | 85 | commands_today = await self.db.fetch( 86 | "SELECT COUNT(id), command FROM log WHERE guild_id=$1 AND user_id=$2 AND timestamp > $3 " 87 | "GROUP BY command ORDER BY COUNT DESC LIMIT 5", 88 | ctx.guild.id, 89 | member.id, 90 | past_day, 91 | ) 92 | 93 | e = disnake.Embed() 94 | e.set_author(name=member.name, icon_url=member.display_avatar.url) 95 | e.add_field(name="Top Commands", value=self._stats_craft_list(commands_alltime)) 96 | e.add_field(name="Top Commands Today", value=self._stats_craft_list(commands_today)) 97 | 98 | self._stats_embed_fill(e, total_uses, first_command) 99 | 100 | await ctx.send(embed=e) 101 | 102 | async def _stats_guild(self, ctx): 103 | past_day = datetime.utcnow() - timedelta(days=1) 104 | total_uses = await self.db.fetchval( 105 | "SELECT COUNT(id) FROM log WHERE guild_id=$1", ctx.guild.id 106 | ) 107 | 108 | first_command = await self.db.fetchval( 109 | "SELECT timestamp FROM log WHERE guild_id=$1 LIMIT 1", ctx.guild.id 110 | ) 111 | 112 | commands_today = await self.db.fetch( 113 | "SELECT COUNT(id), command FROM log WHERE guild_id=$1 AND timestamp > $2 GROUP BY command " 114 | "ORDER BY COUNT DESC LIMIT 5", 115 | ctx.guild.id, 116 | past_day, 117 | ) 118 | 119 | commands_alltime = await self.db.fetch( 120 | "SELECT COUNT(id), command FROM log WHERE guild_id=$1 GROUP BY command ORDER BY COUNT DESC LIMIT 5", 121 | ctx.guild.id, 122 | ) 123 | 124 | users_today = await self.db.fetch( 125 | "SELECT COUNT(id), user_id FROM log WHERE guild_id=$1 AND timestamp > $2 GROUP BY user_id " 126 | "ORDER BY COUNT DESC LIMIT 5", 127 | ctx.guild.id, 128 | past_day, 129 | ) 130 | 131 | users_alltime = await self.db.fetch( 132 | "SELECT COUNT(id), user_id FROM log WHERE guild_id=$1 GROUP BY user_id " 133 | "ORDER BY COUNT DESC LIMIT 5", 134 | ctx.guild.id, 135 | ) 136 | 137 | e = disnake.Embed() 138 | e.set_author(name=ctx.guild.name, icon_url=ctx.guild.icon or None) 139 | e.add_field(name="Top Commands", value=self._stats_craft_list(commands_alltime)) 140 | e.add_field(name="Top Commands Today", value=self._stats_craft_list(commands_today)) 141 | 142 | e.add_field( 143 | name="Top Users", 144 | value=self._stats_craft_list( 145 | users_alltime, [f"<@{user_id}>" for _, user_id in users_alltime] 146 | ), 147 | ) 148 | 149 | e.add_field( 150 | name="Top Users Today", 151 | value=self._stats_craft_list( 152 | users_today, [f"<@{user_id}>" for _, user_id in users_today] 153 | ), 154 | ) 155 | 156 | self._stats_embed_fill(e, total_uses, first_command) 157 | 158 | await ctx.send(embed=e) 159 | 160 | def _stats_embed_fill(self, e, total_uses, first_command): 161 | e.description = f"{total_uses} total commands issued." 162 | if first_command is not None: 163 | e.timestamp = first_command 164 | e.set_footer(text="First command invoked") 165 | 166 | def _stats_craft_list(self, cmds, members=None): 167 | value = "" 168 | for index, cmd in enumerate(cmds): 169 | value += f"\n{MEDALS[index]} {members[index] if members else cmd[1]} ({cmd[0]} uses)" 170 | 171 | if not len(value): 172 | return "None so far!" 173 | 174 | return value[1:] 175 | 176 | def format_commit(self, commit): 177 | short, _, _ = commit.message.partition("\n") 178 | short_sha2 = commit.hex[0:6] 179 | tz = timezone(timedelta(minutes=commit.commit_time_offset)) 180 | time = datetime.fromtimestamp(commit.commit_time).replace(tzinfo=tz) 181 | offset = pretty_datetime( 182 | time.astimezone(timezone.utc).replace(tzinfo=None), ignore_time=True 183 | ) 184 | return f"[`{short_sha2}`]({GITHUB_LINK}/commit/{commit.hex}) {short} ({offset})" 185 | 186 | def get_last_commits(self, count=3): 187 | if self.repo is None: 188 | return "-" 189 | 190 | return "\n".join( 191 | self.format_commit(c) 192 | for c in islice(self.repo.walk(self.repo.head.target, GIT_SORT_TOPOLOGICAL), count) 193 | ) 194 | 195 | @commands.command() 196 | @commands.bot_has_permissions(embed_links=True) 197 | async def about(self, ctx, *, command: str = None): 198 | """Show info about the bot or a command.""" 199 | 200 | if command is None: 201 | await self._about_bot(ctx) 202 | else: 203 | cmd = self.bot.get_command(command) 204 | if cmd is None or cmd.hidden: 205 | raise commands.CommandError("No command with that name found.") 206 | await self._about_command(ctx, cmd) 207 | 208 | async def _about_bot(self, ctx): 209 | e = disnake.Embed( 210 | title="Click here to add the bot to your own server!", 211 | description=f"{self.get_last_commits()}\n\n[Support server here!]({self.bot.support_link})", 212 | url=self.bot.invite_link, 213 | ) 214 | 215 | owner = self.bot.get_user(self.bot.owner_id) 216 | e.set_author(name=str(owner), icon_url=owner.display_avatar.url) 217 | 218 | e.add_field(name="Developer", value=str(self.bot.get_user(self.bot.owner_id))) 219 | 220 | invokes = await self.db.fetchval("SELECT COUNT(*) FROM log") 221 | e.add_field(name="Command invokes", value="{0:,d}".format(invokes)) 222 | 223 | guilds, text, voice, users = 0, 0, 0, 0 224 | 225 | for guild in self.bot.guilds: 226 | guilds += 1 227 | users += len(guild.members) 228 | for channel in guild.channels: 229 | if isinstance(channel, disnake.TextChannel): 230 | text += 1 231 | elif isinstance(channel, disnake.VoiceChannel): 232 | voice += 1 233 | 234 | unique = len(self.bot.users) 235 | 236 | e.add_field(name="Servers", value=str(guilds)) 237 | 238 | memory_usage = self.process.memory_full_info().uss / 1024**2 239 | cpu_usage = self.process.cpu_percent() / psutil.cpu_count() 240 | 241 | e.add_field( 242 | name="Process", 243 | value="CPU: {0:.2f}%\nMemory: {1:.2f} MiB".format(cpu_usage, memory_usage), 244 | ) 245 | 246 | e.add_field(name="Members", value="{0:,d} total\n{1:,d} unique".format(users, unique)) 247 | e.add_field( 248 | name="Channels", 249 | value="{0:,d} total\n{1:,d} text channels\n{2:,d} voice channels".format( 250 | text + voice, text, voice 251 | ), 252 | ) 253 | 254 | now = datetime.utcnow() 255 | e.set_footer( 256 | text="Last restart {0} ago".format(pretty_timedelta(now - self.bot.startup_time)) 257 | ) 258 | 259 | await ctx.send(embed=e) 260 | 261 | async def _about_command(self, ctx, command: commands.Command): 262 | e = disnake.Embed( 263 | title=command.qualified_name + " " + command.signature, 264 | description=command.description or command.help, 265 | ) 266 | 267 | e.add_field(name="Qualified name", value=command.qualified_name) 268 | 269 | try: 270 | can_run = await command.can_run(ctx) 271 | except commands.CommandError: 272 | can_run = False 273 | 274 | e.add_field(name="Can you run it?", value=yesno(can_run)) 275 | 276 | e.add_field(name="Enabled", value=yesno(command.enabled)) 277 | 278 | invokes = await self.db.fetchval( 279 | "SELECT COUNT(*) FROM log WHERE command=$1", command.qualified_name 280 | ) 281 | e.add_field(name="Total invokes", value="{0:,d}".format(invokes)) 282 | 283 | here_invokes = await self.db.fetchval( 284 | "SELECT COUNT(*) FROM log WHERE command=$1 AND guild_id=$2", 285 | command.qualified_name, 286 | ctx.guild.id, 287 | ) 288 | e.add_field(name="Invokes in this server", value="{0:,d}".format(here_invokes)) 289 | 290 | if command.aliases: 291 | e.set_footer(text="Also known as: " + ", ".join(command.aliases)) 292 | 293 | await ctx.send(embed=e) 294 | 295 | @commands.command(aliases=["fb"]) 296 | @commands.cooldown(rate=2, per=120.0, type=commands.BucketType.user) 297 | async def feedback(self, ctx, *, feedback: str): 298 | """Give me some feedback about the bot!""" 299 | 300 | with open( 301 | "data/feedback/{}.feedback".format(str(ctx.message.id)), 302 | "w", 303 | encoding="utf-8-sig", 304 | ) as f: 305 | f.write(ctx.stamp + "\n\n" + feedback) 306 | 307 | await ctx.send("Feedback sent. Thanks for helping improve the bot!") 308 | 309 | @commands.command() 310 | async def support(self, ctx): 311 | """Get link to support server.""" 312 | 313 | await ctx.send(self.bot.support_link) 314 | 315 | @commands.command(aliases=["source"]) 316 | async def code(self, ctx: AceContext, *, command: str = None): 317 | """Get a github link to the source code of a command.""" 318 | 319 | if command is None: 320 | await ctx.send(GITHUB_LINK) 321 | return 322 | 323 | cmd: commands.Command = self.bot.get_command(command) 324 | 325 | # not a command 326 | if cmd is None: 327 | raise COULD_NOT_FIND 328 | 329 | callback = cmd.callback 330 | source_file = Path(inspect.getsourcefile(callback)).relative_to(getcwd()) 331 | 332 | if self.repo is not None: 333 | # pygit2 expects forward slashes which breaks on windows with pathlib 334 | source_file_str = str(source_file).replace("\\", "/") 335 | 336 | # not in repo 337 | try: 338 | file_status = self.repo.status_file(source_file_str) 339 | except KeyError: 340 | raise COULD_NOT_FIND 341 | 342 | # ignored 343 | if file_status & GIT_STATUS_IGNORED: 344 | raise COULD_NOT_FIND 345 | 346 | lines, first_line_no = inspect.getsourcelines(callback) 347 | await ctx.send( 348 | f"<{GITHUB_LINK}/blob/{GITHUB_BRANCH}/{source_file}#L{first_line_no}-L{first_line_no + len(lines) - 1}>" 349 | ) 350 | 351 | 352 | def setup(bot): 353 | bot.add_cog(Meta(bot)) 354 | -------------------------------------------------------------------------------- /cogs/mixins.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from ace import AceBot 4 | 5 | 6 | class AceMixin: 7 | def __init__(self, bot: AceBot): 8 | self.bot: AceBot = bot 9 | 10 | @property 11 | def db(self): 12 | return self.bot.db 13 | -------------------------------------------------------------------------------- /cogs/owner.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import copy 3 | import io 4 | import logging 5 | import textwrap 6 | import traceback 7 | from collections import Counter 8 | from contextlib import redirect_stdout 9 | from datetime import datetime, timedelta 10 | from urllib.parse import urlparse 11 | 12 | import disnake 13 | from bs4 import BeautifulSoup 14 | from disnake.ext import commands 15 | from disnake.mixins import Hashable 16 | from tabulate import tabulate 17 | 18 | from cogs.mixins import AceMixin 19 | from config import BOT_ACTIVITY 20 | from utils.context import AceContext 21 | from utils.converters import MaxValueConverter 22 | from utils.pager import Pager 23 | from utils.string import shorten 24 | from utils.time import pretty_datetime, pretty_timedelta 25 | 26 | log = logging.getLogger(__name__) 27 | 28 | 29 | class Owner(AceMixin, commands.Cog): 30 | """Commands accessible only to the bot owner.""" 31 | 32 | def __init__(self, bot): 33 | super().__init__(bot) 34 | 35 | self.help_cog = bot.get_cog("AutoHotkeyHelpSystem") 36 | self.event_counter = Counter() 37 | 38 | async def cog_check(self, ctx): 39 | return await self.bot.is_owner(ctx.author) 40 | 41 | def cleanup_code(self, content): 42 | """Automatically removes code blocks from the code.""" 43 | 44 | # remove ```py\n``` 45 | if content.startswith("```") and content.endswith("```"): 46 | return "\n".join(content.split("\n")[1:-1]) 47 | 48 | # remove `foo` 49 | return content.strip("` \n") 50 | 51 | @commands.Cog.listener() 52 | async def on_socket_event_type(self, event_type): 53 | self.event_counter[event_type] += 1 54 | 55 | @commands.command(hidden=True) 56 | async def prompt(self, ctx: AceContext, user: disnake.Member = None): 57 | result = await ctx.prompt(user_override=user) 58 | await ctx.send(result) 59 | 60 | @commands.command(hidden=True) 61 | async def adminprompt(self, ctx: AceContext): 62 | result = await ctx.admin_prompt() 63 | await ctx.send(result) 64 | 65 | @commands.command() 66 | async def eval(self, ctx, *, body: str): 67 | """Evaluates some code.""" 68 | 69 | from pprint import pprint 70 | 71 | from tabulate import tabulate 72 | 73 | env = { 74 | "disnake.": disnake, 75 | "bot": self.bot, 76 | "ctx": ctx, 77 | "channel": ctx.channel, 78 | "author": ctx.author, 79 | "guild": ctx.guild, 80 | "message": ctx.message, 81 | "pprint": pprint, 82 | "tabulate": tabulate, 83 | "db": self.db, 84 | } 85 | 86 | env.update(globals()) 87 | 88 | body = self.cleanup_code(body) 89 | stdout = io.StringIO() 90 | 91 | to_compile = f'async def func():\n{textwrap.indent(body, " ")}' 92 | 93 | try: 94 | exec(to_compile, env) 95 | except Exception as e: 96 | return await ctx.send(f"```py\n{e.__class__.__name__}: {e}\n```") 97 | 98 | func = env["func"] 99 | try: 100 | with redirect_stdout(stdout): 101 | ret = await func() 102 | except Exception as e: 103 | value = stdout.getvalue() 104 | await ctx.send(f"```py\n{value}{traceback.format_exc()}\n```") 105 | else: 106 | value = stdout.getvalue() 107 | try: 108 | await ctx.message.add_reaction("\u2705") 109 | except: 110 | pass 111 | 112 | if ret is None: 113 | if value: 114 | if len(value) > 1990: 115 | fp = io.BytesIO(value.encode("utf-8")) 116 | await ctx.send("Log too large...", file=disnake.File(fp, "results.txt")) 117 | else: 118 | await ctx.send(f"```py\n{value}\n```") 119 | 120 | @commands.command() 121 | async def sql(self, ctx, *, query: str): 122 | """Execute a SQL query.""" 123 | 124 | try: 125 | result = await self.db.fetch(query) 126 | except Exception as exc: 127 | raise commands.CommandError(str(exc)) 128 | 129 | if not len(result): 130 | await ctx.send("No rows returned.") 131 | return 132 | 133 | table = tabulate(result, {header: header for header in result[0].keys()}) 134 | 135 | if len(table) > 1994: 136 | fp = io.BytesIO(table.encode("utf-8")) 137 | await ctx.send("Too many results...", file=disnake.File(fp, "results.txt")) 138 | else: 139 | await ctx.send("```" + table + "```") 140 | 141 | @commands.command() 142 | async def gateway(self, ctx, *, n=None): 143 | """Print gateway event counters.""" 144 | 145 | table = tabulate( 146 | tabular_data=[ 147 | (name, format(count, ",d")) for name, count in self.event_counter.most_common(n) 148 | ], 149 | headers=("Event", "Count"), 150 | ) 151 | 152 | paginator = commands.Paginator() 153 | for line in table.split("\n"): 154 | paginator.add_line(line) 155 | 156 | for page in paginator.pages: 157 | await ctx.send(page) 158 | 159 | @commands.command() 160 | async def ping(self, ctx): 161 | """Check response time.""" 162 | 163 | msg = await ctx.send("Wait...") 164 | 165 | await msg.edit( 166 | content="Response: {}.\nGateway: {}".format( 167 | pretty_timedelta(msg.created_at - ctx.message.created_at), 168 | pretty_timedelta(timedelta(seconds=self.bot.latency)), 169 | ) 170 | ) 171 | 172 | @commands.command() 173 | async def repeat(self, ctx, repeats: int, *, command): 174 | """Repeat a command.""" 175 | 176 | if repeats < 1: 177 | raise commands.CommandError("Repeat count must be more than 0.") 178 | 179 | msg = copy.copy(ctx.message) 180 | msg.content = ctx.prefix + command 181 | 182 | new_ctx = await self.bot.get_context(msg, cls=AceContext) 183 | 184 | for i in range(repeats): 185 | await new_ctx.reinvoke() 186 | 187 | @commands.command(name="reload", aliases=["rl"]) 188 | @commands.bot_has_permissions(add_reactions=True) 189 | async def _reload(self, ctx): 190 | """Reload edited extensions.""" 191 | 192 | reloaded = self.bot.load_extensions() 193 | 194 | if reloaded: 195 | log.info("Reloaded cogs: %s", ", ".join(reloaded)) 196 | await ctx.send("Reloaded cogs: " + ", ".join("`{0}`".format(ext) for ext in reloaded)) 197 | else: 198 | await ctx.send("Nothing to reload.") 199 | 200 | @commands.command() 201 | async def decache(self, ctx, guild_id: int): 202 | """Clear cache of table data of a specific guild.""" 203 | 204 | configs = ( 205 | self.bot.config, 206 | self.bot.get_cog("Starboard").config, 207 | self.bot.get_cog("Moderation").config, 208 | self.bot.get_cog("Welcome").config, 209 | self.bot.get_cog("Roles").config, 210 | ) 211 | 212 | cleared = [] 213 | 214 | for config in configs: 215 | if await config.clear_entry(guild_id): 216 | cleared.append(config) 217 | 218 | await ctx.send( 219 | "Cleared entries for:\n```\n{0}\n```".format( 220 | "\n".join(config.table for config in cleared) 221 | ) 222 | ) 223 | 224 | @commands.command() 225 | @commands.bot_has_permissions(manage_messages=True) 226 | async def say(self, ctx, channel: disnake.TextChannel, *, content: str): 227 | """Send a message in a channel.""" 228 | 229 | await ctx.message.delete() 230 | await channel.send(content) 231 | 232 | @commands.command() 233 | async def status(self, ctx): 234 | """Refresh the status of the bot in case Discord cleared it.""" 235 | 236 | await self.bot.change_presence() 237 | await self.bot.set_status(activity_text=BOT_ACTIVITY) 238 | 239 | @commands.command() 240 | async def print(self, ctx, *, body: str): 241 | """Calls eval but wraps code in print()""" 242 | 243 | await ctx.invoke(self.eval, body=f"pprint({body})") 244 | 245 | 246 | def setup(bot): 247 | bot.add_cog(Owner(bot)) 248 | -------------------------------------------------------------------------------- /cogs/remind.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime, timedelta 3 | from enum import IntEnum 4 | 5 | import disnake 6 | import parsedatetime 7 | from disnake.ext import commands 8 | 9 | from cogs.mixins import AceMixin 10 | from utils.converters import SerialConverter 11 | from utils.databasetimer import ColumnTimer 12 | from utils.pager import Pager 13 | from utils.string import po, shorten 14 | from utils.time import pretty_datetime, pretty_timedelta 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | SUCCESS_EMOJI = "\U00002705" 19 | DEFAULT_REMINDER_MESSAGE = "Hey, wake up!" 20 | MIN_DELTA = timedelta(minutes=1) 21 | MAX_DELTA = timedelta(days=365 * 10) 22 | MAX_REMINDERS = 32 23 | 24 | 25 | class RemindPager(Pager): 26 | async def create_base_embed(self): 27 | embed = disnake.Embed(title="All your reminders for this server.") 28 | 29 | author = self.ctx.author 30 | embed.set_author(name=author.name, icon_url=author.display_avatar.url) 31 | 32 | return embed 33 | 34 | async def update_page_embed(self, embed, page, entries): 35 | now = datetime.utcnow() 36 | 37 | embed.clear_fields() 38 | 39 | for record in entries: 40 | _id = record.get("id") 41 | remind_on = record.get("remind_on") 42 | message = record.get("message") 43 | 44 | delta = remind_on - now 45 | 46 | time_text = pretty_timedelta(delta) 47 | embed.add_field( 48 | name=f"{_id}: {time_text}", 49 | value=shorten(message, 256) if message is not None else DEFAULT_REMINDER_MESSAGE, 50 | inline=False, 51 | ) 52 | 53 | 54 | def dt_factory(): 55 | return datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) 56 | 57 | 58 | class Timescale(IntEnum): 59 | TIME = 0 60 | DAY = 1 61 | MONTH = 2 62 | YEAR = 3 63 | 64 | 65 | class ReminderConverter(commands.Converter): 66 | NO_DT_FOUND = commands.CommandError("No time/date found in input.") 67 | 68 | async def convert(self, ctx, argument): 69 | cal = parsedatetime.Calendar() 70 | now = datetime.utcnow() 71 | 72 | try: 73 | ret = cal.nlp(argument, now) 74 | except Exception: 75 | raise self.NO_DT_FOUND 76 | 77 | if not ret: 78 | raise self.NO_DT_FOUND 79 | 80 | (dt, flags, start, end, text), *rest = ret 81 | 82 | if flags == 0: 83 | raise self.NO_DT_FOUND 84 | 85 | before = argument[:start].strip(" ") 86 | end = argument[end:].strip(" ") 87 | 88 | joiners = (",",) 89 | 90 | for joiner in joiners: 91 | if before.endswith(joiner) and end.startswith(joiner): 92 | end = end[len(joiner) :] 93 | 94 | before = before.strip(" ") 95 | end = end.strip(" ") 96 | 97 | parts = list() 98 | if before: 99 | parts.append(before) 100 | if end: 101 | parts.append(end) 102 | 103 | text = " ".join(parts) 104 | 105 | return now, dt, None if not text else text 106 | 107 | 108 | class Reminders(AceMixin, commands.Cog): 109 | """Set, view, and delete reminders. 110 | 111 | Examples: 112 | `.remindme in 3 days do the laundry` 113 | `.remindme call back john in 10 minutes` 114 | `.remindme apply for job 17th of august` 115 | `.remindme tomorrow take out trash` 116 | 117 | Absolute dates/times are in UTC. 118 | """ 119 | 120 | def __init__(self, bot): 121 | super().__init__(bot) 122 | self.timer = ColumnTimer(self.bot, "reminder_complete", table="remind", column="remind_on") 123 | 124 | @commands.Cog.listener() 125 | async def on_reminder_complete(self, record): 126 | _id = record.get("id") 127 | guild_id = record.get("guild_id") 128 | channel_id = record.get("channel_id") 129 | user_id = record.get("user_id") 130 | message_id = record.get("message_id") 131 | made_on = record.get("made_on") 132 | message = record.get("message") 133 | 134 | channel = self.bot.get_channel(channel_id) 135 | user = self.bot.get_user(user_id) 136 | 137 | desc = message or DEFAULT_REMINDER_MESSAGE 138 | 139 | if message_id is not None: 140 | jump_url = "https://discord.com/channels/{0}/{1}/{2}".format( 141 | guild_id, channel_id, message_id 142 | ) 143 | desc += f"\n\n[Click for context!]({jump_url})" 144 | 145 | e = disnake.Embed(title="Reminder!", description=desc, timestamp=made_on) 146 | 147 | e.set_footer(text=f"#{channel.name}") 148 | 149 | try: 150 | if channel is not None: 151 | await channel.send(content=f"<@{user_id}>", embed=e) 152 | elif user is not None: 153 | await user.send(embed=e) 154 | except disnake.HTTPException as exc: 155 | log.info("Failed sending reminder #%s for %s - %s", _id, po(user), str(exc)) 156 | 157 | @commands.command(aliases=["remind", "reminder"]) 158 | @commands.bot_has_permissions(add_reactions=True) 159 | async def remindme(self, ctx, *, when_and_what: ReminderConverter()): 160 | """Create a new reminder.""" 161 | 162 | now, when, message = when_and_what 163 | 164 | if when < now: 165 | raise commands.CommandError("Specified time is in the past.") 166 | 167 | if when - now > MAX_DELTA: 168 | raise commands.CommandError("Sorry, can't remind in more than a year in the future.") 169 | 170 | if message is not None and len(message) > 1024: 171 | raise commands.CommandError("Sorry, keep the message below 1024 characters!") 172 | 173 | count = await self.db.fetchval( 174 | "SELECT COUNT(id) FROM remind WHERE user_id=$1", ctx.author.id 175 | ) 176 | if count > MAX_REMINDERS: 177 | raise commands.CommandError( 178 | f"Sorry, you can't have more than {MAX_REMINDERS} active reminders at once." 179 | ) 180 | 181 | await self.db.execute( 182 | "INSERT INTO remind (guild_id, channel_id, user_id, message_id, made_on, remind_on, message) VALUES ($1, $2, $3, $4, $5, $6, $7)", 183 | ctx.guild.id, 184 | ctx.channel.id, 185 | ctx.author.id, 186 | ctx.message.id, 187 | now, 188 | when, 189 | message, 190 | ) 191 | 192 | self.timer.maybe_restart(when) 193 | 194 | remind_in = when - now 195 | remind_in += timedelta(microseconds=1000000 - (remind_in.microseconds % 1000000)) 196 | 197 | await ctx.send("You will be reminded in {}.".format(pretty_timedelta(remind_in))) 198 | 199 | log.info("%s set a reminder for %s.", po(ctx.author), pretty_datetime(when)) 200 | 201 | @commands.command() 202 | @commands.bot_has_permissions(embed_links=True) 203 | async def reminders(self, ctx): 204 | """List your reminders in this guild.""" 205 | 206 | res = await self.db.fetch( 207 | "SELECT * FROM remind WHERE guild_id=$1 AND user_id=$2 ORDER BY id DESC", 208 | ctx.guild.id, 209 | ctx.author.id, 210 | ) 211 | 212 | if not len(res): 213 | raise commands.CommandError("Couldn't find any reminders.") 214 | 215 | await RemindPager(ctx, res, per_page=3).go() 216 | 217 | @commands.command(hidden=True) 218 | async def delreminder(self, ctx, *, reminder_id: SerialConverter()): 219 | """Delete a reminder. Must be your own reminder.""" 220 | 221 | res = await self.db.execute( 222 | "DELETE FROM remind WHERE id=$1 AND guild_id=$2 AND user_id=$3", 223 | reminder_id, 224 | ctx.guild.id, 225 | ctx.author.id, 226 | ) 227 | 228 | if res == "DELETE 1": 229 | await ctx.send("Reminder deleted.") 230 | self.timer.restart_if(lambda record: record.get("id") == reminder_id) 231 | else: 232 | raise commands.CommandError("Reminder not found, or you do not own it.") 233 | 234 | 235 | def setup(bot): 236 | bot.add_cog(Reminders(bot)) 237 | -------------------------------------------------------------------------------- /cogs/welcome.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import disnake 5 | from disnake.ext import commands 6 | 7 | from cogs.mixins import AceMixin 8 | from utils.configtable import ConfigTable, ConfigTableRecord 9 | from utils.string import po 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | WELCOME_NOT_SET_UP_ERROR = commands.CommandError( 14 | "You don't seem to have set up a welcome message yet, do `welcome` to see available commands." 15 | ) 16 | 17 | 18 | class WelcomeRecord(ConfigTableRecord): 19 | @property 20 | def channel(self): 21 | if self.channel_id is None: 22 | return None 23 | 24 | guild = self._config.bot.get_guild(self.guild_id) 25 | if guild is None: 26 | return None 27 | 28 | return guild.get_channel(self.channel_id) 29 | 30 | 31 | class Welcome(AceMixin, commands.Cog): 32 | """Show welcome messages to new members. 33 | 34 | Welcome message replacements: 35 | `{user}` - member mention 36 | `{guild}` - server name 37 | `{member_count}` - server member count 38 | 39 | Example: 40 | `.welcome message Welcome {user} to my server!` 41 | """ 42 | 43 | def __init__(self, bot): 44 | super().__init__(bot) 45 | 46 | self.config = ConfigTable(bot, "welcome", "guild_id", WelcomeRecord) 47 | 48 | async def cog_check(self, ctx): 49 | return await ctx.is_mod() 50 | 51 | @commands.Cog.listener() 52 | async def on_member_join(self, member): 53 | entry = await self.config.get_entry(member.guild.id, construct=False) 54 | 55 | if entry is None: 56 | return 57 | 58 | if entry.enabled is False: 59 | return 60 | 61 | if entry.content is None: 62 | return 63 | 64 | channel = entry.channel 65 | 66 | if channel is None: 67 | return 68 | 69 | # sleep a bit and then dispatch welcome event 70 | await asyncio.sleep(2) 71 | self.bot.dispatch("welcome", member, channel, entry.content) 72 | 73 | @commands.Cog.listener() 74 | async def on_welcome(self, member, channel, message): 75 | replace_table = dict( 76 | guild=member.guild.name, 77 | user=member.mention, 78 | member_count=member.guild.member_count, 79 | ) 80 | 81 | for key, val in replace_table.items(): 82 | message = message.replace("{" + key + "}", str(val)) 83 | 84 | log.info("Sending welcome message for %s in %s", po(member), po(member.guild)) 85 | 86 | try: 87 | await channel.send(message) 88 | except disnake.HTTPException: 89 | pass 90 | 91 | @commands.group(hidden=True, invoke_without_command=True) 92 | async def welcome(self, ctx): 93 | await ctx.send_help(self.welcome) 94 | 95 | @welcome.command() 96 | async def message(self, ctx, *, message: str): 97 | """Set a new welcome message.""" 98 | 99 | if len(message) > 1024: 100 | raise commands.CommandError("Welcome message has to be shorter than 1024 characters.") 101 | 102 | # make sure an entry for this exists... 103 | entry = await self.config.get_entry(ctx.guild.id) 104 | await entry.update(content=message) 105 | 106 | await ctx.send("Welcome message updated. Do `welcome test` to test.") 107 | 108 | @welcome.command() 109 | async def channel(self, ctx, *, channel: disnake.TextChannel = None): 110 | """Set or view welcome message channel.""" 111 | 112 | entry = await self.config.get_entry(ctx.guild.id) 113 | 114 | if channel is None: 115 | if entry.channel_id is None: 116 | raise commands.CommandError("Welcome channel not yet set.") 117 | 118 | channel = entry.channel 119 | if channel is None: 120 | raise commands.CommandError( 121 | "Channel previously set but not found, try setting a new one." 122 | ) 123 | 124 | else: 125 | await entry.update(channel_id=channel.id) 126 | 127 | await ctx.send(f"Welcome channel set to {channel.mention}") 128 | 129 | @welcome.command() 130 | async def raw(self, ctx): 131 | """Get the raw contents of your welcome message. Useful for editing.""" 132 | 133 | entry = await self.config.get_entry(ctx.guild.id, construct=False) 134 | 135 | if entry is None or entry.content is None: 136 | raise WELCOME_NOT_SET_UP_ERROR 137 | 138 | await ctx.send(disnake.utils.escape_markdown(entry.content)) 139 | 140 | @welcome.command() 141 | async def test(self, ctx): 142 | """Test your welcome command.""" 143 | 144 | entry = await self.config.get_entry(ctx.guild.id, construct=False) 145 | 146 | if entry is None: 147 | raise WELCOME_NOT_SET_UP_ERROR 148 | 149 | channel = entry.channel 150 | 151 | if channel is None: 152 | if entry.channel_id is None: 153 | raise commands.CommandError( 154 | "You haven't set up a welcome channel yet.\nSet up with `welcome channel [channel]`" 155 | ) 156 | else: 157 | raise commands.CommandError( 158 | "Welcome channel previously set but not found.\nPlease set again using `welcome channel [channel]`" 159 | ) 160 | 161 | if entry.enabled is False: 162 | raise commands.CommandError( 163 | "Welcome messages are disabled.\nEnable with `welcome enable`" 164 | ) 165 | 166 | if entry.content is None: 167 | raise commands.CommandError( 168 | "No welcome message set.\nSet with `welcome message `" 169 | ) 170 | 171 | await self.on_member_join(ctx.author) 172 | 173 | @welcome.command() 174 | async def enable(self, ctx): 175 | """Enable welcome messages.""" 176 | 177 | entry = await self.config.get_entry(ctx.guild.id) 178 | 179 | if entry.enabled is True: 180 | raise commands.CommandError("Welcome messages already enabled.") 181 | 182 | await entry.update(enabled=True) 183 | await ctx.send("Welcome messages enabled.") 184 | 185 | @welcome.command() 186 | async def disable(self, ctx): 187 | """Disable welcome messages.""" 188 | 189 | entry = await self.config.get_entry(ctx.guild.id) 190 | 191 | if entry.enabled is False: 192 | raise commands.CommandError("Welcome messages already disabled.") 193 | 194 | await entry.update(enabled=False) 195 | await ctx.send("Welcome messages disabled.") 196 | 197 | 198 | def setup(bot): 199 | bot.add_cog(Welcome(bot)) 200 | -------------------------------------------------------------------------------- /cogs/whois.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | import disnake 4 | from disnake.ext import commands 5 | 6 | from cogs.mixins import AceMixin 7 | from utils.string import po 8 | from utils.time import pretty_datetime, pretty_timedelta 9 | 10 | 11 | class WhoIs(AceMixin, commands.Cog): 12 | """View info about a member.""" 13 | 14 | @commands.command() 15 | @commands.bot_has_permissions(embed_links=True) 16 | async def info(self, ctx, *, member: disnake.Member = None): 17 | """Display information about user or self.""" 18 | 19 | member = member or ctx.author 20 | 21 | e = disnake.Embed(description="") 22 | 23 | if member.bot: 24 | e.description = "This account is a bot.\n\n" 25 | 26 | e.description += member.mention 27 | 28 | if member.activity: 29 | e.add_field(name="Activity", value=member.activity.name) 30 | 31 | e.set_author(name=str(member), icon_url=member.display_avatar.url) 32 | 33 | now = datetime.now(timezone.utc) 34 | created = member.created_at 35 | joined = member.joined_at 36 | 37 | e.add_field( 38 | name="Account age", 39 | value="{0} • Created ".format( 40 | pretty_timedelta(now - created), round(created.timestamp()) 41 | ), 42 | inline=False, 43 | ) 44 | 45 | e.add_field( 46 | name="Member for", 47 | value="{0} • Joined ".format( 48 | pretty_timedelta(now - joined), round(joined.timestamp()) 49 | ), 50 | ) 51 | 52 | if len(member.roles) > 1: 53 | e.add_field( 54 | name="Roles", 55 | value=" ".join(role.mention for role in reversed(member.roles[1:])), 56 | inline=False, 57 | ) 58 | 59 | e.set_footer(text="ID: " + str(member.id)) 60 | 61 | await ctx.send(embed=e) 62 | 63 | @commands.command(aliases=["newmembers"]) 64 | @commands.bot_has_permissions(embed_links=True) 65 | async def newusers(self, ctx, *, count=5): 66 | """List newly joined members.""" 67 | 68 | count = min(max(count, 5), 25) 69 | 70 | now = datetime.now(timezone.utc) 71 | e = disnake.Embed() 72 | 73 | for idx, member in enumerate( 74 | sorted(ctx.guild.members, key=lambda m: m.joined_at, reverse=True) 75 | ): 76 | if idx >= count: 77 | break 78 | 79 | value = "Joined {0} ago\nCreated {1} ago".format( 80 | pretty_timedelta(now - member.joined_at), 81 | pretty_timedelta(now - member.created_at), 82 | ) 83 | e.add_field(name=po(member), value=value, inline=False) 84 | 85 | await ctx.send(embed=e) 86 | 87 | @commands.command() 88 | async def avatar(self, ctx, *, member: disnake.Member = None): 89 | """Show an enlarged version of a members avatar.""" 90 | if member is None: 91 | member = ctx.author 92 | await ctx.send(member.display_avatar.url) 93 | 94 | 95 | def setup(bot): 96 | bot.add_cog(WhoIs(bot)) 97 | -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | postgres: 3 | container_name: ace-postgres 4 | image: postgres:16.2 5 | ports: 6 | - 5432:5432 7 | environment: 8 | POSTGRES_USER: ace_user 9 | POSTGRES_PASSWORD: ace_pass 10 | POSTGRES_DB: ace_db 11 | volumes: 12 | - ace-pg-data:/var/lib/postgresql/data 13 | 14 | volumes: 15 | ace-pg-data: 16 | -------------------------------------------------------------------------------- /docs_service/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.2 2 | 3 | RUN mkdir app 4 | WORKDIR /app 5 | 6 | COPY parser_instances . 7 | COPY aggregator.py . 8 | COPY api.py . 9 | COPY parsers.py . 10 | COPY requirements.txt . 11 | 12 | RUN pip install --no-cache-dir -r requirements.txt 13 | 14 | CMD ["python3", "-u", "api.py"] -------------------------------------------------------------------------------- /docs_service/aggregator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from parsers import Entry, HeadersParser 5 | from tqdm import tqdm 6 | 7 | 8 | class Aggregator: 9 | def __init__(self, folder, version) -> None: 10 | self.folder = folder 11 | self.version = version 12 | self.entries = dict() 13 | self._parsed_files = set() 14 | 15 | def bulk_parse_from_dir(self, path, parser_type, **parser_kwargs): 16 | return self.bulk_parse( 17 | [ 18 | parser_type(self.folder, self.version, htm, **parser_kwargs) 19 | for htm in self._get_htms(path, parser_type) 20 | ] 21 | ) 22 | 23 | def bulk_parse(self, parsers): 24 | for parser in tqdm(parsers): 25 | parser.parse() 26 | 27 | if parser.page not in self.entries: 28 | # page has not been parsed before, so just plonk parser entries into aggregator entries 29 | self.entries[parser.page] = parser.entries 30 | else: 31 | # page HAS been parsed before, so we need to weave/update entries 32 | current_entries = self.entries[parser.page] 33 | for fragment, entry in parser.entries.items(): 34 | present_entry = current_entries.get(fragment, None) 35 | if present_entry is None: 36 | current_entries[fragment] = entry 37 | else: 38 | present_entry.merge(entry) 39 | 40 | self._parsed_files.add((parser.__class__, parser.page)) 41 | 42 | def parse_data_index(self, file): 43 | parser_cache = dict() 44 | 45 | with open(f"{self.folder}/{file}") as f: 46 | index = json.loads(f.read()[12:-2]) 47 | for indice in tqdm(index): 48 | name, page, *_ = indice 49 | 50 | if "#" in page: 51 | page, fragment = page.split("#") 52 | else: 53 | fragment = None 54 | 55 | if page in parser_cache: 56 | parser = parser_cache[page] 57 | else: 58 | try: 59 | parser = HeadersParser(self.folder, self.version, page) 60 | except FileNotFoundError: 61 | continue 62 | parser_cache[page] = parser 63 | 64 | tag = parser.bs.find(True, id=fragment) 65 | text, syntax, version = parser.tag_parse(tag) 66 | 67 | entry = Entry( 68 | name=name, 69 | primary_names=parser.name_splitter(name)[1], 70 | page=page, 71 | content=text or None, 72 | fragment=fragment, 73 | syntax=syntax, 74 | version=version, 75 | parents=None, 76 | secondary_names=None, 77 | ) 78 | 79 | if page not in self.entries: 80 | self.entries[page] = dict() 81 | 82 | current_entry = self.entries[page].get(fragment, None) 83 | if current_entry is None: 84 | self.entries[page][fragment] = entry 85 | else: 86 | current_entry.merge(entry) 87 | 88 | def iter_entries(self): 89 | for entries in self.entries.values(): 90 | for entry in entries.values(): 91 | yield entry 92 | 93 | def check_name_semantically_exists(self, name, names): 94 | def cmp(a, b): 95 | a = a.lower() 96 | b = b.lower() 97 | return a == b or a == b + "()" 98 | 99 | for check_name in names: 100 | if cmp(name, check_name): 101 | return True 102 | return False 103 | 104 | def name_map(self): 105 | mapper = dict() # name: id 106 | 107 | for entry in self.iter_entries(): 108 | for pname in entry.primary_names: 109 | if pname not in mapper: 110 | mapper[pname] = entry.id 111 | 112 | for entry in self.iter_entries(): 113 | if entry.secondary_names is None: 114 | continue 115 | for pname in entry.secondary_names: 116 | if pname not in mapper: 117 | mapper[pname] = entry.id 118 | 119 | used = set() 120 | to_delete = set() 121 | for name in mapper.keys(): 122 | lower_name = name.lower() 123 | if lower_name in used or f"{lower_name}()" in used: 124 | to_delete.add(name) 125 | elif lower_name.endswith("()") and lower_name[:-2] in used: 126 | to_delete.add(name[:-2]) 127 | else: 128 | used.add(lower_name) 129 | 130 | for item in to_delete: 131 | del mapper[item] 132 | 133 | return mapper 134 | 135 | def assign_ids(self, start_at): 136 | _id = start_at 137 | for entry in self.iter_entries(): 138 | entry.id = _id 139 | _id += 1 140 | 141 | @property 142 | def entry_count(self): 143 | count = 0 144 | for _ in self.iter_entries(): 145 | count += 1 146 | return count 147 | 148 | def _get_htms(self, folder, filter_on_type): 149 | filtered = [] 150 | for file in os.listdir(f"{self.folder}/{folder}"): 151 | if not file.endswith(".htm"): 152 | continue 153 | 154 | file = f"{folder}/{file}" 155 | 156 | if (filter_on_type, file) in self._parsed_files: 157 | continue 158 | 159 | filtered.append(file) 160 | 161 | return filtered 162 | 163 | def printer(self): 164 | for entry in self.iter_entries(): 165 | print("UUID:", entry.id) 166 | print("Name:", entry.name) 167 | print("Primary names:", entry.primary_names) 168 | print("Secondary names:", entry.secondary_names) 169 | print("Page:", entry.page) 170 | print("Fragment:", entry.fragment) 171 | print("Syntax:", entry.syntax) 172 | print("Version:", entry.version) 173 | print("Parents:", entry.parents) 174 | print() 175 | print(entry.content) 176 | 177 | print("\n", "-" * 100, "\n") 178 | -------------------------------------------------------------------------------- /docs_service/api.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from random import choices 4 | 5 | import asyncpg 6 | from rapidfuzz import fuzz, process 7 | from sanic import Blueprint, Request, Sanic 8 | from sanic.response import HTTPResponse, json 9 | 10 | import config 11 | 12 | app = Sanic("ahkdocs_api") 13 | 14 | api = Blueprint("api", url_prefix="/api") 15 | 16 | 17 | meaning_scalar = lambda v: 1 / ((v * 0.5) ** 2 + 1) 18 | 19 | 20 | def processor(s): 21 | s = s.strip().lower() 22 | return re.sub(r"(\(|\))", "", s) 23 | 24 | 25 | def docs_search(names, query, k): 26 | query = query.strip() 27 | 28 | if not query: 29 | return choices(names, k=k) 30 | 31 | word_scores = [] 32 | 33 | splitters = [query] 34 | splitters.extend(query.split(" ")) 35 | scores = defaultdict(float) 36 | 37 | for i, word in enumerate(splitters): 38 | word_scores = process.extract( 39 | query=word, 40 | choices=names, 41 | scorer=fuzz.WRatio, 42 | processor=processor, 43 | limit=100, 44 | ) 45 | 46 | for name, score, _ in word_scores: 47 | scores[name] += score * meaning_scalar(i) 48 | 49 | return list(name for name, _ in sorted(scores.items(), key=lambda item: item[1], reverse=True))[ 50 | :k 51 | ] 52 | 53 | 54 | def entry_to_dict(row): 55 | return dict( 56 | id=row.get("id"), 57 | v=row.get("v"), 58 | name=row.get("name"), 59 | page=row.get("page"), 60 | fragment=row.get("fragment"), 61 | content=row.get("content"), 62 | syntax=row.get("syntax"), 63 | version=row.get("version"), 64 | ) 65 | 66 | 67 | async def get_entry(conn: asyncpg.Connection, docs_id: int, lineage=True, search_match=None): 68 | row = await conn.fetchrow("SELECT * FROM docs_entry WHERE id=$1", docs_id) 69 | 70 | o = entry_to_dict(row) 71 | 72 | if lineage: 73 | parents = [] 74 | children = [] 75 | 76 | for parent_id in row.get("parents"): 77 | if parent_id is None: # shouldn't happen 78 | continue 79 | parents.append(await get_entry(conn, parent_id, lineage=False)) 80 | 81 | rows = await conn.fetch( 82 | "SELECT * FROM docs_entry WHERE parents[array_upper(parents, 1)] = $1", 83 | docs_id, 84 | ) 85 | for row in rows: 86 | children.append(entry_to_dict(row)) 87 | 88 | o["parents"] = parents 89 | o["children"] = children 90 | o["search_match"] = search_match 91 | 92 | return o 93 | 94 | 95 | @api.post("/search") 96 | async def search(request: Request): 97 | data = request.json 98 | 99 | if data is None: 100 | return HTTPResponse(status=400) 101 | 102 | q = data.get("q", None) 103 | v = data.get("v", None) 104 | 105 | if q is None or v is None: 106 | return HTTPResponse(status=400) 107 | 108 | if not isinstance(q, str) or not isinstance(v, int): 109 | return HTTPResponse(status=400) 110 | 111 | names = app.ctx.names[v] 112 | 113 | res = docs_search(names, q, k=5) 114 | return json(res) 115 | 116 | 117 | @api.post("/entry") 118 | async def entry(request: Request): 119 | data = request.json 120 | 121 | if data is None: 122 | return HTTPResponse(status=400) 123 | 124 | q = data.get("q", None) 125 | v = data.get("v", None) 126 | 127 | if q is None or v is None: 128 | return HTTPResponse(status=400) 129 | 130 | if not isinstance(q, str) or not isinstance(v, int): 131 | return HTTPResponse(status=400) 132 | 133 | names = app.ctx.names[v] 134 | 135 | res = docs_search(names, q, k=1)[0] 136 | docs_id = app.ctx.id_map[v][res] 137 | 138 | async with app.ctx.pool.acquire() as conn: 139 | conn: asyncpg.Connection 140 | async with conn.transaction(): 141 | return json(await get_entry(conn, docs_id, lineage=True, search_match=res)) 142 | 143 | 144 | @app.signal("server.init.before") 145 | async def setup(app, loop): 146 | pool = await asyncpg.create_pool(config.DB_BIND) 147 | 148 | async with pool.acquire() as conn: 149 | conn: asyncpg.Connection 150 | async with conn.transaction(): 151 | res = await conn.fetch("SELECT * FROM docs_name") 152 | id_map = {v: {} for v in (1, 2)} 153 | names = {v: [] for v in (1, 2)} 154 | 155 | for row in res: 156 | v = row.get("v") 157 | docs_id = row.get("docs_id") 158 | name = row.get("name") 159 | id_map[v][name] = docs_id 160 | names[v].append(name) 161 | 162 | app.ctx.names = names 163 | app.ctx.id_map = id_map 164 | app.ctx.pool = pool 165 | 166 | 167 | app.blueprint(api) 168 | 169 | if __name__ == "__main__": 170 | app.run(host="0.0.0.0", port=8000) 171 | -------------------------------------------------------------------------------- /docs_service/build.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import re 4 | import shutil 5 | from zipfile import ZipFile 6 | 7 | import aiohttp 8 | import asyncpg 9 | from aggregator import Aggregator 10 | from bs4 import BeautifulSoup 11 | from parser_instances.common import command, default 12 | from parser_instances.v1 import get as v1_get 13 | from parser_instances.v2 import get as v2_get 14 | from parsers import HeadersParser 15 | 16 | import config 17 | 18 | 19 | async def view_h(parsers, base, path): 20 | folder = base 21 | if path: 22 | folder += f"/{path}" 23 | for i, file in enumerate(sorted(os.listdir(folder))): 24 | if not file.endswith(".htm"): 25 | continue 26 | 27 | ff = f"{folder}/{file}" 28 | 29 | found = False 30 | for parser in parsers: 31 | parser_doink = f"{parser.base}/{parser.page}" 32 | if parser_doink == ff: 33 | found = True 34 | break 35 | 36 | if found: 37 | continue 38 | 39 | bs = BeautifulSoup(open(ff, "r").read(), "lxml") 40 | h = [] 41 | a = set() 42 | for tag in bs.find_all(re.compile(r"^h\d$")): 43 | h.append(tag.name) 44 | if bs.find("h3", id="Methods") or bs.find("h3", id="Properties"): 45 | a.add("object") 46 | if bs.find("h3", id="SubCommands"): 47 | a.add("subcommands") 48 | 49 | if a: 50 | print(file, a) 51 | print(h) 52 | 53 | 54 | async def store(pool: asyncpg.Pool, agg: Aggregator, version: int, id_start_at=1): 55 | print("storing version", version, "starting at id", id_start_at) 56 | print(agg.entry_count, "entries") 57 | print(len(agg.name_map()), "names") 58 | 59 | entry_sql = ( 60 | "INSERT INTO docs_entry (id, v, name, page, fragment, content, syntax, version, parents) " 61 | "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" 62 | ) 63 | 64 | name_sql = "INSERT INTO docs_name (v, docs_id, name) VALUES ($1, $2, $3)" 65 | 66 | agg.assign_ids(id_start_at) 67 | 68 | ent = [ 69 | ( 70 | entry.id, 71 | version, 72 | entry.name, 73 | entry.page, 74 | entry.fragment, 75 | entry.content if entry.content else None, 76 | entry.syntax, 77 | entry.version, 78 | [e.id for e in entry.parents or []], 79 | ) 80 | for entry in agg.iter_entries() 81 | ] 82 | 83 | names = [(version, _id, name) for name, _id in agg.name_map().items()] 84 | 85 | async with pool.acquire() as conn: 86 | conn: asyncpg.Connection 87 | async with conn.transaction(): 88 | await conn.executemany(entry_sql, ent) 89 | await conn.executemany(name_sql, names) 90 | 91 | print("finished storing version", version) 92 | 93 | 94 | async def downloader(url, download_to, extract_to): 95 | print("downloading", url) 96 | 97 | # delete old stuff 98 | try: 99 | os.remove(download_to) 100 | except FileNotFoundError: 101 | pass 102 | 103 | shutil.rmtree(extract_to, ignore_errors=True) 104 | 105 | # fetch docs package 106 | async with aiohttp.ClientSession() as session: 107 | async with session.get(url) as resp: 108 | if resp.status != 200: 109 | raise ValueError("http returned:", resp.status) 110 | 111 | with open(download_to, "wb") as f: 112 | f.write(await resp.read()) 113 | 114 | print("extracting to ", extract_to) 115 | 116 | # and extract it 117 | zip_ref = ZipFile(download_to, "r") 118 | zip_ref.extractall(extract_to) 119 | zip_ref.close() 120 | 121 | 122 | async def build_v1_aggregator(folder, download=False) -> Aggregator: 123 | if download: 124 | await downloader( 125 | url="https://github.com/AutoHotkey/AutoHotkeyDocs/archive/v1.zip", 126 | download_to="docs_v1.zip", 127 | extract_to=folder, 128 | ) 129 | 130 | print("parsing v1 docs") 131 | 132 | folder += "/AutoHotkeyDocs-1/docs" 133 | 134 | agg = Aggregator(folder=folder, version=1) 135 | 136 | agg.bulk_parse(v1_get(folder)) 137 | agg.bulk_parse_from_dir("lib", parser_type=HeadersParser, **command) 138 | agg.bulk_parse_from_dir("misc", parser_type=HeadersParser, **default()) 139 | agg.parse_data_index("static/source/data_index.js") 140 | 141 | return agg 142 | 143 | 144 | async def build_v2_aggregator(folder, download=False) -> Aggregator: 145 | if download: 146 | await downloader( 147 | url="https://github.com/AutoHotkey/AutoHotkeyDocs/archive/v2.zip", 148 | download_to="docs_v2.zip", 149 | extract_to=folder, 150 | ) 151 | 152 | print("parsing v2 docs") 153 | 154 | folder += "/AutoHotkeyDocs-2/docs" 155 | 156 | agg = Aggregator(folder=folder, version=2) 157 | 158 | agg.bulk_parse(v2_get(folder)) 159 | agg.bulk_parse_from_dir("lib", parser_type=HeadersParser, **command) 160 | agg.bulk_parse_from_dir("misc", parser_type=HeadersParser, **default()) 161 | agg.parse_data_index("static/source/data_index.js") 162 | 163 | return agg 164 | 165 | 166 | async def main(): 167 | db = await asyncpg.create_pool(config.DB_BIND) 168 | await db.execute("TRUNCATE docs_name, docs_entry, docs_syntax RESTART IDENTITY") 169 | 170 | agg = await build_v1_aggregator("docs_v1", download=True) 171 | await store(db, agg, 1) 172 | start_at = agg.entry_count + 1 173 | 174 | print() 175 | 176 | agg = await build_v2_aggregator("docs_v2", download=True) 177 | await store(db, agg, 2, id_start_at=start_at) 178 | 179 | await db.close() 180 | 181 | print("done") 182 | 183 | 184 | if __name__ == "__main__": 185 | loop = asyncio.new_event_loop() 186 | loop.run_until_complete(main()) 187 | # loop.run_forever() 188 | -------------------------------------------------------------------------------- /docs_service/migrate.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS docs_entry ( 2 | id SERIAL UNIQUE, 3 | v SMALLINT NOT NULL, 4 | name TEXT NOT NULL, 5 | page TEXT NOT NULL, 6 | fragment TEXT NULL, 7 | content TEXT NULL, 8 | syntax TEXT NULL, 9 | version TEXT NULL, 10 | parents INTEGER[] NULL 11 | ); 12 | 13 | CREATE TABLE IF NOT EXISTS docs_name ( 14 | id SERIAL UNIQUE, 15 | v SMALLINT NOT NULL, 16 | docs_id INT REFERENCES docs_entry (id) NOT NULL, 17 | name TEXT NOT NULL 18 | ); 19 | 20 | CREATE TABLE IF NOT EXISTS docs_syntax ( 21 | id SERIAL UNIQUE, 22 | docs_id INT REFERENCES docs_entry (id) NOT NULL, 23 | syntax TEXT NOT NULL 24 | ); 25 | -------------------------------------------------------------------------------- /docs_service/parser_instances/common.py: -------------------------------------------------------------------------------- 1 | from parsers import BULLET_SPACED 2 | 3 | # for simple command pages (like Abs.htm) 4 | command = dict(ignore=lambda h, t, p: h > 1) 5 | 6 | 7 | # for pages with commands that have subcommands 8 | def subcommand(prefix): 9 | return dict( 10 | prefix_mapper=[ 11 | (2, 1), 12 | (lambda h, t, p: p[-1] == "SubCommands", prefix + "{}"), 13 | ( 14 | 3, 15 | 1, 16 | ), # I removed this since it made gui.htm look nicer, maybe a dumb idea 17 | (4, 3), 18 | ], 19 | basic_name_check=lambda h, t, p: h in (1, 3) and p[-1] != "SubCommands", 20 | ) 21 | 22 | 23 | # for pages describing objects with methods and properties 24 | def obj( 25 | prefix, 26 | instance_name=None, 27 | meth="Methods", 28 | prop="Properties", 29 | func="Functions", 30 | staticmeth="StaticMethods", 31 | ): 32 | instance_name = instance_name or prefix 33 | 34 | # v2 input hook has these two badboys... 35 | props = [prop, "General_Properties", "Option_Properties"] 36 | 37 | return dict( 38 | prefix_mapper=[ 39 | (2, 1), 40 | (lambda h, t, p: p[-1] == meth, instance_name + ".{}()"), 41 | (lambda h, t, p: p[-1] in props, instance_name + ".{}"), 42 | (lambda h, t, p: p[-1] == func, "{}()"), 43 | (lambda h, t, p: p[-1] == staticmeth, prefix + "()"), 44 | (3, 1), 45 | (4, 3), 46 | ], 47 | basic_name_check=lambda h, t, p: h in (1, 3) 48 | and p[-1] not in [meth, func, staticmeth, *props], 49 | ) 50 | 51 | 52 | # stop: which h* to stop adding at 53 | # remap: remap the name of the h1 tag when prefixing 54 | # ignore: list of fragments to ignore 55 | # prefix: up to which h* should the remapper apply 56 | def default(stop=3, remap=1, ignore=None, prefix=None): 57 | if stop > 3: 58 | raise ValueError("You gotta fix the mapper in this case") 59 | 60 | # always add basic name for h1 tags 61 | basic = set([1]) 62 | 63 | if remap is None: 64 | mapper = None 65 | else: 66 | if isinstance(remap, int): 67 | one = remap 68 | elif isinstance(remap, str): 69 | one = BULLET_SPACED.join((remap, "{}")) 70 | mapper = [(i + 1, one) for i in range(1, prefix or stop)] 71 | 72 | # if we're doing remapping but also excluding some h* tags 73 | # because of *prefix*, then we need to add plain names for the 74 | # difference 75 | if prefix: 76 | for i in range(prefix, stop): 77 | basic.add(i + 1) 78 | 79 | d = dict( 80 | prefix_mapper=mapper, 81 | basic_name_check=lambda h, t, p: True if remap is None else h in basic, 82 | ignore=lambda h, t, p: h > stop or (ignore and t.get("id", None) in ignore), 83 | ) 84 | 85 | return d 86 | -------------------------------------------------------------------------------- /docs_service/parser_instances/v1.py: -------------------------------------------------------------------------------- 1 | from parsers import HeadersParser, TableParser 2 | 3 | from .common import command, default, obj, subcommand 4 | 5 | # listview, treeview, etc 6 | view = dict( 7 | prefix_mapper=[(2, 1), (3, 1), (4, 3)], 8 | basic_name_check=lambda h, t, p: h in (1, 3), 9 | ) 10 | 11 | 12 | # guicontrols.htm page which is unique 13 | # also holy shit what is going on 14 | guicontrols = dict( 15 | prefix_mapper=[ 16 | ( 17 | lambda h, t, p: h == 2 and t.find_next_sibling("p").text.startswith("Description:"), 18 | "{} Control", 19 | ), 20 | (2, 1), 21 | (3, -1), 22 | (4, -1), 23 | ], 24 | basic_name_check=lambda h, t, p: ( 25 | not t.get("id", "").endswith("_Options") 26 | and not (h == 2 and t.find_next_sibling("p").text.startswith("Description:")) 27 | ), 28 | ) 29 | 30 | 31 | def get(base): 32 | return ( 33 | HeadersParser(base, 1, "lib/Math.htm", **default(prefix=2)), 34 | HeadersParser(base, 1, "lib/ListView.htm", **view), 35 | HeadersParser(base, 1, "lib/TreeView.htm", **view), 36 | HeadersParser(base, 1, "lib/Gui.htm", **subcommand("Gui, ")), 37 | HeadersParser(base, 1, "lib/Menu.htm", **subcommand("Menu, ")), 38 | HeadersParser(base, 1, "lib/Control.htm", **subcommand("Control, ")), 39 | HeadersParser(base, 1, "lib/GuiControl.htm", **subcommand("GuiControl, ")), 40 | HeadersParser(base, 1, "lib/GuiControls.htm", **guicontrols), 41 | HeadersParser(base, 1, "lib/File.htm", **obj("File")), 42 | HeadersParser(base, 1, "lib/Func.htm", **obj("Func")), 43 | HeadersParser(base, 1, "lib/Object.htm", **obj("Object")), 44 | HeadersParser(base, 1, "lib/Enumerator.htm", **obj("Enum")), 45 | HeadersParser(base, 1, "lib/ComObjArray.htm", **obj("ComObjArray")), 46 | HeadersParser( 47 | base, 48 | 1, 49 | "lib/InputHook.htm", 50 | **obj("InputHook", prop="object"), 51 | ), # TODO: docs are dumb I cba fixing this one 52 | HeadersParser(base, 1, "lib/Process.htm", **subcommand("Process, ")), 53 | HeadersParser(base, 1, "lib/Thread.htm", **subcommand("Thread, ")), 54 | HeadersParser(base, 1, "lib/ControlGet.htm", **subcommand("ControlGet, ")), 55 | HeadersParser(base, 1, "lib/Drive.htm", **subcommand("Drive, ")), 56 | HeadersParser(base, 1, "lib/DriveGet.htm", **subcommand("DriveGet, ")), 57 | HeadersParser(base, 1, "lib/GuiControlGet.htm", **subcommand("GuiControlGet, ")), 58 | HeadersParser(base, 1, "lib/SysGet.htm", **subcommand("SysGet, ")), 59 | HeadersParser(base, 1, "lib/Transform.htm", **subcommand("Transform, ")), 60 | HeadersParser(base, 1, "lib/WinGet.htm", **subcommand("WinGet, ")), 61 | HeadersParser(base, 1, "lib/WinSet.htm", **subcommand("WinSet, ")), 62 | HeadersParser(base, 1, "misc/RegEx-QuickRef.htm", **default(remap="RegEx")), 63 | HeadersParser(base, 1, "AHKL_DBGPClients.htm", **default()), 64 | HeadersParser(base, 1, "AHKL_Features.htm", **default()), 65 | HeadersParser(base, 1, "Concepts.htm", **default(remap="Concepts")), 66 | HeadersParser(base, 1, "FAQ.htm", **default(remap="FAQ")), 67 | HeadersParser(base, 1, "Functions.htm", **default(remap=None)), 68 | HeadersParser(base, 1, "Hotkeys.htm", **default(remap="Hotkeys")), 69 | HeadersParser(base, 1, "Hotstrings.htm", **default()), 70 | HeadersParser(base, 1, "Language.htm", **default(stop=2, remap=None)), 71 | HeadersParser(base, 1, "Objects.htm", **default()), 72 | HeadersParser(base, 1, "Program.htm", **default(stop=2, remap=None)), 73 | HeadersParser(base, 1, "Scripts.htm", **default(stop=2, remap=None)), 74 | HeadersParser(base, 1, "Tutorial.htm", **default(remap="Tutorial")), 75 | HeadersParser(base, 1, "Variables.htm", **default(remap=None, ignore=["loop"])), 76 | HeadersParser(base, 1, "KeyList.htm", **default(remap="List of Keys")), 77 | HeadersParser(base, 1, "HotkeyFeatures.htm", **default(remap=None)), 78 | TableParser(base, 1, "Variables.htm"), 79 | TableParser(base, 1, "KeyList.htm"), 80 | TableParser(base, 1, "Hotkeys.htm"), 81 | ) 82 | -------------------------------------------------------------------------------- /docs_service/parser_instances/v2.py: -------------------------------------------------------------------------------- 1 | from parser_instances.common import command, default, obj 2 | from parsers import HeadersParser 3 | 4 | # guicontrols.htm page which is unique 5 | # also holy shit what is going on 6 | guicontrols = dict( 7 | prefix_mapper=[ 8 | ( 9 | lambda h, t, p: h == 2 and t.find_next_sibling("p").text.startswith("Description:"), 10 | "{} Control", 11 | ), 12 | (2, 1), 13 | (3, -1), 14 | (4, -1), 15 | ], 16 | basic_name_check=lambda h, t, p: h == 1, 17 | ignore=lambda h, t, p: h > 3, 18 | ) 19 | 20 | 21 | def get(base): 22 | return ( 23 | HeadersParser(base, 2, "AHKL_DBGPClients.htm", **default()), 24 | HeadersParser(base, 2, "Concepts.htm", **default(remap="Concepts")), 25 | HeadersParser(base, 2, "FAQ.htm", **default(remap="FAQ")), 26 | HeadersParser(base, 2, "Functions.htm", **default(remap=None)), 27 | HeadersParser(base, 2, "Hotkeys.htm", **default(remap="Hotkeys")), 28 | HeadersParser(base, 2, "Hotstrings.htm", **default()), 29 | HeadersParser(base, 2, "Language.htm", **default(stop=2, remap=None)), 30 | HeadersParser(base, 2, "Objects.htm", **default()), 31 | HeadersParser(base, 2, "Program.htm", **default(stop=2, remap=None)), 32 | HeadersParser(base, 2, "Scripts.htm", **default(stop=2, remap=None)), 33 | HeadersParser(base, 2, "Tutorial.htm", **default(remap="Tutorial")), 34 | HeadersParser(base, 2, "Variables.htm", **default(remap=None, ignore=["loop"])), 35 | HeadersParser(base, 2, "KeyList.htm", **default(remap="List of Keys")), 36 | HeadersParser(base, 2, "HotkeyFeatures.htm", **default(remap=None)), 37 | HeadersParser(base, 2, "v1-changes.htm", **default(stop=1)), 38 | HeadersParser(base, 2, "v2-changes.htm", **default(stop=1)), 39 | HeadersParser(base, 2, "lib/Any.htm", **obj("Value")), 40 | HeadersParser(base, 2, "lib/Array.htm", **obj("Array", "ArrayObj")), 41 | HeadersParser(base, 2, "lib/Buffer.htm", **obj("Buffer", "BufferObj")), 42 | HeadersParser(base, 2, "lib/Class.htm", **obj("ClassObj")), 43 | # HeadersParser(base, 2, "lib/ComObjArray.htm", **object("ComObjArray")), # page not formatted correctly 44 | HeadersParser( 45 | base, 2, "lib/Enumerator.htm", **obj("Enum") 46 | ), # kind of weird, has a function thing too? 47 | HeadersParser(base, 2, "lib/File.htm", **obj("File", "FileObj")), 48 | HeadersParser(base, 2, "lib/Func.htm", **obj("Func", "FuncObj")), 49 | HeadersParser(base, 2, "lib/Gui.htm", **obj("Gui", "MyGui", staticmeth="Static_Methods")), 50 | HeadersParser(base, 2, "lib/GuiControl.htm", **obj("GuiControl", "GuiCtrl")), 51 | HeadersParser(base, 2, "lib/Map.htm", **obj("Map", "MapObj")), 52 | HeadersParser(base, 2, "lib/Menu.htm", **obj("Menu", "MyMenu")), 53 | HeadersParser(base, 2, "lib/Object.htm", **obj("Object", "Obj")), 54 | # HeadersParser(base, 2, "lib/InputHook.htm", **object("InputHook")), # h3 tags that should be h2, and new property id names??? 55 | HeadersParser(base, 2, "lib/GuiControls.htm", **guicontrols), 56 | HeadersParser(base, 2, "lib/Hotstring.htm", **default(stop=2)), 57 | HeadersParser(base, 2, "lib/ListView.htm", **obj("ListView", "LV", meth="BuiltIn")), 58 | HeadersParser(base, 2, "lib/TreeView.htm", **obj("TreeView", "TV", meth="BuiltIn")), 59 | HeadersParser(base, 2, "lib/Math.htm", **default(prefix=2)), 60 | ) 61 | 62 | 63 | """ 64 | issues: 65 | ClassObj.Call() should also have an entry for ClassObj() really 66 | """ 67 | -------------------------------------------------------------------------------- /docs_service/parsers.py: -------------------------------------------------------------------------------- 1 | import re 2 | from itertools import chain 3 | 4 | from bs4 import BeautifulSoup, NavigableString, Tag 5 | from markdownify import MarkdownConverter 6 | 7 | DOCS_URL_FMT = "https://www.autohotkey.com/docs/v{}/" 8 | ANY_HEADER_RE = re.compile(r"^h\d$") 9 | BIG_HEADER_RE = re.compile(r"^h[1-3]$") 10 | BULLET_SPACED = " • " 11 | 12 | """ 13 | current issues: 14 | - see gui cancel/hide and how it wraps a div around with another id 15 | """ 16 | 17 | 18 | class Entry: 19 | def __init__( 20 | self, 21 | name, 22 | primary_names, 23 | page, 24 | content, 25 | fragment=None, 26 | syntax=None, 27 | version=None, 28 | parents=None, 29 | secondary_names=None, 30 | ): 31 | self.id = None 32 | self.name = name 33 | self.primary_names = primary_names 34 | self.page = page 35 | self.fragment = fragment 36 | self.content = content 37 | self.syntax = syntax 38 | self.version = version 39 | self.parents = parents 40 | self.secondary_names = secondary_names 41 | 42 | def merge(self, other): 43 | or_fields = ("content", "syntax", "version") 44 | for field in or_fields: 45 | setattr(self, field, getattr(self, field) or getattr(other, field)) 46 | 47 | if self.secondary_names is None: 48 | self.secondary_names = [] 49 | 50 | self.secondary_names.extend(other.primary_names) 51 | 52 | 53 | class DocsMarkdownConverter(MarkdownConverter): 54 | def __init__(self, url_folder, url_file, **options): 55 | self.url_folder = url_folder 56 | self.url_file = url_file 57 | self.version = None 58 | super().__init__(**options) 59 | 60 | def convert_span(self, el: Tag, text, convert_as_inline): 61 | classes = el.get("class", None) 62 | if classes is None: 63 | return text 64 | 65 | if "optional" in classes: 66 | return f"[{text}]" 67 | 68 | if "ver" in classes: 69 | self.version = text 70 | 71 | return text 72 | 73 | def convert_code(self, el, text, convert_as_inline): 74 | return f"`{text}`" 75 | 76 | def convert_a(self, el, text, convert_as_inline): 77 | href = el.get("href") 78 | if href.startswith("#"): 79 | url = f"{self.url_folder}/{self.url_file}#{href}" 80 | else: 81 | url = f"{self.url_folder}/{href}" 82 | 83 | return f"[{text}]({url})" 84 | 85 | def convert_strong(self, el, text, convert_as_inline): 86 | return f"**{text}**" 87 | 88 | def convert_em(self, el, text, convert_as_inline): 89 | return f"*{text}*" 90 | 91 | def convert_pre(self, el, text, convert_as_inline): 92 | if not text: 93 | return "" 94 | code_language = self.options["code_language"] 95 | 96 | if self.options["code_language_callback"]: 97 | code_language = self.options["code_language_callback"](el) or code_language 98 | 99 | return "\n```%s\n%s\n```\n" % (code_language, text) 100 | 101 | 102 | class TemporaryOptions: 103 | def __init__(self, converter, **options) -> None: 104 | self.converter = converter 105 | self.options = options 106 | self._restore = {} 107 | 108 | def __enter__(self): 109 | for k, v in self.options.items(): 110 | self._restore[k] = self.converter.options[k] 111 | self.converter.options[k] = v 112 | 113 | def __exit__(self, exc_type, exc_value, exc_traceback): 114 | for k, v in self._restore.items(): 115 | self.converter.options[k] = v 116 | 117 | 118 | class Parser: 119 | def __init__(self, base, version, page) -> None: 120 | self.base = base 121 | self.version = version 122 | self.page = page 123 | self.parser = "lxml" 124 | 125 | self.entries = dict() 126 | 127 | with open(f"{self.base}/{self.page}", "r") as f: 128 | self.bs = BeautifulSoup(f.read(), self.parser) 129 | 130 | full_url = DOCS_URL_FMT.format(version) + page 131 | *to_join, url_file = full_url.split("/") 132 | url_folder = "/".join(to_join) 133 | 134 | self.converter = DocsMarkdownConverter( 135 | url_folder=url_folder, 136 | url_file=url_file, 137 | convert=["span", "code", "a", "strong", "em"], 138 | ) 139 | 140 | def md(self, soup, **opt): 141 | with TemporaryOptions(self.converter, **opt): 142 | return self.converter.convert_soup(soup).strip() 143 | 144 | def add_entry(self, entry: Entry): 145 | self.entries[entry.fragment] = entry 146 | 147 | def parse(self): 148 | raise NotImplementedError("Must be implemented by subclass") 149 | 150 | def tag_to_str(self, tag: Tag): 151 | def to_str(tag): 152 | if isinstance(tag, NavigableString): 153 | return str(tag) 154 | elif tag.name == "br": 155 | return "\n" 156 | 157 | content = "" 158 | for child in tag.children: 159 | content += to_str(child) 160 | 161 | return content 162 | 163 | return to_str(tag).strip() or None 164 | 165 | def strip_versioning(self, tag: Tag): 166 | found_tags = tag.find_all("span", class_="ver") 167 | if not found_tags: 168 | return None 169 | 170 | ver = self.tag_to_str(found_tags[0]) 171 | 172 | for found_tag in found_tags: 173 | found_tag.decompose() 174 | 175 | return ver 176 | 177 | def tag_parse(self, parent_tag: Tag): 178 | markup = "
" 179 | found_p = False 180 | syntax = None 181 | tag: Tag = parent_tag 182 | 183 | while tag := tag.next_sibling: 184 | if isinstance(tag, NavigableString): 185 | markup += tag 186 | continue 187 | 188 | elif tag.name == "p": 189 | if not found_p: 190 | markup += str(tag) 191 | found_p = True 192 | else: 193 | break 194 | 195 | elif tag.name == "pre": 196 | _classes = tag.get("class", []) 197 | 198 | if "Syntax" in _classes: 199 | syntax = self.md( 200 | tag, 201 | escape_underscores=False, 202 | escape_asterisks=False, 203 | convert=["span"], 204 | ) 205 | 206 | break 207 | 208 | else: 209 | break 210 | 211 | text = self.md(BeautifulSoup(markup, self.parser)) 212 | return text, syntax, self.converter.version 213 | 214 | def name_splitter(self, name): 215 | if name.startswith("MinIndeMinIndexx"): 216 | print("what") 217 | 218 | splits = [" / ", "\n"] 219 | 220 | temp = [] 221 | for split in splits: 222 | if split in name: 223 | temp.extend(name.split(split)) 224 | break 225 | else: 226 | temp.append(name) 227 | 228 | names = [name.strip() for name in temp if name.strip()] 229 | 230 | return " / ".join(names), names 231 | 232 | 233 | class HeadersParser(Parser): 234 | def __init__( 235 | self, 236 | base, 237 | version, 238 | page, 239 | prefix_mapper=None, 240 | basic_name_check=lambda h, t, p: True, 241 | ignore=lambda h, t, p: False, 242 | ) -> None: 243 | self.prefix_mapper: list = prefix_mapper 244 | self.basic_name_check = basic_name_check 245 | self.ignore = ignore 246 | super().__init__(base, version, page) 247 | 248 | def level_from_header_tag(self, tag: Tag): 249 | return int(tag.name[1]) 250 | 251 | def names_from_header(self, tag: Tag, parents: list): 252 | orig_name, names = self.name_splitter(self.tag_to_str(tag)) 253 | 254 | return orig_name, list( 255 | chain(*[self._names_from_header(name, tag, parents) for name in names]) 256 | ) 257 | 258 | def _names_from_header(self, text: str, tag: Tag, parents: list): 259 | h_level = self.level_from_header_tag(tag) 260 | parent_ids = [p.fragment or p.name for p in parents if p is not None] 261 | names = [] 262 | 263 | try: 264 | add_name = self.basic_name_check(h_level, tag, parent_ids) 265 | except IndexError: 266 | add_name = True 267 | 268 | if add_name: 269 | names.append(text) 270 | 271 | def do_action(action): 272 | if isinstance(action, int): 273 | try: 274 | parent = parents[action] 275 | parent_name = parent.name 276 | names.append(f"{parent_name}{BULLET_SPACED}{text}") 277 | except IndexError: 278 | pass 279 | elif isinstance(action, str): 280 | names.append(action.format(text)) 281 | elif callable(action): 282 | names.append(action(text)) 283 | 284 | if self.prefix_mapper is not None: 285 | for check, action in self.prefix_mapper: 286 | if isinstance(check, int): 287 | if h_level == check: 288 | do_action(action) 289 | break 290 | elif callable(check): 291 | try: 292 | if check(h_level, tag, parent_ids): 293 | do_action(action) 294 | break 295 | except: 296 | pass 297 | 298 | return names 299 | 300 | def process_tag(self, tag: Tag, parents: list): 301 | header_version = self.strip_versioning(tag) 302 | name, primary_names = self.names_from_header(tag, parents) 303 | text, syntax, parsed_version = self.tag_parse(tag) 304 | version = header_version or parsed_version 305 | 306 | fragment = tag.get("id", None) 307 | 308 | # fix for methods in object-like pages, since their method h3 tags 309 | # do not have an id attr, but their previous_sibling div has, we 310 | # just manually check for that instead 311 | # really we should not be using a HeadersParser, and rather 312 | # use something that finds methods by finding divs with ids 313 | # the range(3) is a hack to make it work for v2 Gui.htm 314 | previous = tag 315 | for _ in range(3): 316 | previous = previous.previous_element 317 | if previous.name == "div" and "methodShort" in previous.get("class", []): 318 | fragment = previous.get("id", None) 319 | break 320 | 321 | entry = Entry( 322 | name=name, 323 | primary_names=primary_names, 324 | page=self.page, 325 | content=text, 326 | fragment=fragment, 327 | syntax=syntax, 328 | version=version, 329 | parents=parents[1:], 330 | secondary_names=None, 331 | ) 332 | 333 | self.add_entry(entry) 334 | return entry 335 | 336 | def parse(self): 337 | parents = [None] * 10 338 | 339 | for tag in self.bs.find_all(ANY_HEADER_RE): 340 | ids = tag.get("id", None) 341 | 342 | if ids is not None and "toc" in ids: 343 | continue 344 | 345 | h_level = self.level_from_header_tag(tag) 346 | tag_parents = [e for e in parents[:h_level]] 347 | 348 | try: 349 | ignore = self.ignore(h_level, tag, tag_parents) 350 | except: 351 | ignore = False 352 | 353 | if ignore: 354 | continue 355 | 356 | entry = self.process_tag(tag, tag_parents) 357 | 358 | if entry is not None: 359 | parents[h_level] = entry 360 | 361 | 362 | class TableParser(Parser): 363 | def parse(self): 364 | for tr in self.bs.find_all("tr", id=True): 365 | first = True 366 | for td in tr.find_all("td"): 367 | if first: 368 | first = False 369 | version = self.strip_versioning(td) 370 | orig_name, names = self.name_splitter(self.tag_to_str(td)) 371 | else: 372 | desc = self.md(td) 373 | 374 | fragment = tr.get("id") 375 | 376 | entry = Entry( 377 | name=orig_name, 378 | primary_names=names, 379 | page=self.page, 380 | content=desc, 381 | fragment=fragment, 382 | version=version or self.converter.version, 383 | ) 384 | 385 | self.add_entry(entry) 386 | -------------------------------------------------------------------------------- /docs_service/requirements.txt: -------------------------------------------------------------------------------- 1 | markdownify==0.11.6 2 | beautifulsoup4==4.10.0 3 | aiohttp==3.9.5 4 | asyncpg==0.28.0 5 | rapidfuzz==3.4.0 6 | sanic==23.12.1 7 | tqdm==4.64.1 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging.handlers 3 | import os 4 | 5 | import asyncpg 6 | import coloredlogs 7 | import disnake 8 | from disnake.ext.commands import CommandSyncFlags 9 | 10 | from ace import AceBot 11 | from config import BOT_INTENTS, BOT_TOKEN, LOG_LEVEL, TEST_GUILDS 12 | 13 | 14 | def setup_logger(): 15 | # init first log file 16 | if not os.path.isfile("logs/log.log"): 17 | # we have to make the logs dir before we log to it 18 | if not os.path.exists("logs"): 19 | os.makedirs("logs") 20 | open("logs/log.log", "w+") 21 | 22 | # set logging levels for various libs 23 | logging.getLogger("disnake").setLevel(logging.INFO) 24 | logging.getLogger("websockets").setLevel(logging.INFO) 25 | logging.getLogger("asyncpg").setLevel(logging.INFO) 26 | logging.getLogger("asyncio").setLevel(logging.INFO) 27 | 28 | # we want out logging formatted like this everywhere 29 | fmt = logging.Formatter( 30 | "{asctime} [{levelname}] {name}: {message}", 31 | datefmt="%Y-%m-%d %H:%M:%S", 32 | style="{", 33 | ) 34 | 35 | coloredlogs.install( 36 | level=logging.DEBUG, 37 | fmt="{asctime} [{levelname}] {name}: {message}", 38 | style="{", 39 | level_styles=dict( 40 | debug=dict(color=12), 41 | info=dict(color=15), 42 | warning=dict(bold=True, color=13), 43 | critical=dict(bold=True, color=9), 44 | ), 45 | ) 46 | 47 | file = logging.handlers.TimedRotatingFileHandler( 48 | "data/logs/log.log", when="midnight", encoding="utf-8-sig" 49 | ) 50 | file.setFormatter(fmt) 51 | file.setLevel(logging.INFO) 52 | 53 | # get the __main__ logger and add handlers 54 | root = logging.getLogger() 55 | root.setLevel(LOG_LEVEL) 56 | root.addHandler(file) 57 | 58 | return logging.getLogger(__name__) 59 | 60 | 61 | def setup(): 62 | 63 | # misc. monkey-patching 64 | class Embed(disnake.Embed): 65 | def __init__(self, color=disnake.Color.blue(), **attrs): 66 | attrs["color"] = color 67 | super().__init__(**attrs) 68 | 69 | disnake.Embed = Embed 70 | 71 | def patched_execute(old): 72 | async def new( 73 | self, 74 | query, 75 | args, 76 | limit, 77 | timeout, 78 | return_status=False, 79 | ignore_custom_codec=False, 80 | record_class=None, 81 | ): 82 | log.debug(query) 83 | return await old( 84 | self, 85 | query, 86 | args, 87 | limit, 88 | timeout, 89 | return_status=return_status, 90 | ignore_custom_codec=ignore_custom_codec, 91 | record_class=record_class, 92 | ) 93 | 94 | return new 95 | 96 | asyncpg.Connection._execute = patched_execute(asyncpg.Connection._execute) 97 | 98 | # create allowed mentions 99 | allowed_mentions = disnake.AllowedMentions( 100 | everyone=False, 101 | users=True, 102 | roles=False, 103 | replied_user=True, 104 | ) 105 | 106 | command_sync_flags = CommandSyncFlags( 107 | allow_command_deletion=True, 108 | sync_commands=True, 109 | sync_commands_debug=True, 110 | sync_global_commands=True, 111 | sync_guild_commands=True, 112 | sync_on_cog_actions=False, 113 | ) 114 | 115 | # init bot 116 | log.info("Initializing bot") 117 | bot = AceBot( 118 | loop=loop, 119 | intents=BOT_INTENTS, 120 | allowed_mentions=allowed_mentions, 121 | command_sync_flags=command_sync_flags, 122 | test_guilds=TEST_GUILDS, 123 | ) 124 | 125 | return bot 126 | 127 | 128 | if __name__ == "__main__": 129 | # create folders 130 | for path in ("data", "data/logs", "data/error", "data/feedback", "data/ahk_eval"): 131 | if not os.path.exists(path): 132 | os.makedirs(path) 133 | 134 | log = setup_logger() 135 | loop = asyncio.get_event_loop() 136 | 137 | bot = setup() 138 | bot.run(BOT_TOKEN) 139 | -------------------------------------------------------------------------------- /migrate.sql: -------------------------------------------------------------------------------- 1 | -- ALTER TYPE mod_event_type RENAME VALUE 'MUTE' TO 'TIMEOUT'; 2 | -- ALTER TYPE security_action RENAME VALUE 'MUTE' TO 'TIMEOUT'; 3 | -- ALTER TABLE mod_timer ADD COLUMN completed BOOLEAN NOT NULL DEFAULT FALSE; 4 | -- DROP TABLE docs_param; 5 | -- alter table log add column type command_type not null default 'PREFIX'; 6 | -- should also alter the above one to not have a default value anymore after rows have been set 7 | -- alter table mod_timer drop constraint mod_timer_guild_id_user_id_event_key; 8 | -- also make sure help_claim has correct owner 9 | 10 | DO $$ 11 | BEGIN 12 | IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'mod_event_type') THEN 13 | CREATE TYPE mod_event_type AS ENUM ('BAN', 'TIMEOUT'); 14 | END IF; 15 | 16 | IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'security_action') THEN 17 | CREATE TYPE security_action AS ENUM ('TIMEOUT', 'KICK', 'BAN'); 18 | END IF; 19 | 20 | IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'command_type') THEN 21 | CREATE TYPE command_type AS ENUM ('PREFIX', 'APPLICATION'); 22 | END IF; 23 | END$$; 24 | 25 | -- guild config 26 | CREATE TABLE IF NOT EXISTS config ( 27 | id SERIAL UNIQUE, 28 | guild_id BIGINT UNIQUE NOT NULL, 29 | prefix VARCHAR(8) NULL, 30 | mod_role_id BIGINT NULL 31 | ); 32 | 33 | -- moderation values 34 | CREATE TABLE IF NOT EXISTS mod_config ( 35 | id SERIAL UNIQUE, 36 | guild_id BIGINT UNIQUE NOT NULL, 37 | 38 | log_channel_id BIGINT NULL, 39 | mute_role_id BIGINT NULL, 40 | 41 | spam_action security_action NULL, 42 | spam_count SMALLINT NOT NULL DEFAULT 8, 43 | spam_per SMALLINT NOT NULL DEFAULT 10, 44 | 45 | mention_action security_action NULL, 46 | mention_count SMALLINT NOT NULL DEFAULT 8, 47 | mention_per SMALLINT NOT NULL DEFAULT 16, 48 | 49 | raid BOOLEAN NOT NULL DEFAULT FALSE, 50 | raid_age INTERVAL NULL, 51 | raid_default_avatar BOOLEAN NOT NULL DEFAULT FALSE 52 | ); 53 | 54 | CREATE TABLE IF NOT EXISTS mod_timer ( 55 | id SERIAL UNIQUE, 56 | 57 | guild_id BIGINT NOT NULL, 58 | user_id BIGINT NOT NULL, 59 | mod_id BIGINT NULL, 60 | 61 | event mod_event_type NOT NULL, 62 | 63 | created_at TIMESTAMP NOT NULL, 64 | duration INTERVAL NULL, 65 | 66 | reason TEXT NULL, 67 | userdata JSON NULL, 68 | 69 | completed BOOLEAN NOT NULL DEFAULT FALSE 70 | ); 71 | 72 | -- highlighter languages 73 | CREATE TABLE IF NOT EXISTS highlight_lang ( 74 | id SERIAL UNIQUE, 75 | guild_id BIGINT NOT NULL, 76 | user_id BIGINT NOT NULL DEFAULT 0, 77 | lang VARCHAR(32) NOT NULL, 78 | UNIQUE (guild_id, user_id) 79 | ); 80 | 81 | -- starboard config 82 | CREATE TABLE IF NOT EXISTS starboard ( 83 | id SERIAL UNIQUE, 84 | guild_id BIGINT UNIQUE NOT NULL, 85 | channel_id BIGINT NULL, 86 | locked BOOLEAN NOT NULL DEFAULT FALSE, 87 | threshold SMALLINT NULL, 88 | minimum SMALLINT NULL 89 | ); 90 | 91 | -- starmessage 92 | CREATE TABLE IF NOT EXISTS star_msg ( 93 | id SERIAL UNIQUE, 94 | guild_id BIGINT NOT NULL, 95 | channel_id BIGINT NOT NULL, 96 | user_id BIGINT NOT NULL, 97 | message_id BIGINT UNIQUE NOT NULL, 98 | star_message_id BIGINT NULL, 99 | starred_at TIMESTAMP NOT NULL, 100 | starrer_id BIGINT NOT NULL 101 | ); 102 | 103 | -- starrers 104 | CREATE TABLE IF NOT EXISTS starrers ( 105 | id SERIAL UNIQUE, 106 | star_id INTEGER NOT NULL REFERENCES star_msg (id) ON DELETE CASCADE, 107 | user_id BIGINT NOT NULL, 108 | UNIQUE (star_id, user_id) 109 | ); 110 | 111 | -- fact list 112 | CREATE TABLE IF NOT EXISTS facts ( 113 | id SERIAL UNIQUE, 114 | content TEXT NOT NULL 115 | ); 116 | 117 | -- tag list 118 | CREATE TABLE IF NOT EXISTS tag ( 119 | id SERIAL UNIQUE, 120 | name VARCHAR(32) NOT NULL, 121 | alias VARCHAR(32) NULL, 122 | guild_id BIGINT NOT NULL, 123 | user_id BIGINT NOT NULL, 124 | uses INT NOT NULL DEFAULT 0, 125 | created_at TIMESTAMP NOT NULL, 126 | edited_at TIMESTAMP NULL, 127 | viewed_at TIMESTAMP NULL, 128 | content VARCHAR(2000) NOT NULL 129 | ); 130 | 131 | -- command log 132 | CREATE TABLE IF NOT EXISTS log ( 133 | id SERIAL UNIQUE, 134 | guild_id BIGINT NOT NULL, 135 | channel_id BIGINT NOT NULL, 136 | user_id BIGINT NOT NULL, 137 | timestamp TIMESTAMP NOT NULL, 138 | command TEXT NOT NULL, 139 | type command_type NOT NULL 140 | ); 141 | 142 | CREATE TABLE IF NOT EXISTS remind ( 143 | id SERIAL UNIQUE, 144 | guild_id BIGINT NOT NULL, 145 | channel_id BIGINT NOT NULL, 146 | user_id BIGINT NOT NULL, 147 | message_id BIGINT NULL, 148 | made_on TIMESTAMP NOT NULL, 149 | remind_on TIMESTAMP NOT NULL, 150 | message TEXT 151 | ); 152 | 153 | CREATE TABLE IF NOT EXISTS welcome ( 154 | id SERIAL UNIQUE, 155 | guild_id BIGINT UNIQUE NOT NULL, 156 | channel_id BIGINT, 157 | enabled BOOLEAN NOT NULL DEFAULT TRUE, 158 | content VARCHAR(1024) 159 | ); 160 | 161 | -- docs stuff 162 | CREATE TABLE IF NOT EXISTS role ( 163 | id SERIAL UNIQUE, 164 | guild_id BIGINT UNIQUE NOT NULL, 165 | channel_id BIGINT NULL, 166 | message_ids BIGINT[8] NOT NULL DEFAULT ARRAY[]::BIGINT[8], 167 | selectors INTEGER[8] NOT NULL DEFAULT ARRAY[]::INTEGER[8] 168 | ); 169 | 170 | CREATE TABLE IF NOT EXISTS role_selector ( 171 | id SERIAL UNIQUE, 172 | guild_id BIGINT NOT NULL, 173 | title VARCHAR(256) NOT NULL, 174 | description VARCHAR(1024) NULL, 175 | icon VARCHAR(256) NULL, 176 | inline BOOLEAN NOT NULL DEFAULT TRUE, 177 | roles INTEGER[25] NOT NULL DEFAULT ARRAY[]::INTEGER[25] 178 | ); 179 | 180 | CREATE TABLE IF NOT EXISTS role_entry ( 181 | id SERIAL UNIQUE, 182 | guild_id BIGINT NOT NULL, 183 | role_id BIGINT UNIQUE NOT NULL, 184 | emoji VARCHAR(56) NOT NULL, 185 | name VARCHAR(199) NOT NULL, 186 | description VARCHAR(1024) NOT NULL 187 | ); 188 | 189 | 190 | CREATE TABLE IF NOT EXISTS trivia ( 191 | id SERIAL UNIQUE, 192 | guild_id BIGINT NOT NULL, 193 | user_id BIGINT NOT NULL, 194 | correct_count INT NOT NULL DEFAULT 0, 195 | wrong_count INT NOT NULL DEFAULT 0, 196 | score BIGINT NOT NULL DEFAULT 0, 197 | UNIQUE (guild_id, user_id) 198 | ); 199 | 200 | CREATE TABLE IF NOT EXISTS trivia_stats ( 201 | id SERIAL UNIQUE, 202 | guild_id BIGINT NOT NULL, 203 | user_id BIGINT NOT NULL, 204 | timestamp TIMESTAMP NOT NULL, 205 | question_hash BIGINT NOT NULL, 206 | result BOOL NOT NULL 207 | ); 208 | 209 | CREATE TABLE IF NOT EXISTS help_claim ( 210 | guild_id BIGINT NOT NULL, 211 | channel_id BIGINT NOT NULL, 212 | user_id BIGINT NOT NULL, 213 | UNIQUE (guild_id, channel_id) 214 | ); -------------------------------------------------------------------------------- /neural/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.2 2 | 3 | COPY requirements.txt / 4 | RUN pip install --no-cache-dir -r requirements.txt 5 | 6 | COPY api.py . 7 | COPY model.py . 8 | COPY text_processor.py . 9 | COPY torch_config.py . 10 | COPY model.pth . 11 | COPY embeddings embeddings 12 | 13 | CMD ["python3", "-u", "api.py"] 14 | -------------------------------------------------------------------------------- /neural/README.md: -------------------------------------------------------------------------------- 1 | # Game script prediction model architecture 2 | 3 | To install PyTorch compiled for CUDA, see: 4 | https://pytorch.org/ 5 | 6 | ``` 7 | TextCNN( 8 | (embedding): Embedding(6442, 100, padding_idx=0) 9 | (convs): ModuleList( 10 | (0): Conv2d(1, 64, kernel_size=(2, 100), stride=(1, 1)) 11 | (1): Conv2d(1, 64, kernel_size=(3, 100), stride=(1, 1)) 12 | ) 13 | (dropout): Dropout(p=0.5, inplace=False) 14 | (fc): Linear(in_features=128, out_features=1, bias=True) 15 | (sigmoid): Sigmoid() 16 | ) 17 | Trainable params: 676457 18 | ``` -------------------------------------------------------------------------------- /neural/api.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from sanic import Request, Sanic 5 | from sanic.response import HTTPResponse, json 6 | from text_processor import TextProcessor 7 | from torch_config import EMBEDDINGS_DIR 8 | from torchtext.data.utils import get_tokenizer 9 | 10 | from model import TextCNN 11 | 12 | app = Sanic("torch_api") 13 | 14 | embeddings = torch.load(f"{EMBEDDINGS_DIR}/vectors.pkl") 15 | 16 | model = TextCNN( 17 | embeddings=embeddings, 18 | n_filters=64, 19 | filter_sizes=[2, 3], 20 | dropout=0.0, 21 | ) 22 | 23 | device = torch.device("cpu") 24 | model.load_state_dict(torch.load("model.pth", map_location=device)) 25 | model.eval() 26 | 27 | text_processing = TextProcessor( 28 | wti=pickle.load(open(f"{EMBEDDINGS_DIR}/wti.pkl", "rb")), 29 | tokenizer=get_tokenizer("basic_english"), 30 | standardize=True, 31 | min_len=3, 32 | ) 33 | 34 | 35 | @app.post("/game") 36 | async def game(request: Request): 37 | q = request.form.get("q", None) 38 | 39 | if q is None: 40 | return HTTPResponse(status=400) 41 | 42 | tokens = text_processing.process(q) 43 | x = torch.unsqueeze(tokens, dim=0) 44 | 45 | pred = model(x) 46 | pred = torch.squeeze(pred).item() 47 | 48 | # TODO: add logging 49 | 50 | return json(dict(p=pred)) 51 | 52 | 53 | if __name__ == "__main__": 54 | app.run(host="0.0.0.0", port=7000) 55 | -------------------------------------------------------------------------------- /neural/data_fetcher.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import asyncpg 5 | from torch_config import CORPUS_DIR, DB_BIND 6 | 7 | 8 | async def fetch_data(): 9 | db = await asyncpg.connect(DB_BIND) 10 | rs = await db.fetch("SELECT * FROM corpus WHERE truth IN (0, 1)") 11 | 12 | for r in rs: 13 | _id = r.get("id") 14 | 15 | if not os.path.isdir(CORPUS_DIR): 16 | os.mkdir(CORPUS_DIR) 17 | 18 | label = "1" if r.get("truth") else "0" 19 | path = f"{CORPUS_DIR}/{label}" 20 | 21 | if not os.path.isdir(path): 22 | os.mkdir(path) 23 | 24 | file_path = f"{path}/{_id}.txt" 25 | # if not os.path.isfile(file_path): 26 | with open(file_path, "w", encoding="utf-8") as f: 27 | f.write(" ".join(r.get("data"))) 28 | 29 | 30 | if __name__ == "__main__": 31 | loop = asyncio.get_event_loop() 32 | task = loop.create_task(fetch_data()) 33 | loop.run_until_complete(task) 34 | -------------------------------------------------------------------------------- /neural/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def pad(tokens, seq_len, value=0): 9 | orig_len = tokens.size(0) 10 | 11 | # if too long, we need to strip 12 | if orig_len > seq_len: 13 | return tokens[:seq_len] 14 | 15 | if seq_len == orig_len: 16 | return tokens 17 | 18 | return F.pad(input=tokens, pad=(0, seq_len - orig_len), value=value) 19 | 20 | 21 | class TextDataset(Dataset): 22 | def __init__(self, folder, processor): 23 | self.texts = list() 24 | self.labels = list() 25 | self.tokenized = dict() 26 | 27 | self.processor = processor 28 | 29 | # load corpus into .texts and .labels 30 | for dir, subdir, files in os.walk(folder): 31 | dir = dir.replace(r"\\", "/") 32 | 33 | for file in files: 34 | with open(f"{dir}/{file}", "r", encoding="utf-8") as f: 35 | text = f.read() 36 | 37 | if not len(text): 38 | continue 39 | 40 | self.texts.append(text) 41 | 42 | label = int(dir[-1]) 43 | self.labels.append(label) 44 | 45 | def get_tokens(self, item): 46 | tokens = self.tokenized.get(item, None) 47 | 48 | if tokens is None: 49 | tokens = self.processor.process(self.texts[item]) 50 | self.tokenized[item] = tokens 51 | 52 | return tokens 53 | 54 | def __len__(self): 55 | return len(self.texts) 56 | 57 | def __getitem__(self, item): 58 | # get a data point from the loader 59 | 60 | # get the tokens for this item 61 | tokens = self.get_tokens(item) 62 | 63 | # get the label for this item 64 | label = self.labels[item] 65 | 66 | # return tokens and the labels as a FloatTensor 67 | return tokens, torch.tensor(label, dtype=torch.float32) 68 | 69 | 70 | class Sequencer: 71 | def __init__(self, sequence_len): 72 | self.sequence_len = sequence_len 73 | 74 | def __call__(self, batch): 75 | # sort the batch from longest to shortest sentences 76 | # not necessarily necessary for convnet but is for LSTM layer with padding and packing 77 | # batch = sorted(batch, key=lambda x: x[0].size(), reverse=True) 78 | 79 | # and pad and stack into a LongTensor 80 | tokens = torch.stack([pad(token_list, self.sequence_len) for token_list, _ in batch]) 81 | 82 | # strip labels 83 | labels = torch.tensor([label for _, label in batch], dtype=torch.float32) 84 | 85 | return tokens, labels 86 | -------------------------------------------------------------------------------- /neural/make_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import Counter 4 | 5 | import torch 6 | from torch_config import CORPUS_DIR, EMBEDDINGS_DIR, GLOVE_DIR 7 | from torchtext.data.utils import get_tokenizer 8 | from tqdm import tqdm 9 | 10 | tokenizer = get_tokenizer("basic_english") 11 | counter = Counter() 12 | 13 | for dir, subdir, files in os.walk(CORPUS_DIR): 14 | dir = dir.replace(r"\\", "/") 15 | 16 | print("\nReading:", dir) 17 | for file in tqdm(files): 18 | with open(f"{dir}/{file}", "r", encoding="utf-8") as f: 19 | text = f.read() 20 | 21 | if not len(text): 22 | continue 23 | 24 | counter.update(tokenizer(text)) 25 | 26 | glove_wti: dict = pickle.load(open(f"{GLOVE_DIR}/word2idx.pkl", "rb")) 27 | glove_vectors: torch.Tensor = torch.load(f"{GLOVE_DIR}/vectors.pkl") 28 | 29 | embed_idx = 0 30 | embed_wti = dict() 31 | embed_vectors = [] 32 | 33 | print("Copying relevant embeddings...") 34 | for word in counter.keys(): 35 | glove_idx = glove_wti.get(word, None) 36 | 37 | if glove_idx is None: 38 | continue 39 | 40 | embed_wti[word] = embed_idx 41 | embed_vectors.append(glove_vectors[glove_idx]) 42 | embed_idx += 1 43 | 44 | if not os.path.exists(EMBEDDINGS_DIR): 45 | os.mkdir(EMBEDDINGS_DIR) 46 | 47 | pickle.dump(embed_wti, open(f"{EMBEDDINGS_DIR}/wti.pkl", "wb")) 48 | torch.save(torch.stack(embed_vectors), f"{EMBEDDINGS_DIR}/vectors.pkl") 49 | -------------------------------------------------------------------------------- /neural/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TextCNN(nn.Module): 7 | def __init__(self, embeddings, n_filters, filter_sizes, dropout): 8 | super().__init__() 9 | 10 | num_embeddings = embeddings.size(0) 11 | embedding_dim = embeddings.size(1) 12 | 13 | self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0, sparse=False) 14 | self.embedding.load_state_dict(dict(weight=embeddings)) 15 | # self.embedding.weight.requires_grad = False 16 | 17 | self.convs = nn.ModuleList( 18 | [ 19 | nn.Conv2d( 20 | in_channels=1, 21 | out_channels=n_filters, 22 | kernel_size=(fs, embedding_dim), 23 | ) 24 | for fs in filter_sizes 25 | ] 26 | ) 27 | 28 | self.dropout = nn.Dropout(dropout) 29 | self.fc = nn.Linear(len(filter_sizes) * n_filters, 1) 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, x): 33 | # embedded = [batch_size, sequence_len, embedding_size] -> [32, 380, 100] 34 | embedded = self.embedding(x) 35 | embedded = embedded.unsqueeze(1) 36 | 37 | conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs] 38 | pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved] 39 | 40 | x = self.dropout(torch.cat(pooled, dim=1)) 41 | 42 | x = self.fc(x) 43 | return self.sigmoid(x) 44 | -------------------------------------------------------------------------------- /neural/process_glove.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from torch_config import GLOVE_DIR 5 | from tqdm import tqdm 6 | 7 | idx = 0 8 | word2idx = {} 9 | glove_size = 1193517 10 | vectors = torch.empty((glove_size, 100), dtype=torch.float32) 11 | 12 | with open(f"{GLOVE_DIR}/glove.twitter.27B.100d.txt", "rb") as f: 13 | for l in tqdm(f, total=glove_size): 14 | line = l.decode().split() 15 | 16 | if len(line) != 101: 17 | continue 18 | 19 | word = line[0] 20 | word2idx[word] = idx 21 | 22 | c = line[1:] 23 | vector = torch.tensor([float(x) for x in c], dtype=torch.float32) 24 | vectors[idx] = vector 25 | 26 | idx += 1 27 | 28 | pickle.dump(word2idx, open(f"{GLOVE_DIR}/word2idx.pkl", "wb")) 29 | torch.save(vectors, f"{GLOVE_DIR}/vectors.pkl") 30 | -------------------------------------------------------------------------------- /neural/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<2 2 | torch==2.1.0 3 | torchtext==0.16.0 4 | sanic==23.12.1 5 | unidecode==1.1.1 6 | -------------------------------------------------------------------------------- /neural/text_processor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from unidecode import unidecode 7 | 8 | 9 | class TextProcessor: 10 | def __init__(self, wti, tokenizer, standardize=True, min_len=None): 11 | self.wti = wti 12 | self.tokenizer = tokenizer 13 | self.do_standardize = standardize 14 | self.min_len = min_len 15 | 16 | def process(self, text): 17 | # converts a string to a list of word indices, using the tokenizer and "word to index" map 18 | text = self.standardize(text) if self.do_standardize else text 19 | tensor = torch.LongTensor([self.wti.get(word, 1) for word in self.tokenizer(text)]) 20 | 21 | if self.min_len is not None: 22 | tensor_len = tensor.size(0) 23 | if tensor_len < self.min_len: 24 | tensor = F.pad( 25 | input=tensor, pad=(0, self.min_len - tensor_len), value=1 26 | ) # we need to pad with UNK 27 | 28 | return tensor 29 | 30 | @staticmethod 31 | def standardize(s: str): 32 | # make lowercase 33 | s = s.lower() 34 | 35 | # remove urls 36 | s = re.sub(r"^https?:\/\/.*[\r\n]*", "", s, flags=re.MULTILINE) 37 | 38 | # remove diacritics 39 | s = unidecode(s) 40 | 41 | # remove numbers 42 | s = re.sub(f"[{string.digits}\.,]", " ", s) 43 | 44 | # remove punctuation 45 | s = re.sub(f"[{re.escape(string.punctuation)}]", "", s) 46 | 47 | # condense whitespaces 48 | s = re.sub(r"\s+", " ", s) 49 | 50 | return s.lower().strip() 51 | -------------------------------------------------------------------------------- /neural/train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | 5 | import torch 6 | from dataset import Sequencer, TextDataset 7 | from text_processor import TextProcessor 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torch_config import CORPUS_DIR, EMBEDDINGS_DIR 11 | from torchtext.data.utils import get_tokenizer 12 | from tqdm import tqdm 13 | 14 | from model import TextCNN 15 | 16 | DATA_SPLIT = 0.75 17 | SEQUENCE_LEN = 380 18 | 19 | 20 | def main(): 21 | device = torch.device("cuda") 22 | 23 | embedding_vectors = torch.load(f"{EMBEDDINGS_DIR}/vectors.pkl") 24 | 25 | text_processor = TextProcessor( 26 | wti=pickle.load(open(f"{EMBEDDINGS_DIR}/wti.pkl", "rb")), 27 | tokenizer=get_tokenizer("basic_english"), 28 | standardize=True, 29 | min_len=3, 30 | ) 31 | 32 | dataset = TextDataset(CORPUS_DIR, text_processor) 33 | 34 | # split into training and test set 35 | # TODO: fix this splitting sometimes failing when corpus size changes 36 | train_set, test_set = torch.utils.data.random_split( 37 | dataset, 38 | [int(len(dataset) * DATA_SPLIT), int(len(dataset) * (1.0 - DATA_SPLIT))], 39 | ) 40 | 41 | # count number of samples in each class 42 | class_count = [0, 0] 43 | for data, label in dataset: 44 | class_count[int(label.item())] += 1 45 | 46 | # get relative weights for classes 47 | _sum = sum(class_count) 48 | class_count[0] /= _sum 49 | class_count[1] /= _sum 50 | 51 | # reverse the weights since we're getting the inverse for the sampler 52 | class_count = list(reversed(class_count)) 53 | 54 | # set weight for every sample 55 | weights = [class_count[int(x[1].item())] for x in train_set] 56 | 57 | # weighted sampler 58 | sampler = torch.utils.data.WeightedRandomSampler( 59 | weights=weights, num_samples=len(train_set), replacement=True 60 | ) 61 | 62 | train_loader = DataLoader( 63 | dataset=train_set, 64 | batch_size=32, 65 | collate_fn=Sequencer(SEQUENCE_LEN), 66 | sampler=sampler, 67 | ) 68 | 69 | test_loader = DataLoader(dataset=test_set, batch_size=32, collate_fn=Sequencer(SEQUENCE_LEN)) 70 | 71 | # number of filters in each convolutional filter 72 | N_FILTERS = 64 73 | 74 | # sizes and number of convolutional layers 75 | FILTER_SIZES = [2, 3] 76 | 77 | # dropout for between conv and dense layers 78 | DROPOUT = 0.5 79 | 80 | model = TextCNN( 81 | embeddings=embedding_vectors, 82 | n_filters=N_FILTERS, 83 | filter_sizes=FILTER_SIZES, 84 | dropout=DROPOUT, 85 | ).to(device) 86 | 87 | print(model) 88 | print( 89 | "Trainable params:", 90 | sum(p.numel() for p in model.parameters() if p.requires_grad), 91 | ) 92 | 93 | criterion = nn.BCELoss() 94 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 95 | 96 | EPOCHS = 12 97 | 98 | best_acc = 0.0 99 | 100 | # training loop 101 | for epoch in range(EPOCHS): 102 | print("Epoch", epoch + 1) 103 | 104 | for i, data in tqdm(enumerate(train_loader), total=len(train_loader)): 105 | # get word indices vector and corresponding labels 106 | x, labels = data 107 | 108 | # send to device 109 | x = x.to(device) 110 | labels = labels.to(device) 111 | 112 | # make predictions 113 | predictions = model(x).squeeze() 114 | 115 | # calculate loss 116 | loss = criterion(predictions, labels) 117 | 118 | # learning stuff... 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | # evaluate 124 | with torch.no_grad(): 125 | model.eval() 126 | 127 | correct = 0 128 | wrong = 0 129 | m = [[0, 0], [0, 0]] 130 | 131 | for data in test_loader: 132 | x, label = data 133 | x = x.to(device) 134 | 135 | predictions = model(x).squeeze() 136 | 137 | for truth, prediction in zip(label, predictions): 138 | y = int(truth.item()) 139 | y_pred = 1 if prediction.item() > 0.5 else 0 140 | 141 | m[y][y_pred] += 1 142 | 143 | if y == y_pred: 144 | correct += 1 145 | else: 146 | wrong += 1 147 | 148 | model.train() 149 | 150 | acc = correct / (correct + wrong) 151 | if acc > best_acc: 152 | best_acc = acc 153 | for file in glob.glob("models/model_*.pth"): 154 | os.remove(file) 155 | torch.save(model.state_dict(), f"models/state_{epoch}.pth") 156 | 157 | print() 158 | print("Correct:", f"{correct}/{correct + wrong}", "Accuracy:", acc) 159 | print("[[TN, FP], [FN, TP]]") 160 | print(m) 161 | print() 162 | 163 | # put into evaluation mode 164 | model.eval() 165 | 166 | text_processor.do_standardize = True 167 | 168 | with torch.no_grad(): 169 | while True: 170 | text = input("Prompt: ") 171 | x = text_processor.process(text) 172 | x = torch.tensor(x).unsqueeze(dim=0) 173 | print(model(x.to(device)).squeeze()) 174 | 175 | 176 | if __name__ == "__main__": 177 | if not os.path.isdir("models"): 178 | os.mkdir("models") 179 | main() 180 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | target-version = ["py311"] 4 | 5 | [tool.isort] 6 | profile = "black" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | disnake==2.9.1 2 | asyncpg==0.28.0 3 | # rapidfuzz==1.9.1 4 | rapidfuzz==3.4.0 5 | parsedatetime==2.6 6 | python-Levenshtein==0.12.2 7 | lxml==4.9.3 8 | beautifulsoup4==4.10.0 9 | tabulate==0.8.9 10 | emoji==1.6.1 11 | pygit2==1.13.1 12 | psutil==5.8.0 13 | Unidecode==1.3.2 14 | colorama==0.4.4 15 | coloredlogs==15.0.1 16 | -------------------------------------------------------------------------------- /utils/commanderrorlogic.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from datetime import datetime 3 | from pprint import saferepr 4 | 5 | import disnake 6 | from disnake.ext import commands 7 | 8 | 9 | class CommandErrorLogic: 10 | def __init__(self, ctx, exc): 11 | self.bot = ctx.bot 12 | self.ctx = ctx 13 | self.exc = exc 14 | 15 | self.embed = None 16 | self.save = False 17 | 18 | async def __aenter__(self): 19 | return self 20 | 21 | async def __aexit__(self, exc_type, exc_val, exc_tb): 22 | # if an error was never set, do nothing 23 | if self.embed is None: 24 | return 25 | 26 | # first try to send error 27 | try: 28 | ctx = self.ctx 29 | e = self.embed 30 | 31 | extra = dict() 32 | if isinstance(ctx, commands.Context): 33 | perms = ctx.perms 34 | else: 35 | perms = ctx.permissions 36 | extra["ephemeral"] = True 37 | 38 | if perms.embed_links: 39 | if self.save: 40 | e.description += self.support_text(True) 41 | await ctx.send(embed=e, **extra) 42 | else: 43 | content = str() 44 | if isinstance(e.title, str): 45 | content += e.title 46 | elif isinstance(e.author.name, str): 47 | content += e.author.name 48 | 49 | if isinstance(e.description, str) and len(e.description): 50 | content += "\n\n" + e.description 51 | 52 | if self.save: 53 | content += self.support_text(False) 54 | 55 | await ctx.send(content, **extra) 56 | 57 | except disnake.HTTPException: 58 | pass 59 | 60 | # after doing that, save and raise if it's an oops 61 | finally: 62 | if self.save: 63 | self.save_error() 64 | raise self.exc 65 | 66 | @staticmethod 67 | def new_embed(**kwargs): 68 | return disnake.Embed(color=0x36393E, **kwargs) 69 | 70 | def support_text(self, embeddable): 71 | support_link = self.bot.support_link 72 | 73 | content = "\n\nYou can join the support server " 74 | 75 | if embeddable: 76 | return content + "[here]({0})!".format(support_link) 77 | else: 78 | return content + "here: <{0}>".format(support_link) 79 | 80 | def set(self, **kwargs): 81 | self.embed = self.new_embed(**kwargs) 82 | 83 | def oops(self): 84 | desc = ( 85 | "An exception occured while processing the command.\n" 86 | "My developer has been notified and the issue will hopefully be fixed soon!" 87 | ) 88 | 89 | e = self.new_embed(description=desc) 90 | e.set_author(name="Oops!", icon_url=self.bot.user.display_avatar.url) 91 | 92 | self.save = True 93 | self.embed = e 94 | 95 | def save_error(self): 96 | ctx = self.ctx 97 | exc = self.exc 98 | 99 | timestamp = str(datetime.utcnow()).split(".")[0].replace(" ", "_").replace(":", "") 100 | filename = str(ctx.message.id) + "_" + timestamp + ".error" 101 | 102 | try: 103 | raise exc 104 | except: 105 | tb = traceback.format_exc() 106 | 107 | content = ( 108 | "{0.stamp}\n\nMESSAGE CONTENT:\n{0.message.content}\n\n" 109 | "COMMAND: {0.command.qualified_name}\nARGS: {1}\nKWARGS: {2}\n\n{3}" 110 | ).format(ctx, saferepr(ctx.args[2:]), saferepr(ctx.kwargs), tb) 111 | 112 | with open("data/error/{0}".format(filename), "w", encoding="utf-8-sig") as f: 113 | f.write(content) 114 | -------------------------------------------------------------------------------- /utils/configtable.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | log = logging.getLogger(__name__) 5 | 6 | 7 | class ConfigTableRecord(object): 8 | _data = dict() 9 | 10 | def __init__(self, config, record): 11 | self._config = config 12 | self._data = dict() 13 | self._dirty = set() 14 | 15 | for key, value in record.items(): 16 | self._data[key] = value 17 | 18 | def __getattr__(self, key): 19 | if key in self._data: 20 | return self._data[key] 21 | 22 | def __setattr__(self, key, value): 23 | if key in self._data: 24 | self.set(key, value) 25 | else: 26 | self.__dict__[key] = value 27 | 28 | def _build_dirty(self, start_at=1): 29 | return ", ".join( 30 | "{} = ${}".format(key, idx + start_at) for idx, key in enumerate(self._dirty) 31 | ) 32 | 33 | def _set_dirty(self, key): 34 | if key not in self._data: 35 | raise AttributeError( 36 | "Attempted to set key {} to dirty, but it does not exist".format(key) 37 | ) 38 | self._dirty.add(key) 39 | 40 | def _clear_dirty(self): 41 | self._dirty.clear() 42 | 43 | def get(self, key): 44 | if key in self._data: 45 | return self._data[key] 46 | else: 47 | raise AttributeError("Key '{}' not defined in this table.".format(key)) 48 | 49 | def set(self, key, value): 50 | if key not in self._data: 51 | raise AttributeError("Key '{}' not defined in this table.".format(key)) 52 | 53 | self._data[key] = value 54 | self._set_dirty(key) 55 | 56 | async def update(self, **kwargs): 57 | for key, val in kwargs.items(): 58 | self.set(key, val) 59 | 60 | if not self._dirty: 61 | raise ValueError("No values dirty for table {}".format(self._config.table)) 62 | 63 | query = "UPDATE {} SET {} WHERE {}".format( 64 | self._config.table, 65 | self._build_dirty(len(self._config.primary) + 1), 66 | self._config.build_predicate(), 67 | ) 68 | 69 | keys = tuple(self._data[primary] for primary in self._config.primary) 70 | values = tuple(self._data[key] for key in self._dirty) 71 | 72 | await self._config.bot.db.execute(query, *keys, *values) 73 | 74 | self._clear_dirty() 75 | 76 | 77 | class ConfigTable: 78 | def __init__(self, bot, table, primary, record_class=None): 79 | record_class = record_class or ConfigTableRecord 80 | 81 | if record_class is not ConfigTableRecord and not issubclass( 82 | record_class, ConfigTableRecord 83 | ): 84 | raise TypeError("entry_class must inherit from ConfigTableEntry.") 85 | 86 | if isinstance(primary, str): 87 | primary = (primary,) 88 | elif not isinstance(primary, tuple): 89 | raise TypeError("Primary keys must be tuple or string.") 90 | 91 | self.bot = bot 92 | self.table = table 93 | self.primary = primary 94 | self.entries = dict() 95 | 96 | self._record_class = record_class 97 | self._lock = asyncio.Lock() 98 | self._non_existent = set() 99 | 100 | log.debug("Constructed ConfigTable for table %s with keys %s", table, primary) 101 | 102 | def build_predicate(self, start_at=1): 103 | return " AND ".join( 104 | "{} = ${}".format(key, idx + start_at) for idx, key in enumerate(self.primary) 105 | ) 106 | 107 | def get_keys_from_record(self, record): 108 | return tuple(record.get(primary) for primary in self.primary) 109 | 110 | @property 111 | def _insert_query(self): 112 | return "INSERT INTO {} ({}) VALUES ({})".format( 113 | self.table, 114 | ", ".join(self.primary), 115 | ", ".join("${}".format(idx + 1) for idx, _ in enumerate(self.primary)), 116 | ) 117 | 118 | async def insert_record(self, record, keys=None): 119 | keys = keys or self.get_keys_from_record(record) 120 | 121 | if keys in self._non_existent: 122 | self._non_existent.remove(keys) 123 | 124 | log.debug("Inserting record with keys %s for table %s", keys, self.table) 125 | 126 | entry = self._record_class(self, record) 127 | self.entries[keys] = entry 128 | 129 | return entry 130 | 131 | async def get_entry(self, *keys, construct=True): 132 | keys = tuple(keys) 133 | 134 | for key in keys: 135 | if not isinstance(key, int): 136 | raise TypeError("Primary key must be int.") 137 | 138 | if not construct and keys in self._non_existent: 139 | return None 140 | 141 | async with self._lock: 142 | if keys in self.entries: 143 | return self.entries[keys] 144 | 145 | get_query = "SELECT * FROM {} WHERE ".format(self.table) + self.build_predicate() 146 | 147 | record = await self.bot.db.fetchrow(get_query, *keys) 148 | 149 | if record is None: 150 | if not construct: 151 | self._non_existent.add(keys) 152 | return None 153 | elif keys in self._non_existent: 154 | self._non_existent.remove(keys) 155 | 156 | await self.bot.db.execute(self._insert_query, *keys) 157 | record = await self.bot.db.fetchrow(get_query, *keys) 158 | 159 | return await self.insert_record(record, keys=keys) 160 | 161 | def has_entry(self, *keys): 162 | return tuple(keys) in self.entries 163 | 164 | async def clear_entry(self, *keys): 165 | """Returns True if key(s) found in entries dict.""" 166 | 167 | keys = tuple(keys) 168 | 169 | async with self._lock: 170 | if keys in self._non_existent: 171 | log.info("Clearing non-existent entry %s for table %s", keys, self.table) 172 | self._non_existent.remove(keys) 173 | 174 | removed = bool(self.entries.pop(keys, False)) 175 | 176 | if removed: 177 | log.info("Clearing entry %s for table %s", keys, self.table) 178 | 179 | return removed 180 | -------------------------------------------------------------------------------- /utils/context.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import TYPE_CHECKING 3 | 4 | import disnake 5 | from disnake.ext import commands 6 | 7 | from utils.string import po 8 | from utils.time import pretty_datetime 9 | 10 | if TYPE_CHECKING: 11 | from ace import AceBot 12 | 13 | STATIC_PERMS = ("add_reactions", "manage_messages", "embed_links") 14 | PROMPT_REQUIRED_PERMS = ("embed_links",) 15 | PROMPT_EMOJIS = ("\N{WHITE HEAVY CHECK MARK}", "\N{CROSS MARK}") 16 | 17 | 18 | async def is_mod_pred(ctx): 19 | return await ctx.is_mod() 20 | 21 | 22 | def is_mod(): 23 | return commands.check(is_mod_pred) 24 | 25 | 26 | async def can_prompt_pred(ctx): 27 | perms = ctx.perms 28 | missing_perms = list(perm for perm in PROMPT_REQUIRED_PERMS if not getattr(perms, perm)) 29 | 30 | if not missing_perms: 31 | return True 32 | 33 | raise commands.BotMissingPermissions(missing_perms) 34 | 35 | 36 | def can_prompt(): 37 | return commands.check(can_prompt_pred) 38 | 39 | 40 | class PromptView(disnake.ui.View): 41 | def __init__(self, check_user, *args, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | 44 | self.check_user = check_user 45 | self.message = None 46 | 47 | self.result = False 48 | self.event = asyncio.Event() 49 | 50 | async def finish(self, message: disnake.Message): 51 | self.event.set() 52 | 53 | if message is not None: 54 | try: 55 | await message.delete() 56 | except disnake.HTTPException: 57 | pass 58 | 59 | async def wait(self): 60 | await self.event.wait() 61 | return self.result 62 | 63 | async def on_timeout(self) -> None: 64 | await self.finish(self.message) 65 | 66 | async def interaction_check(self, interaction: disnake.MessageInteraction): 67 | return interaction.author == self.check_user 68 | 69 | @disnake.ui.button( 70 | label="Continue", 71 | emoji="\N{White Heavy Check Mark}", 72 | style=disnake.ButtonStyle.primary, 73 | ) 74 | async def yes(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 75 | self.result = True 76 | await self.finish(inter.message) 77 | 78 | @disnake.ui.button(label="Abort", emoji="\N{CROSS MARK}", style=disnake.ButtonStyle.secondary) 79 | async def no(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 80 | self.result = False 81 | await self.finish(inter.message) 82 | 83 | 84 | class AceContext(commands.Context): 85 | def __init__(self, **kwargs): 86 | super().__init__(**kwargs) 87 | self.bot: "AceBot" 88 | 89 | @property 90 | def db(self): 91 | return self.bot.db 92 | 93 | @property 94 | def http(self): 95 | return self.bot.aiohttp 96 | 97 | @property 98 | def perms(self): 99 | return self.channel.permissions_for(self.guild.me) 100 | 101 | @property 102 | def pretty(self): 103 | return "{0.display_name} ({0.id}) in {1.name} ({1.id})".format(self.author, self.guild) 104 | 105 | @property 106 | def stamp(self): 107 | return "TIME: {}\nGUILD: {}\nCHANNEL: #{}\nAUTHOR: {}\nMESSAGE ID: {}".format( 108 | pretty_datetime(self.message.created_at), 109 | po(self.guild), 110 | po(self.channel), 111 | po(self.author), 112 | str(self.message.id), 113 | ) 114 | 115 | async def is_mod(self, member=None): 116 | """Check if invoker or member has bot moderator rights.""" 117 | 118 | member = member or self.author 119 | 120 | # always allow bot owner 121 | if member.id == self.bot.owner_id: 122 | return True 123 | 124 | # true if member has administrator perms in this channel 125 | if self.channel.permissions_for(member).administrator: 126 | return True 127 | 128 | # only last way member can be mod if they're in the moderator role 129 | gc = await self.bot.config.get_entry(member.guild.id) 130 | 131 | # false if not set 132 | if gc.mod_role_id is None: 133 | return False 134 | 135 | # if set, see if author has this role 136 | 137 | return bool(disnake.utils.get(member.roles, id=gc.mod_role_id)) 138 | 139 | async def send_help(self, command=None): 140 | """Convenience method for sending help.""" 141 | 142 | perms = self.perms 143 | missing_perms = list(perm for perm in STATIC_PERMS if not getattr(perms, perm)) 144 | 145 | if missing_perms: 146 | help_cmd = self.bot.static_help_command 147 | help_cmd.missing_perms = missing_perms 148 | else: 149 | help_cmd = self.bot.help_command 150 | 151 | help_cmd.context = self 152 | 153 | if isinstance(command, commands.Command): 154 | command = command.qualified_name 155 | 156 | await help_cmd.command_callback(self, command=command) 157 | 158 | async def prompt(self, title=None, prompt=None, user_override=None): 159 | """Creates a yes/no prompt.""" 160 | 161 | perms = self.perms 162 | if not all(getattr(perms, perm) for perm in PROMPT_REQUIRED_PERMS): 163 | return False 164 | 165 | e = disnake.Embed(description=prompt or "No description provided.") 166 | 167 | e.set_author(name=title or "Prompt", icon_url=self.bot.user.display_avatar.url) 168 | 169 | view = PromptView(check_user=user_override or self.author, timeout=60.0) 170 | 171 | try: 172 | message = await self.send( 173 | content=None if user_override is None else user_override.mention, 174 | embed=e, 175 | view=view, 176 | ) 177 | 178 | view.message = message 179 | except disnake.HTTPException: 180 | return False 181 | 182 | return await view.wait() 183 | 184 | async def admin_prompt(self, raise_on_abort=True): 185 | result = await self.prompt( 186 | title="Warning!", 187 | prompt=( 188 | "You are about to do an administrative action on an item you do not own.\n\n" 189 | "Are you sure you want to continue?" 190 | ), 191 | ) 192 | 193 | if raise_on_abort and not result: 194 | raise commands.CommandError("Administrative action aborted.") 195 | 196 | return result 197 | -------------------------------------------------------------------------------- /utils/converters.py: -------------------------------------------------------------------------------- 1 | import re 2 | from inspect import Parameter 3 | 4 | import disnake 5 | import emoji 6 | from disnake.ext import commands 7 | 8 | from .fakeuser import FakeUser 9 | 10 | empty = Parameter.empty 11 | 12 | 13 | def param_name(converter, ctx): 14 | fallback = "Argument" 15 | 16 | for param_name, parameter in ctx.command.params.items(): 17 | param_conv = parameter.annotation 18 | 19 | if param_conv == empty: 20 | continue 21 | 22 | if param_conv is converter: 23 | return param_name 24 | 25 | return fallback 26 | 27 | 28 | def _make_int(converter, ctx, argument): 29 | try: 30 | return int(argument) 31 | except ValueError: 32 | name = param_name(converter, ctx) 33 | raise commands.BadArgument(f"{name} should be a number.") 34 | 35 | 36 | class MaybeMemberConverter(commands.MemberConverter): 37 | async def resolve_id(self, ctx, member_id): 38 | member = ctx.guild.get_member(member_id) 39 | if member is not None: 40 | return member 41 | 42 | try: 43 | return await ctx.guild.fetch_member(member_id) 44 | except disnake.HTTPException: 45 | return FakeUser(member_id, ctx.guild) 46 | 47 | async def convert(self, ctx, argument): 48 | try: 49 | return await super().convert(ctx, argument) 50 | except commands.BadArgument as exc: 51 | # handles pure id's 52 | if argument.isdigit(): 53 | return await self.resolve_id(ctx, int(argument)) 54 | 55 | # handles mentions 56 | match = re.match(r"<@!?([0-9]+)>$", argument) 57 | if match is not None: 58 | return await self.resolve_id(ctx, int(match.group(1))) 59 | 60 | raise exc 61 | 62 | 63 | class EmojiConverter(commands.Converter): 64 | async def convert(self, ctx, argument): 65 | guild_emojis = list(str(e) for e in ctx.guild.emojis) 66 | 67 | if argument not in emoji.UNICODE_EMOJI["en"]: 68 | if argument not in guild_emojis: 69 | raise commands.BadArgument("Unknown emoji.") 70 | 71 | return argument 72 | 73 | 74 | class MaxValueConverter(commands.Converter): 75 | def __init__(self, _max): 76 | self.max = _max 77 | 78 | async def convert(self, ctx, argument): 79 | value = _make_int(self, ctx, argument) 80 | 81 | if value > self.max: 82 | name = param_name(self, ctx) 83 | raise commands.BadArgument(f"{name} must be lower than {self.max}") 84 | 85 | return value 86 | 87 | 88 | class SerialConverter(commands.Converter): 89 | MAX = pow(2, 31) - 1 90 | 91 | async def convert(self, ctx, argument): 92 | value = _make_int(self, ctx, argument) 93 | 94 | if value > self.MAX: 95 | name = param_name(self, ctx) 96 | raise commands.BadArgument(f"{name} must be a lower value.") 97 | 98 | return value 99 | 100 | 101 | class RangeConverter(commands.Converter): 102 | def __init__(self, min, max): 103 | self.min = min 104 | self.max = max 105 | 106 | def _make_error(self, ctx): 107 | name = param_name(self, ctx) 108 | return commands.BadArgument(f"{name} must be between {self.min} and {self.max}") 109 | 110 | async def convert(self, ctx, argument): 111 | value = _make_int(self, ctx, argument) 112 | 113 | if value < self.min: 114 | raise self._make_error(ctx) 115 | 116 | if value > self.max: 117 | raise self._make_error(ctx) 118 | 119 | return value 120 | 121 | 122 | class LengthConverter(commands.Converter): 123 | def __init__(self, min=1, max=32): 124 | self.min = min 125 | self.max = max 126 | 127 | def _make_error(self, ctx): 128 | name = param_name(self, ctx) 129 | return commands.BadArgument(f"{name} must be between {self.min} and {self.max} characters.") 130 | 131 | async def convert(self, ctx, argument): 132 | length = len(argument) 133 | 134 | if length < self.min: 135 | raise self._make_error(ctx) 136 | 137 | if length > self.max: 138 | raise self._make_error(ctx) 139 | 140 | return argument 141 | 142 | 143 | class MaxLengthConverter(commands.Converter): 144 | def __init__(self, max=32): 145 | self.max = max 146 | 147 | async def convert(self, ctx, argument): 148 | length = len(argument) 149 | 150 | if length > self.max: 151 | name = param_name(self, ctx) 152 | raise commands.BadArgument(f"{name} must be shorter than {self.max} characters.") 153 | 154 | return argument 155 | -------------------------------------------------------------------------------- /utils/databasetimer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from datetime import datetime, timedelta 4 | 5 | import asyncpg 6 | import disnake 7 | 8 | from utils.time import pretty_timedelta 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | class DatabaseTimer: 14 | MAX_SLEEP = timedelta(days=40) 15 | 16 | def __init__(self, bot, event_name): 17 | self.bot = bot 18 | self.event_name = event_name 19 | 20 | self.record = None 21 | self.task = self.start_task() 22 | 23 | def start_task(self): 24 | return self.bot.loop.create_task(self.dispatch()) 25 | 26 | def restart_task(self): 27 | self.task.cancel() 28 | self.task = self.start_task() 29 | 30 | async def dispatch(self): 31 | # don't run timer before bot is ready 32 | await self.bot.wait_until_ready() 33 | 34 | try: 35 | while True: 36 | # fetch next record 37 | record = await self.get_record() 38 | 39 | # if none was found, sleep for 40 days and check again 40 | if record is None: 41 | log.debug("No record found for %s, sleeping", self.event_name) 42 | await asyncio.sleep(self.MAX_SLEEP.total_seconds()) 43 | continue 44 | 45 | # we are now with this record 46 | self.record = record 47 | 48 | # get datetime again in case query took a lot of time 49 | now = datetime.utcnow() 50 | then = self.when(record) 51 | dt = then - now 52 | 53 | # if the next record is in the future, sleep until it should be invoked 54 | if now < then: 55 | log.debug( 56 | "%s dispatching in %s", 57 | self.event_name, 58 | pretty_timedelta(then - now), 59 | ) 60 | await asyncio.sleep(dt.total_seconds()) 61 | 62 | await self.cleanup_record(record) 63 | self.record = None 64 | 65 | log.debug("Dispatching event %s", self.event_name) 66 | 67 | # run it 68 | self.bot.dispatch(self.event_name, record) 69 | 70 | except (disnake.ConnectionClosed, asyncpg.PostgresConnectionError) as e: 71 | # if anything happened, sleep for 15 seconds then attempt a restart 72 | log.warning( 73 | "DatabaseTimer got exception %s: attempting restart in 15 seconds", 74 | str(e), 75 | ) 76 | 77 | await asyncio.sleep(15) 78 | self.restart_task() 79 | 80 | async def get_record(self): 81 | raise NotImplementedError 82 | 83 | async def cleanup_record(self, record): 84 | raise NotImplementedError 85 | 86 | def when(self, record): 87 | raise NotImplementedError 88 | 89 | def maybe_restart(self, dt): 90 | if self.record is None: 91 | self.restart_task() 92 | 93 | elif dt < self.when(self.record): 94 | self.restart_task() 95 | 96 | def restart_if(self, pred): 97 | if self.record is None or pred(self.record): 98 | self.restart_task() 99 | 100 | 101 | class ColumnTimer(DatabaseTimer): 102 | def __init__(self, bot, event_name, table, column): 103 | super().__init__(bot, event_name) 104 | self.table = table 105 | self.column = column 106 | 107 | async def get_record(self): 108 | return await self.bot.db.fetchrow( 109 | "SELECT * FROM {0} WHERE {1} < $1 AND {1} IS NOT NULL ORDER BY {1} LIMIT 1".format( 110 | self.table, self.column 111 | ), 112 | datetime.utcnow() + self.MAX_SLEEP, 113 | ) 114 | 115 | async def cleanup_record(self, record): 116 | await self.bot.db.execute( 117 | "DELETE FROM {0} WHERE id=$1".format(self.table), record.get("id") 118 | ) 119 | 120 | def when(self, record): 121 | return record.get(self.column) 122 | -------------------------------------------------------------------------------- /utils/fakeuser.py: -------------------------------------------------------------------------------- 1 | from disnake import Object 2 | 3 | 4 | class FakeAsset: 5 | def __init__(self, url): 6 | self.url = url 7 | 8 | def __str__(self): 9 | return self.url 10 | 11 | 12 | class FakeUser(Object): 13 | def __init__(self, id, guild=None, **data): 14 | super().__init__(id) 15 | 16 | self._guild = guild 17 | self._data = data 18 | 19 | @property 20 | def guild(self): 21 | if self._guild is None: 22 | raise ValueError("FakeUser does not have a guild") 23 | 24 | return self._guild 25 | 26 | @property 27 | def mention(self): 28 | return f"<@!{self.id}>" 29 | 30 | @property 31 | def name(self): 32 | return self._data.get("name", "Unknown User") 33 | 34 | @property 35 | def nick(self): 36 | return self._data.get("nick", None) 37 | 38 | @property 39 | def display_name(self): 40 | return self.nick or self.name 41 | 42 | @property 43 | def discriminator(self): 44 | return self._data.get("discriminator", "????") 45 | 46 | @property 47 | def avatar(self): 48 | return FakeAsset( 49 | url=self._data.get("avatar_url", "https://cdn.discordapp.com/embed/avatars/0.png") 50 | ) 51 | 52 | @property 53 | def display_avatar(self): 54 | return self.avatar 55 | 56 | def __str__(self): 57 | name = self.name 58 | nick = self.nick 59 | discriminator = self.discriminator 60 | 61 | string = "" 62 | if nick is not None: 63 | string += nick 64 | elif name is not None: 65 | string += name 66 | else: 67 | raise ValueError("Not enough information in FakeMember data to craft str") 68 | 69 | if discriminator is not None: 70 | string += "#" + discriminator 71 | 72 | return string 73 | -------------------------------------------------------------------------------- /utils/guildconfigrecord.py: -------------------------------------------------------------------------------- 1 | from .configtable import ConfigTableRecord 2 | 3 | 4 | class GuildConfigRecord(ConfigTableRecord): 5 | @property 6 | def mod_role(self): 7 | if self.mod_role_id is None: 8 | return None 9 | 10 | guild = self._config.bot.get_guild(self.guild_id) 11 | 12 | if guild is None: 13 | return None 14 | 15 | return guild.get_role(self.mod_role_id) 16 | -------------------------------------------------------------------------------- /utils/help.py: -------------------------------------------------------------------------------- 1 | import disnake 2 | from disnake.ext import commands 3 | 4 | from utils.pager import Pager 5 | 6 | 7 | class HelpPager(Pager): 8 | commands_per_page = 8 9 | 10 | def add_page(self, cog_name, cog_desc, commands): 11 | """Will split into several pages to accommodate the per_page limit.""" 12 | 13 | # will obviously not run if no commands are in the page 14 | for commands_slice in [ 15 | commands[i : i + self.commands_per_page] 16 | for i in range(0, len(commands), self.commands_per_page) 17 | ]: 18 | self.entries.append((cog_name, cog_desc, commands_slice)) 19 | 20 | def craft_invite_string(self): 21 | return "[Enjoying the bot? Invite it to your own server!]({0})".format( 22 | self.ctx.bot.invite_link 23 | ) 24 | 25 | async def create_base_embed(self): 26 | return disnake.Embed() 27 | 28 | async def update_page_embed(self, embed, page, entries): 29 | cog_name, cog_desc, commands = entries[0] 30 | 31 | name = f"{cog_name} Commands" 32 | 33 | desc = "" 34 | if self.ctx.guild.owner != self.ctx.author: 35 | desc += self.craft_invite_string() 36 | 37 | if cog_desc is not None: 38 | desc += "\n\n" + cog_desc 39 | 40 | embed.set_author(name=name, icon_url=self.ctx.bot.user.display_avatar.url) 41 | embed.description = desc 42 | 43 | embed.clear_fields() 44 | for name, value in commands: 45 | embed.add_field(name=name, value=value, inline=False) 46 | 47 | async def help_embed(self, e): 48 | e.set_author(name="How do I use the bot?", icon_url=self.bot.user.display_avatar.url) 49 | 50 | e.description = ( 51 | "Invoke a command by sending the prefix followed by a command name.\n\n" 52 | "For example, the command signature `track ` can be invoked by doing `track yellow`\n\n" 53 | "The different argument brackets mean:" 54 | ) 55 | 56 | e.add_field(name="", value="the argument is required.", inline=False) 57 | e.add_field(name="[argument]", value="the argument is optional.\n\u200b", inline=False) 58 | 59 | e.add_field( 60 | name="Support Server", 61 | value="Join the support server!\n" + self.bot.support_link, 62 | ) 63 | 64 | 65 | class PaginatedHelpCommand(commands.HelpCommand): 66 | """Cog that implements the help command and help pager.""" 67 | 68 | async def package_command(self, command, force=False, long_help=False): 69 | if command.hidden or not command.enabled: 70 | return None 71 | 72 | if not force: 73 | try: 74 | if not await command.can_run(self.context): 75 | return None 76 | except commands.CommandError: 77 | return None 78 | 79 | help_message = command.brief or command.help 80 | 81 | if help_message is None: 82 | help_message = "No description available." 83 | elif not long_help: 84 | help_message = help_message.split("\n")[0] 85 | 86 | # unsure if I want this 87 | if command.aliases: 88 | help_message += "\nAliases: `" + ", ".join(command.aliases) + "`" 89 | 90 | return self.context.prefix + get_signature(command), help_message 91 | 92 | async def prepare_help_command(self, ctx, command=None): 93 | self.context = ctx 94 | self.pager = HelpPager(ctx, list(), per_page=1) 95 | 96 | async def add_cog(self, cog: commands.Cog, force=False): 97 | cog_name = cog.__class__.__name__ 98 | cog_desc = cog.__doc__ 99 | 100 | commands = [] 101 | added = [] 102 | 103 | for command in cog.walk_commands(): 104 | if command in added: 105 | continue 106 | 107 | added.append(command) 108 | 109 | pack = await self.package_command(command, force=force) 110 | if pack is None: 111 | continue 112 | 113 | commands.append(pack) 114 | 115 | if not commands: 116 | return True 117 | 118 | self.pager.add_page(cog_name, cog_desc, commands) 119 | 120 | async def send_bot_help(self, mapping): 121 | for cog in mapping: 122 | if cog is not None: 123 | await self.add_cog(cog) 124 | 125 | await self.pager.go() 126 | 127 | async def send_cog_help(self, cog): 128 | if await self.add_cog(cog, force=True): 129 | return 130 | 131 | await self.pager.go() 132 | 133 | async def send_group_help(self, group): 134 | cog_name = group.cog_name 135 | 136 | if cog_name is not None and group.cog_name.lower() == group.name.lower(): 137 | await self.send_cog_help(group.cog) 138 | return 139 | 140 | commands = [] 141 | seen = [] 142 | 143 | for command in group.walk_commands(): 144 | if command in seen: 145 | continue 146 | 147 | seen.append(command) 148 | 149 | pack = await self.package_command(command) 150 | if pack is None: 151 | continue 152 | 153 | commands.append(pack) 154 | 155 | # if we found no commands, just stop here 156 | if not commands: 157 | await self.stop() 158 | return 159 | 160 | self.pager.add_page(group.cog_name, group.cog.__doc__, commands) 161 | await self.pager.go() 162 | 163 | async def send_command_help(self, command): 164 | cog_name = command.cog_name 165 | 166 | if cog_name is not None and cog_name.lower() == command.name: 167 | await self.send_cog_help(command.cog) 168 | return 169 | 170 | pack = await self.package_command(command, force=True, long_help=True) 171 | 172 | if pack is None: # probably means it's hidden 173 | await self.stop() 174 | return 175 | 176 | self.pager.add_page(cog_name, command.cog.__doc__, [pack]) 177 | await self.pager.go() 178 | 179 | async def stop(self): 180 | await self.send_error_message(await self.command_not_found(self.context.kwargs["command"])) 181 | 182 | async def command_not_found(self, command_name): 183 | return commands.CommandNotFound(command_name) 184 | 185 | async def send_error_message(self, error): 186 | if not isinstance(error, commands.CommandNotFound): 187 | return 188 | 189 | command_name = str(error) 190 | 191 | for cog in self.context.bot.cogs: 192 | if command_name == cog.lower(): 193 | await self.send_cog_help(self.context.bot.get_cog(cog)) 194 | return 195 | 196 | await self.context.send("Command '{0}' not found.".format(command_name)) 197 | 198 | 199 | # rip is just the signature command ripped from the lib, but with alias support removed. 200 | def get_signature(command): 201 | """Returns a POSIX-like signature useful for help command output.""" 202 | 203 | result = [] 204 | parent = command.full_parent_name 205 | 206 | name = command.name if not parent else parent + " " + command.name 207 | result.append(name) 208 | 209 | if command.usage: 210 | result.append(command.usage) 211 | return " ".join(result) 212 | 213 | params = command.clean_params 214 | if not params: 215 | return " ".join(result) 216 | 217 | for name, param in params.items(): 218 | if param.default is not param.empty: 219 | # We don't want None or '' to trigger the [name=value] case and instead it should 220 | # do [name] since [name=None] or [name=] are not exactly useful for the user. 221 | should_print = ( 222 | param.default if isinstance(param.default, str) else param.default is not None 223 | ) 224 | if should_print: 225 | result.append("[%s=%s]" % (name, param.default)) 226 | else: 227 | result.append("[%s]" % name) 228 | elif param.kind == param.VAR_POSITIONAL: 229 | result.append("[%s...]" % name) 230 | else: 231 | result.append("<%s>" % name) 232 | 233 | return " ".join(result) 234 | -------------------------------------------------------------------------------- /utils/html2markdown.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from bs4 import BeautifulSoup, NavigableString 4 | 5 | PREPEND = dict( 6 | br="\n", # linebreak 7 | ) 8 | 9 | 10 | WRAP = dict( 11 | b="**", # bold 12 | em="*", # emphasis 13 | i="*", # italics 14 | # blockquote='*', 15 | ) 16 | 17 | 18 | MULTIWRAP = dict( 19 | li=(" • ", "\n"), # list item 20 | ) 21 | 22 | 23 | SPACING = dict( 24 | p=2, 25 | div=2, 26 | ul=2, 27 | ) 28 | 29 | 30 | CODEBOX_NAMES = ("code", "pre") 31 | 32 | 33 | class CreditsEmpty(Exception): 34 | pass 35 | 36 | 37 | class Result: 38 | def __init__(self, credits): 39 | self.credits = credits 40 | self.content = "" 41 | 42 | def __str__(self): 43 | return self.content 44 | 45 | def consume(self, amount: int): 46 | if amount > self.credits: 47 | raise CreditsEmpty() 48 | 49 | self.credits -= amount 50 | 51 | def feed(self, amount: int): 52 | self.credits += amount 53 | 54 | def can_afford(self, string): 55 | return len(string) <= self.credits 56 | 57 | def ensure_spacing(self, spacing=2): 58 | while self.content.endswith("\n" * (spacing + 1)): 59 | self.content = self.content[:-1] 60 | self.feed(1) 61 | 62 | while not self.content.endswith("\n" * spacing): 63 | self.content += "\n" 64 | self.consume(1) 65 | 66 | def add(self, string): 67 | self.content += string 68 | 69 | def add_and_consume(self, string, trunc=False): 70 | to_consume = len(string) 71 | do_raise = False 72 | 73 | if trunc is True and to_consume > self.credits: 74 | to_consume = self.credits 75 | string = string[:to_consume] 76 | do_raise = True 77 | 78 | self.consume(to_consume) 79 | self.add(string) 80 | 81 | if do_raise: 82 | raise CreditsEmpty() 83 | 84 | 85 | class HTML2Markdown: 86 | def __init__(self, escaper=None, big_box=False, lang=None, max_len=2000, base_url=None): 87 | self.result = None 88 | self.escaper = escaper 89 | self.max_len = max_len 90 | self.base_url = base_url 91 | self.big_box = big_box 92 | self.lang = lang 93 | 94 | self.cutoff = "..." 95 | 96 | def convert(self, html): 97 | self.result = Result(max(self.max_len, 8) - len(self.cutoff) - 1) 98 | 99 | try: 100 | self.traverse(BeautifulSoup(html, "html.parser")) 101 | except CreditsEmpty: 102 | if str(self.result).endswith(" "): 103 | self.result.add(self.cutoff) 104 | else: 105 | self.result.add(" " + self.cutoff) 106 | 107 | content = str(self.result) 108 | 109 | # shorten groups of more than 2 newlines to just 2 newlines 110 | content = re.sub("\n\n+", "\n\n", content) 111 | 112 | # remove multiple newlines after triple backticks 113 | if self.big_box is not None: 114 | # always only one newline at the side of either triple backtick 115 | content = re.sub("\n+```\n+", "\n```\n", content) 116 | 117 | if self.lang is not None: 118 | content = re.sub( 119 | "\n+```{}\n+".format(self.lang), 120 | "\n```{}\n".format(self.lang), 121 | content, 122 | ) 123 | 124 | # strip of trailing/leading newlines and return 125 | return content.strip("\n") 126 | 127 | def traverse(self, tag): 128 | for node in tag.contents: 129 | if isinstance(node, NavigableString): 130 | self.navigable_string(node) 131 | else: 132 | if node.name in CODEBOX_NAMES: 133 | self.codebox(node) 134 | continue 135 | elif node.name == "a": 136 | self.link(node) 137 | continue 138 | 139 | back_required = False 140 | 141 | if node.name in PREPEND: 142 | front, back = PREPEND[node.name], "" 143 | 144 | elif node.name in WRAP: 145 | wrap_str = WRAP[node.name] 146 | front, back = wrap_str, wrap_str 147 | 148 | # for wrapping, we *must* add the back char(s) 149 | back_required = True 150 | 151 | elif node.name in MULTIWRAP: 152 | front, back = MULTIWRAP[node.name] 153 | 154 | else: 155 | front, back = "", "" 156 | 157 | # if we can't add the front + back and at least one char, just raise creditsempty 158 | if not self.result.can_afford(front + back + " "): 159 | raise CreditsEmpty() 160 | 161 | self.result.add_and_consume(front) 162 | 163 | # prematurely consume the back characters if it *must* be added later 164 | if back_required: 165 | self.result.consume(len(back)) 166 | 167 | try: 168 | self.traverse(node) 169 | except CreditsEmpty as exc: 170 | if back_required: 171 | self.result.add(back) 172 | raise exc 173 | 174 | if back_required: 175 | self.result.add(back) 176 | else: 177 | if node.name in SPACING: 178 | self.result.ensure_spacing(SPACING[node.name]) 179 | else: 180 | self.result.add_and_consume(back) 181 | 182 | def navigable_string(self, node): 183 | content = str(node) 184 | if content == "\n": 185 | return 186 | self.result.add_and_consume( 187 | self.escaper(content) if callable(self.escaper) else content, True 188 | ) 189 | 190 | def get_content(self, tag): 191 | content = self._get_content_meta(tag) 192 | return content.strip() 193 | 194 | def _get_content_meta(self, tag): 195 | if isinstance(tag, NavigableString): 196 | return str(tag) 197 | elif tag.name == "br": 198 | return "\n" 199 | 200 | content = "" 201 | for child in tag.children: 202 | content += self._get_content_meta(child) 203 | 204 | return content 205 | 206 | def codebox(self, tag): 207 | front, back = self._codebox_wraps() 208 | 209 | # specific fix for autohotkey rss 210 | for br in tag.find_all("br"): 211 | br.replace_with("\n") 212 | 213 | contents = self.get_content(tag) 214 | 215 | self.result.add_and_consume(front + contents + back) 216 | 217 | def _codebox_wraps(self): 218 | return ( 219 | "```{}\n".format(self.lang or "") if self.big_box else "`", 220 | "\n```\n" if self.big_box else "`", 221 | ) 222 | 223 | def link(self, tag): 224 | credits = self.result.credits 225 | 226 | link = self._format_link(tag["href"]) 227 | contents = self.get_content(tag) 228 | 229 | full = "[{}]({})".format(contents, link) 230 | 231 | if link is None: 232 | self.result.add_and_consume(contents, True) 233 | elif credits >= len(full): 234 | self.result.add_and_consume(full) 235 | elif credits >= len(link) + 5: 236 | self.result.add_and_consume( 237 | "[{}]({})".format(contents[: credits - len(link) - 4], link) 238 | ) 239 | else: 240 | self.result.add_and_consume(contents, True) 241 | 242 | def _format_link(self, href): 243 | if re.match(r"^.+:\/\/", href): 244 | return href 245 | 246 | if self.base_url is None: 247 | return None 248 | 249 | if href.startswith("#"): 250 | return self.base_url + href 251 | else: 252 | return "/".join(self.base_url.split("/")[:-1]) + "/" + href 253 | -------------------------------------------------------------------------------- /utils/pager.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import TYPE_CHECKING, Union 3 | 4 | import disnake 5 | 6 | if TYPE_CHECKING: 7 | from disnake.ext import commands 8 | 9 | FIRST_LABEL = "First page" 10 | LAST_LABEL = "Last page" 11 | NEXT_LABEL = "Next page" 12 | PREV_LABEL = "Previous page" 13 | STOP_LABEL = "Stop" 14 | 15 | FIRST_EMOJI = "\N{BLACK LEFT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}" 16 | NEXT_EMOJI = "\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE}" 17 | PREV_EMOJI = "\N{BLACK LEFT-POINTING DOUBLE TRIANGLE}" 18 | LAST_EMOJI = "\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}" 19 | STOP_EMOJI = "\N{BLACK SQUARE FOR STOP}" 20 | HELP_EMOJI = "\N{WHITE QUESTION MARK ORNAMENT}" 21 | 22 | 23 | class Pager(disnake.ui.View): 24 | def __init__( 25 | self, 26 | ctx: Union[disnake.AppCommandInteraction, "commands.Context"], 27 | entries=None, 28 | per_page=6, 29 | timeout=60.0 * 15, 30 | ): 31 | super().__init__(timeout=timeout) 32 | 33 | self.ctx = ctx 34 | self.entries = entries 35 | self.per_page = per_page 36 | 37 | self.page = 0 38 | self.embed = None 39 | self.message = None 40 | 41 | self.__buttons = None 42 | 43 | async def go(self, at_page=0): 44 | if isinstance(self.ctx, disnake.Interaction): 45 | meth = ( 46 | self.ctx.followup.send 47 | if self.ctx.response.is_done() 48 | else self.ctx.response.send_message 49 | ) 50 | else: 51 | meth = self.ctx.send 52 | 53 | kwargs = dict(embed=await self.init(at_page=at_page)) 54 | 55 | if self.top_page: 56 | kwargs["view"] = self 57 | 58 | self.message = await meth(**kwargs) 59 | 60 | async def init(self, at_page=0): 61 | self.embed = await self.create_base_embed() 62 | await self.try_page(at_page) 63 | 64 | return self.embed 65 | 66 | def get_page_entries(self, page): 67 | """Converts a page number to a range of entries.""" 68 | base = page * self.per_page 69 | return self.entries[base : base + self.per_page] 70 | 71 | async def create_base_embed(self): 72 | raise NotImplementedError() 73 | 74 | async def update_page_embed(self, embed, page, entries): 75 | raise NotImplementedError() 76 | 77 | @property 78 | def buttons(self): 79 | if self.__buttons is not None: 80 | return self.__buttons 81 | 82 | buttons = { 83 | child.label: child 84 | for child in self.children 85 | if isinstance(child, disnake.ui.Button) and child.label is not None 86 | } 87 | 88 | self.__buttons = buttons 89 | return buttons 90 | 91 | @property 92 | def top_page(self): 93 | return ceil(len(self.entries) / self.per_page) - 1 94 | 95 | async def try_page(self, page): 96 | if not 0 <= page <= self.top_page: 97 | return 98 | 99 | self.page = page 100 | 101 | if self.top_page: 102 | self.embed.set_footer(text=f"Page {page + 1}/{self.top_page + 1}") 103 | 104 | is_first = self.page == 0 105 | is_last = self.page == self.top_page 106 | self.buttons[PREV_LABEL].disabled = is_first 107 | self.buttons[NEXT_LABEL].disabled = is_last 108 | self.buttons[FIRST_LABEL].disabled = is_first 109 | self.buttons[LAST_LABEL].disabled = is_last 110 | 111 | await self.update_page_embed(self.embed, page, self.get_page_entries(page)) 112 | 113 | @property 114 | def author(self): 115 | return self.ctx.author 116 | 117 | async def interaction_check(self, interaction: disnake.MessageInteraction) -> bool: 118 | return interaction.author == self.author 119 | 120 | async def on_timeout(self) -> None: 121 | try: 122 | await self.message.edit(view=None) 123 | except disnake.HTTPException: 124 | pass 125 | 126 | @disnake.ui.button(label=PREV_LABEL, emoji=PREV_EMOJI, style=disnake.ButtonStyle.primary, row=0) 127 | async def prev_page(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 128 | await self.try_page(self.page - 1) 129 | await inter.response.edit_message(embed=self.embed, view=self) 130 | 131 | @disnake.ui.button(label=NEXT_LABEL, emoji=NEXT_EMOJI, style=disnake.ButtonStyle.primary, row=0) 132 | async def next_page(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 133 | await self.try_page(self.page + 1) 134 | await inter.response.edit_message(embed=self.embed, view=self) 135 | 136 | @disnake.ui.button( 137 | label=FIRST_LABEL, emoji=FIRST_EMOJI, style=disnake.ButtonStyle.secondary, row=1 138 | ) 139 | async def first_page(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 140 | await self.try_page(0) 141 | await inter.response.edit_message(embed=self.embed, view=self) 142 | 143 | @disnake.ui.button( 144 | label=LAST_LABEL, emoji=LAST_EMOJI, style=disnake.ButtonStyle.secondary, row=1 145 | ) 146 | async def last_page(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 147 | await self.try_page(self.top_page) 148 | await inter.response.edit_message(embed=self.embed, view=self) 149 | 150 | @disnake.ui.button(label=STOP_LABEL, emoji=STOP_EMOJI, style=disnake.ButtonStyle.danger, row=1) 151 | async def stop_pager(self, button: disnake.ui.Button, inter: disnake.MessageInteraction): 152 | await inter.response.edit_message(view=None) 153 | self.stop() 154 | -------------------------------------------------------------------------------- /utils/string.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from utils.time import pretty_datetime 4 | 5 | 6 | def shorten(text, max_char=2000): 7 | """Shortens text to fit within max_chars and max_newline.""" 8 | 9 | if max_char < 16: 10 | raise ValueError("Only shortens down to 16 characters") 11 | 12 | if max_char >= len(text): 13 | return text 14 | 15 | text = text[0 : max_char - 4] 16 | 17 | for i in range(1, 16): 18 | if text[-i] in (" ", "\n"): 19 | return text[0 : len(text) - i + 1] + "..." 20 | 21 | return text + " ..." 22 | 23 | 24 | def po(obj): 25 | try: 26 | name = str(obj) 27 | except: 28 | name = "Unknown" 29 | return "{0} ({1})".format(name, obj.id) 30 | 31 | 32 | def yesno(b): 33 | return "Yes" if b else "No" 34 | -------------------------------------------------------------------------------- /utils/time.py: -------------------------------------------------------------------------------- 1 | import math 2 | from datetime import datetime, timedelta 3 | 4 | from disnake.ext import commands 5 | 6 | ordinal = lambda n: "%d%s" % ( 7 | n, 8 | "tsnrhtdd"[(math.floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4], 9 | ) 10 | 11 | 12 | steps = dict( 13 | year=timedelta(days=365), 14 | week=timedelta(days=7), 15 | day=timedelta(days=1), 16 | hour=timedelta(hours=1), 17 | minute=timedelta(minutes=1), 18 | second=timedelta(seconds=1), 19 | millisecond=timedelta(milliseconds=1), 20 | ) 21 | 22 | 23 | def pretty_timedelta(td: timedelta): 24 | """Returns a pretty string of a timedelta""" 25 | 26 | if not isinstance(td, timedelta): 27 | raise ValueError("timedelta expected, '{}' given".format(type(td))) 28 | 29 | parts = [] 30 | 31 | for name, span in steps.items(): 32 | if td >= span: 33 | count = int(td / span) 34 | td -= count * span 35 | parts.append("{} {}{}".format(count, name, "s" if count > 1 else "")) 36 | if len(parts) >= 2 or name == "second": 37 | break 38 | elif len(parts): 39 | break 40 | 41 | return ", ".join(parts) 42 | 43 | 44 | def pretty_seconds(s): 45 | return pretty_timedelta(timedelta(seconds=s)) 46 | 47 | 48 | def pretty_datetime(dt: datetime, ignore_time=False): 49 | if not isinstance(dt, datetime): 50 | raise ValueError("datetime expected, '{}' given".format(type(dt))) 51 | 52 | return "{0} {1}".format( 53 | ordinal(int(dt.strftime("%d"))), 54 | dt.strftime("%b %Y" + ("" if ignore_time else " %H:%M")), 55 | ) 56 | 57 | 58 | class TimeMultConverter(commands.Converter): 59 | async def convert(self, ctx, mult): 60 | try: 61 | mult = float(mult) 62 | except ValueError: 63 | raise commands.CommandError("Argument has to be float.") 64 | 65 | if mult < 1.0: 66 | raise commands.CommandError("Unit must be more than 1.") 67 | 68 | return mult 69 | 70 | 71 | class TimeDeltaConverter(commands.Converter): 72 | async def convert(self, ctx, unit): 73 | unit = unit.lower() 74 | 75 | if unit in ("s", "sec", "secs", "second", "seconds"): 76 | return timedelta(seconds=1) 77 | elif unit in ("m", "min", "mins", "minute", "minutes"): 78 | return timedelta(minutes=1) 79 | elif unit in ("h", "hr", "hrs", "hour", "hours"): 80 | return timedelta(hours=1) 81 | elif unit in ("d", "day", "days"): 82 | return timedelta(days=1) 83 | elif unit in ("w", "wk", "week", "weeks"): 84 | return timedelta(weeks=1) 85 | else: 86 | raise commands.BadArgument("Unknown time type.") 87 | --------------------------------------------------------------------------------