├── .gitignore ├── LICENSE ├── README.md ├── bot ├── __init__.py ├── __main__.py ├── bot.py ├── command.py ├── commands.py ├── config.json ├── db │ ├── __init__.py │ └── mongodb_connector.py ├── discord │ ├── __init__.py │ ├── client.py │ └── models.py ├── entity_manager.py ├── models.py ├── paginator.py └── sites │ ├── __init__.py │ ├── atcoder.py │ ├── codechef.py │ ├── codeforces.py │ ├── competitive_programming_site.py │ ├── models.py │ └── site_container.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Soumik Sarkar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPBot 2 | 3 | A simple Discord bot that assists with competitive programming 4 | 5 | This is intended to be a learning project for me, and a small tool to be used by the members of my college's Discord group for programmers. 6 | 7 | ### Features 8 | 9 | - CPBot can fetch a list of upcoming contests from supported competitive programming sites. These are displayed when the associated command is invoked. Supports filtering by site and limiting by count. 10 | - CPBot can monitor profiles of users on supported programming sites. This allows users to subscribe using their handle and be notified when their rating is updated after participation in a contest. 11 | 12 | **Warning**: CPBot is intended to be a simple bot, and is not expected to be very robust or handle a large amount of activity. Much of that can be alleviated by using an official API wrapper such as [discord.py](https://github.com/Rapptz/discord.py), which I decided against for the sake of learning. 13 | 14 | -------------------------------------------------------------------------------- /bot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meooow25/cp-discord-bot/4d25b51f9dc4dc44105a6cebeeaea9ef1191c8c1/bot/__init__.py -------------------------------------------------------------------------------- /bot/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import logging 5 | import os 6 | 7 | from .bot import Bot 8 | from .entity_manager import EntityManager 9 | from .db import MongoDBConnector 10 | from .discord import Client 11 | from .sites import AtCoder, CodeChef, Codeforces, SiteContainer 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | DISCORD_TOKEN = os.environ['DISCORD_TOKEN'] 16 | MONGODB_SRV = os.environ['MONGODB_SRV'] 17 | 18 | with open('./bot/config.json') as file: 19 | CONFIG = json.load(file) 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--log', default='WARNING') 25 | args = parser.parse_args() 26 | numeric_level = getattr(logging, args.log.upper(), None) 27 | if not isinstance(numeric_level, int): 28 | raise ValueError(f'Invalid log level: {args.log}') 29 | logging.basicConfig(format='{levelname}:{name}:{message}', style='{', level=numeric_level) 30 | 31 | discord_client = Client(DISCORD_TOKEN, name=CONFIG['name'], activity_name=CONFIG['activity']) 32 | mongodb_connector = MongoDBConnector(MONGODB_SRV, CONFIG['db_name']) 33 | entity_manager = EntityManager(mongodb_connector) 34 | sites = [ 35 | AtCoder(**CONFIG['at_config']), 36 | CodeChef(**CONFIG['cc_config']), 37 | Codeforces(**CONFIG['cf_config']), 38 | ] 39 | site_container = SiteContainer(sites=sites) 40 | 41 | bot = Bot(CONFIG['name'], discord_client, site_container, entity_manager, 42 | triggers=CONFIG['triggers'], allowed_channels=CONFIG['channels']) 43 | 44 | try: 45 | asyncio.run(bot.run()) 46 | except Exception: 47 | logger.exception('Grinding halt') 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /bot/bot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import platform 3 | from datetime import timedelta, timezone 4 | from operator import itemgetter 5 | 6 | from . import command, commands 7 | from .discord import Channel 8 | from .models import User 9 | 10 | 11 | class Bot: 12 | PYTHON_URL = 'https://www.python.org' 13 | GITHUB_URL = 'https://github.com/meooow25/cp-discord-bot' 14 | CONTESTS_PER_PAGE = 5 15 | # TODO: Support separate time zones per channel or server 16 | TIMEZONE = timezone(timedelta(hours=5, minutes=30)) 17 | 18 | def __init__(self, name, client, site_container, entity_manager, triggers=None, allowed_channels=None): 19 | self.name = name 20 | self.client = client 21 | self.site_container = site_container 22 | self.entity_manager = entity_manager 23 | self.triggers = triggers 24 | self.allowed_channels = allowed_channels 25 | self.logger = logging.getLogger(self.__class__.__qualname__) 26 | 27 | self.command_map = {} 28 | for attr_name in dir(commands): 29 | attr = getattr(commands, attr_name) 30 | if isinstance(attr, command.Command): 31 | cmd = attr 32 | self.command_map[cmd.name] = cmd 33 | self.logger.info(f'Loaded commands: {self.command_map.keys()}') 34 | 35 | # Help message begin. 36 | self.help_message = {} 37 | if not triggers: 38 | self.help_message['content'] = '*@mention me to activate me.*\n' 39 | else: 40 | self.help_message['content'] = f'*@mention me or use my trigger `{self.triggers[0]}` to activate me.*\n' 41 | 42 | fields = [cmd.embed_field_rep() for cmd in self.command_map.values() if not cmd.hidden] 43 | fields.sort(key=itemgetter('name')) 44 | self.help_message['embed'] = { 45 | 'title': 'Supported commands:', 46 | 'fields': fields, 47 | } 48 | 49 | # Info message begin. 50 | self.info_message = { 51 | 'content': f'*Hello, I am **{self.name}**!*', 52 | 'embed': { 53 | 'description': f'A half-baked bot made by *meooow*\n' 54 | f'Written in awesome [Python 3.7]({self.PYTHON_URL})\n' 55 | f'Check me out on [Github]({self.GITHUB_URL})!', 56 | }, 57 | } 58 | 59 | # Status message begin. 60 | self.status_message = { 61 | 'content': '*Status info*', 62 | 'embed': { 63 | 'fields': [ 64 | { 65 | 'name': 'System', 66 | 'value': f'Python version: {platform.python_version()}\n' 67 | f'OS and version: {platform.system()}-{platform.release()}', 68 | }, 69 | ], 70 | }, 71 | } 72 | 73 | async def run(self): 74 | """Runs the entity manager, site container, and Discord client.""" 75 | await self.entity_manager.run() 76 | await self.site_container.run(get_all_users=self.get_all_users, 77 | on_profile_fetch=self.on_profile_fetch) 78 | await self.client.run(on_message=self.on_message) 79 | 80 | async def on_message(self, message): 81 | """Callback intended to be executed when the Discord client receives a message.""" 82 | 83 | # message.author is None when message is sent by a webhook. 84 | if not message.author or message.author.bot: 85 | return 86 | 87 | args = message.content.split() 88 | if not args: 89 | return 90 | has_trigger = args[0] in self.triggers if self.triggers else False 91 | has_trigger = has_trigger or args[0] == f'<@{self.client.user["id"]}>' 92 | if has_trigger: 93 | args = args[1:] 94 | if not args: 95 | return 96 | on_allowed_channel = self.allowed_channels is None or message.channel_id in self.allowed_channels 97 | 98 | channel = await self.get_channel(message.channel_id) 99 | if channel.type == Channel.Type.DM: 100 | await self.run_command_from_map(args, message, is_dm=True) 101 | elif has_trigger and on_allowed_channel: 102 | await self.run_command_from_map(args, message, is_dm=False) 103 | 104 | async def run_command_from_map(self, args, message, is_dm): 105 | """Executes the command named ``args[0]`` if it exists.""" 106 | 107 | # Ignore command case. 108 | args[0] = args[0].lower() 109 | cmd = self.command_map.get(args[0]) 110 | if cmd is None: 111 | self.logger.info(f'Unrecognized command {args}') 112 | return 113 | if cmd.allow_dm and is_dm or cmd.allow_guild and not is_dm: 114 | try: 115 | await cmd.execute(self, args[1:], message) 116 | except command.IncorrectUsageException as ex: 117 | self.logger.info(f'Incorrect usage: {ex}') 118 | else: 119 | self.logger.info(f'Command not allowed in current channel type (guild/DM): "{message.content}"') 120 | 121 | async def get_channel(self, channel_id): 122 | """Returns the Discord channel with given channel id. 123 | 124 | Attempts to find the channel is the entity manager first. If not found, the client is queried and the returned 125 | channel saved to the entity manager before returning. 126 | """ 127 | channel = self.entity_manager.get_channel(channel_id) 128 | if channel is None: 129 | channel = await self.client.get_channel(channel_id) 130 | await self.entity_manager.save_channel(channel) 131 | return channel 132 | 133 | def get_all_users(self): 134 | """Returns a shallow copy of the list of all users.""" 135 | return self.entity_manager.users[:] 136 | 137 | async def on_profile_fetch(self, user, old_profile, new_profile): 138 | """Callback intended to be executed when the site container updates a user site profile.""" 139 | 140 | changed = await self.entity_manager.update_user_site_profile(user.discord_id, new_profile) 141 | if not changed: 142 | return 143 | self.logger.debug(f'Changed profile: {old_profile.to_dict()}, {new_profile.to_dict()}') 144 | msg = { 145 | 'content': '*Your profile has been updated*', 146 | 'embed': User.get_profile_change_embed(old_profile, new_profile) 147 | } 148 | await self.client.send_message(msg, user.dm_channel_id) 149 | -------------------------------------------------------------------------------- /bot/command.py: -------------------------------------------------------------------------------- 1 | class Command: 2 | """An executable bot command.""" 3 | 4 | def __init__(self, func, name=None, usage=None, desc=None, hidden=False, 5 | allow_guild=True, allow_dm=False): 6 | """ 7 | :param func: the function that actually does the work 8 | :param name: the command name 9 | :param usage: command usage information 10 | :param desc: command description 11 | :param hidden: whether the command is hidden 12 | :param allow_guild: whether the command is allowed in guild channels 13 | :param allow_dm: whether the command is allowed in DM channels 14 | """ 15 | self.func = func 16 | self.name = func.__name__ if name is None else name 17 | self.usage = func.__name__ if usage is None else usage 18 | self.desc = func.__name__ if desc is None else desc 19 | self.hidden = hidden 20 | self.allow_guild = allow_guild 21 | self.allow_dm = allow_dm 22 | 23 | async def execute(self, *args, **kwargs): 24 | """Execute the command.""" 25 | await self.func(*args, **kwargs) 26 | 27 | def embed_field_rep(self): 28 | """Returns a Discord embed field representing this command.""" 29 | return { 30 | 'name': self.usage, 31 | 'value': self.desc, 32 | } 33 | 34 | 35 | class IncorrectUsageException(Exception): 36 | """Represents an exception raised when a command is used incorrectly.""" 37 | 38 | def __init__(self, msg=None, cmd=None): 39 | """ 40 | :param msg: a message to be displayed 41 | :param cmd: the command in context 42 | """ 43 | if cmd: 44 | msg = f'Command "{cmd}": {msg}' if msg else f'Command "{cmd}"' 45 | if msg: 46 | super().__init__(msg) 47 | else: 48 | super().__init__() 49 | 50 | 51 | def command(func=None, **kwargs): 52 | """Wraps an async function in a Command object, intended for use as a decorator""" 53 | if func is not None: 54 | return Command(func, **kwargs) 55 | return lambda fun: Command(fun, **kwargs) 56 | 57 | 58 | def assert_true(val, msg=None, cmd=None): 59 | if val is not True: 60 | msg = msg or f'Expected True, found {val}' 61 | raise IncorrectUsageException(msg, cmd) 62 | 63 | 64 | def assert_none(val, msg=None, cmd=None): 65 | if val is not None: 66 | msg = msg or f'Expected None, found {val}' 67 | raise IncorrectUsageException(msg, cmd) 68 | 69 | 70 | def assert_not_none(val, msg=None, cmd=None): 71 | if val is None: 72 | msg = msg or f'Expected not None, found None' 73 | raise IncorrectUsageException(msg, cmd) 74 | 75 | 76 | def assert_int(val, msg=None, cmd=None): 77 | try: 78 | int(val) 79 | except ValueError: 80 | msg = msg or f'Expected int, found {val}' 81 | raise IncorrectUsageException(msg, cmd) 82 | 83 | 84 | def assert_arglen(args, num, msg=None, cmd=None): 85 | if len(args) != num: 86 | msg = msg or f'Expected {num} arguments, found {len(args)}' 87 | raise IncorrectUsageException(msg, cmd) 88 | -------------------------------------------------------------------------------- /bot/commands.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import time 4 | from datetime import datetime, timedelta 5 | 6 | from .discord import Channel 7 | from . import command, paginator 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @command.command(desc='Responds with boop') 13 | async def beep(bot, args, message): 14 | command.assert_arglen(args, 0, cmd=message.content) 15 | reply = {'content': '*boop*'} 16 | await bot.client.send_message(reply, message.channel_id) 17 | 18 | 19 | @command.command(usage='help [cmd]', 20 | desc='Displays information about commands. When `cmd` is provided, only displays ' 21 | 'information about that command') 22 | async def help(bot, args, message): 23 | if not args: 24 | reply = bot.help_message 25 | await paginator.paginate_and_send(reply, bot, message.channel_id, per_page=4, 26 | time_active=15 * 60, time_delay=2 * 60) 27 | else: 28 | command.assert_arglen(args, 1, cmd=message.content) 29 | cmd_name = args.pop() 30 | cmd = bot.command_map.get(cmd_name) 31 | command.assert_not_none(cmd, msg=f'Unrecognized command "{cmd_name}"', cmd=message.content) 32 | field = cmd.embed_field_rep() 33 | field['name'] = 'Usage: ' + field['name'] 34 | reply = { 35 | 'embed': { 36 | 'title': cmd_name, 37 | 'fields': [field], 38 | } 39 | } 40 | await bot.client.send_message(reply, message.channel_id) 41 | 42 | 43 | @command.command(desc='Displays bot info') 44 | async def info(bot, args, message): 45 | command.assert_arglen(args, 0, cmd=message.content) 46 | reply = bot.info_message 47 | await bot.client.send_message(reply, message.channel_id) 48 | 49 | 50 | @command.command(usage='next [cnt] [at] [cc] [cf] [px]', 51 | desc='Displays future contests. If `cnt` is absent, displays the next contest. ' 52 | 'If `all`, displays all upcoming contests. If `day`, displays contests ' 53 | 'which start within the next 24 hours. Optional site filters can be used, ' 54 | 'where `at` = *AtCoder*, `cc` = *CodeChef* and `cf` = *Codeforces*') 55 | async def next(bot, args, message): 56 | args = [arg.lower() for arg in args] 57 | site_tag_to_name = {} 58 | cnt = None 59 | for arg in args: 60 | name = bot.site_container.get_site_name(arg) 61 | if name is not None: 62 | site_tag_to_name[arg] = name 63 | elif arg in ('all', 'day'): 64 | command.assert_none(cnt, msg='More than 1 cnt argument', cmd=message.content) 65 | cnt = arg 66 | else: 67 | raise command.IncorrectUsageException(msg=f'Unrecognized argument "{arg}"', cmd=message.content) 68 | cnt = cnt or 1 69 | 70 | if cnt == 'day': 71 | start_max = datetime.now().timestamp() + timedelta(days=1).total_seconds() 72 | contests = bot.site_container.get_future_contests_before(start_max, site_tag_to_name.keys()) 73 | logger.info(f'{len(contests)} contests fetched before {start_max}') 74 | else: 75 | contests = bot.site_container.get_future_contests_cnt(cnt, site_tag_to_name.keys()) 76 | logger.info(f'{len(contests)} contests fetched out of {cnt}') 77 | 78 | if contests: 79 | reply = create_message_from_contests(contests, cnt, site_tag_to_name.values(), bot.TIMEZONE) 80 | await paginator.paginate_and_send(reply, bot, message.channel_id, per_page=bot.CONTESTS_PER_PAGE, 81 | time_active=15 * 60, time_delay=2 * 60) 82 | else: 83 | reply = {'content': '*No contest found*'} 84 | await bot.client.send_message(reply, message.channel_id) 85 | 86 | 87 | def create_message_from_contests(contests, cnt, site_names, bot_timezone): 88 | descs = [] 89 | for contest in contests: 90 | start = datetime.fromtimestamp(contest.start, bot_timezone) 91 | start = start.strftime('%d %b %y, %H:%M') 92 | 93 | duration_days, rem_secs = divmod(contest.length, 60 * 60 * 24) 94 | duration_hrs, rem_secs = divmod(rem_secs, 60 * 60) 95 | duration_mins, rem_secs = divmod(rem_secs, 60) 96 | duration = f'{duration_hrs}h {duration_mins}m' 97 | if duration_days > 0: 98 | duration = f'{duration_days}d ' + duration 99 | 100 | descs.append((contest.name, contest.site_name, start, duration, contest.url)) 101 | 102 | max_site_name_len = max(len(desc[1]) for desc in descs) 103 | max_duration_len = max(len(desc[3]) for desc in descs) 104 | em = '\u2001' 105 | 106 | def make_field(name, site_name, start, duration, url): 107 | return { 108 | 'name': name, 109 | 'value': (f'`{site_name.ljust(max_site_name_len, em)}{em}|' 110 | f'{em}{start}{em}|' 111 | f'{em}{duration.rjust(max_duration_len, em)}{em}|' 112 | f'{em}`[`link \u25F3`]({url} "Link to contest page")'), 113 | } 114 | 115 | if cnt == 'day': 116 | title = 'Contests that start under 24 hours from now' 117 | else: 118 | title = 'Upcoming contests' 119 | embed = { 120 | 'fields': [make_field(*desc) for desc in descs], 121 | } 122 | if site_names: 123 | embed['description'] = 'Showing only: ' + ', '.join(name for name in site_names) 124 | 125 | message = { 126 | 'content': f'*{title}*', 127 | 'embed': embed, 128 | } 129 | return message 130 | 131 | 132 | @command.command(desc='Displays bot status') 133 | async def status(bot, args, message): 134 | command.assert_arglen(args, 0, cmd=message.content) 135 | reply = copy.deepcopy(bot.status_message) 136 | now = time.time() 137 | uptime = (now - bot.client.start_time) / 3600 138 | field1 = { 139 | 'name': 'Bot Uptime', 140 | 'value': f'Online since {uptime:.1f} hrs ago' 141 | } 142 | field2 = { 143 | 'name': 'Last Updated', 144 | 'value': '', 145 | } 146 | # TODO: Shift the code below to a member function of Site. 147 | for site in bot.site_container.sites: 148 | last = (now - site.contests_last_fetched) / 60 149 | field2['value'] += f'{site.NAME}: {last:.0f} mins ago\n' 150 | reply['embed']['fields'] += [field1, field2] 151 | await bot.client.send_message(reply, message.channel_id) 152 | 153 | 154 | @command.command(usage='showsub [at|cc|cf]', 155 | desc='Show registered profiles. A particular site can be specified', 156 | allow_dm=True) 157 | async def showsub(bot, args, message): 158 | user_id = message.author.id 159 | user = bot.entity_manager.get_user(user_id) 160 | if not args: 161 | if user is None: 162 | reply = {'content': f'*You are not subscribed to any site*'} 163 | await bot.client.send_message(reply, message.channel_id) 164 | return 165 | 166 | embed = user.get_all_profiles_embed() 167 | if not embed: 168 | reply = {'content': f'*You are not subscribed to any site*'} 169 | await bot.client.send_message(reply, message.channel_id) 170 | return 171 | 172 | reply = { 173 | 'content': '*Your registered profiles:*', 174 | 'embed': embed, 175 | } 176 | await bot.client.send_message(reply, message.channel_id) 177 | return 178 | 179 | command.assert_arglen(args, 1, cmd=message.content) 180 | site_tag = args[0].lower() 181 | site_name = bot.site_container.get_site_name(site_tag) 182 | command.assert_not_none(site_name, msg='Unrecognized site', cmd=message.content) 183 | 184 | if user is None: 185 | reply = {'content': f'*You are not subscribed to {site_name}*'} 186 | await bot.client.send_message(reply, message.channel_id) 187 | return 188 | 189 | profile = bot.entity_manager.get_user(user_id).get_profile_for_site(site_tag) 190 | if profile is None: 191 | reply = {'content': f'*You are not subscribed to {site_name}*'} 192 | await bot.client.send_message(reply, message.channel_id) 193 | return 194 | 195 | embed = bot.entity_manager.get_user(user_id).get_profile_embed(site_tag) 196 | reply = {'embed': embed} 197 | await bot.client.send_message(reply, message.channel_id) 198 | 199 | 200 | @command.command(usage='sub at|cc|cf handle', 201 | desc='Subscribe to profile changes', 202 | allow_dm=True) 203 | async def sub(bot, args, message): 204 | command.assert_arglen(args, 2, cmd=message.content) 205 | user_id = message.author.id 206 | site_tag = args[0].lower() 207 | site_name = bot.site_container.get_site_name(site_tag) 208 | command.assert_not_none(site_name, msg='Unrecognized site', cmd=message.content) 209 | 210 | await bot.client.trigger_typing(message.channel_id) 211 | handle = args[1] 212 | profile = await bot.site_container.fetch_profile(handle, site_tag=site_tag) 213 | if profile is None: 214 | reply = {'content': '*No user found with given handle*'} 215 | await bot.client.send_message(reply, message.channel_id) 216 | return 217 | 218 | if bot.entity_manager.get_user(user_id) is None: 219 | # Register new user with DM channel ID. 220 | channel = await bot.get_channel(message.channel_id) 221 | if channel.type != Channel.Type.DM: 222 | channel = await bot.client.get_dm_channel(user_id) 223 | bot.entity_manager.create_user(user_id, channel.id) 224 | 225 | await bot.entity_manager.update_user_site_profile(user_id, profile) 226 | embed = bot.entity_manager.get_user(user_id).get_profile_embed(site_tag) 227 | reply = { 228 | 'content': '*Your profile has been registered*', 229 | 'embed': embed, 230 | } 231 | await bot.client.send_message(reply, message.channel_id) 232 | 233 | 234 | @command.command(usage='unsub at|cc|cf', 235 | desc='Unsubscribe from profile changes', 236 | allow_dm=True) 237 | async def unsub(bot, args, message): 238 | command.assert_arglen(args, 1, cmd=message.content) 239 | user_id = message.author.id 240 | site_tag = args[0].lower() 241 | site_name = bot.site_container.get_site_name(site_tag) 242 | command.assert_not_none(site_name, msg='Unrecognized site', cmd=message.content) 243 | 244 | user = bot.entity_manager.get_user(user_id) 245 | if user is None: 246 | reply = {'content': f'*You are not subscribed to {site_name}*'} 247 | await bot.client.send_message(reply, message.channel_id) 248 | return 249 | 250 | profile = bot.entity_manager.get_user(user_id).get_profile_for_site(site_tag) 251 | if profile is None: 252 | reply = {'content': f'*You are not subscribed to {site_name}*'} 253 | await bot.client.send_message(reply, message.channel_id) 254 | return 255 | 256 | await bot.entity_manager.delete_user_site_profile(user_id, site_tag) 257 | reply = {'content': f'*You are now unsubscribed from {site_name}*'} 258 | await bot.client.send_message(reply, message.channel_id) 259 | -------------------------------------------------------------------------------- /bot/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Bot", 3 | "triggers": [ 4 | "trigger" 5 | ], 6 | "channels": [ 7 | "channel_id" 8 | ], 9 | "activity": "activity", 10 | "db_name": "db", 11 | "at_config": { 12 | "contest_refresh_interval": 600, 13 | "user_refresh_interval": 2700, 14 | "user_delay_interval": 10 15 | }, 16 | "cc_config": { 17 | "contest_refresh_interval": 600, 18 | "user_refresh_interval": 2700, 19 | "user_delay_interval": 10 20 | }, 21 | "cf_config": { 22 | "contest_refresh_interval": 600, 23 | "user_refresh_interval": 1800, 24 | "user_delay_interval": 2 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /bot/db/__init__.py: -------------------------------------------------------------------------------- 1 | from .mongodb_connector import MongoDBConnector 2 | 3 | __all__ = ['MongoDBConnector'] 4 | -------------------------------------------------------------------------------- /bot/db/mongodb_connector.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import motor.motor_asyncio 4 | 5 | 6 | class MongoDBConnector: 7 | """Handles connection with a MongoDB database.""" 8 | 9 | def __init__(self, srv_url, db_name): 10 | self.srv_url = srv_url 11 | self.db_name = db_name 12 | self.client = None 13 | self.db = None 14 | self.logger = logging.getLogger(self.__class__.__qualname__) 15 | 16 | def connect(self): 17 | """Initialize MongoDB client to provided database.""" 18 | self.logger.info('Connecting to MongoDB') 19 | loop = asyncio.get_running_loop() 20 | self.client = motor.motor_asyncio.AsyncIOMotorClient(self.srv_url, io_loop=loop) 21 | self.db = self.client[self.db_name] 22 | 23 | async def put_user(self, user): 24 | """Store a user to the database.""" 25 | await self.db.users.replace_one({'discord_id': user['discord_id']}, user, upsert=True) 26 | 27 | async def put_channel(self, channel): 28 | """Store a channel to the database.""" 29 | await self.db.channels.replace_one({'id': channel['id']}, channel, upsert=True) 30 | 31 | async def get_all_users(self): 32 | """Retrieve a list of all users from the database.""" 33 | cursor = self.db.users.find() 34 | return await cursor.to_list(length=None) 35 | 36 | async def get_all_channels(self): 37 | """Retrieve a list of all channels from the database.""" 38 | cursor = self.db.channels.find() 39 | return await cursor.to_list(length=None) 40 | -------------------------------------------------------------------------------- /bot/discord/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import Client, EventType 2 | from .models import Channel, Message, User 3 | 4 | __all__ = ['Channel', 'Client', 'EventType', 'Message', 'User'] 5 | -------------------------------------------------------------------------------- /bot/discord/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import platform 5 | import time 6 | from enum import IntEnum 7 | 8 | import aiohttp 9 | 10 | from .models import Channel, Message 11 | 12 | 13 | class Opcode(IntEnum): 14 | DISPATCH = 0 15 | HEARTBEAT = 1 16 | IDENTIFY = 2 17 | STATUS_UPDATE = 3 18 | VOICE_STATE_UPDATE = 4 19 | RESUME = 6 20 | RECONNECT = 7 21 | REQUEST_GUILD_MEMBERS = 8 22 | INVALID_SESSION = 9 23 | HELLO = 10 24 | HEARTBEAT_ACK = 11 25 | 26 | 27 | class EventType: 28 | # This list is not exhaustive. 29 | READY = 'READY' 30 | CHANNEL_CREATE = 'CHANNEL_CREATE' 31 | CHANNEL_UPDATE = 'CHANNEL_UPDATE' 32 | CHANNEL_DELETE = 'CHANNEL_DELETE' 33 | GUILD_CREATE = 'GUILD_CREATE' 34 | GUILD_UPDATE = 'GUILD_UPDATE' 35 | GUILD_DELETE = 'GUILD_DELETE' 36 | MESSAGE_CREATE = 'MESSAGE_CREATE' 37 | MESSAGE_UPDATE = 'MESSAGE_UPDATE' 38 | MESSAGE_DELETE = 'MESSAGE_DELETE' 39 | PRESENCE_UPDATE = 'PRESENCE_UPDATE' 40 | TYPING_START = 'TYPING_START' 41 | MESSAGE_REACTION_ADD = 'MESSAGE_REACTION_ADD' 42 | MESSAGE_REACTION_REMOVE = 'MESSAGE_REACTION_REMOVE' 43 | 44 | 45 | class Client: 46 | API_URL = 'https://discordapp.com/api' 47 | 48 | def __init__(self, token, name='Bot', activity_name=None): 49 | self.token = token 50 | self.name = name 51 | self.headers = { 52 | 'Authorization': f'Bot {self.token}', 53 | 'User-Agent': self.name, 54 | } 55 | self.activity_name = activity_name 56 | 57 | self.on_message = None 58 | self.listeners = {} 59 | self.user = None 60 | self.start_time = None 61 | self.last_seq = None 62 | self.logger = logging.getLogger(self.__class__.__qualname__) 63 | 64 | async def run(self, *, on_message): 65 | """Connect to Discord and run forever.""" 66 | self.on_message = on_message 67 | resp = await self._request('GET', '/gateway') 68 | socket_url = resp['url'] 69 | async with aiohttp.ClientSession() as session: 70 | async with session.ws_connect(f'{socket_url}?v=6&encoding=json') as ws: 71 | self.start_time = time.time() 72 | self.logger.info('Websocket connected') 73 | async for msg in ws: 74 | if msg.type == aiohttp.WSMsgType.CLOSE: 75 | self.logger.error(f'Discord closed the connection: {msg.data}, {msg.extra}') 76 | raise Exception(f'Websocket closed') 77 | if msg.type == aiohttp.WSMsgType.ERROR: 78 | self.logger.error(f'Websocket error response: {msg.data}') 79 | raise Exception(f'Websocket error') 80 | elif msg.type == aiohttp.WSMsgType.TEXT: 81 | await self._handle_message(ws, msg.data) 82 | else: 83 | self.logger.warning(f'Unhandled type: {msg.type}, {msg.data}') 84 | raise Exception('Discord websocket disconnected') 85 | 86 | async def _request(self, method, path, headers=None, json_data=None, expect_json=True): 87 | """Send a HTTP request to the Discord API.""" 88 | headers = headers or self.headers 89 | self.logger.debug(f'Request: {method} {path} {headers} {json_data}') 90 | async with aiohttp.request(method, f'{self.API_URL}{path}', headers=headers, json=json_data) as response: 91 | # TODO: Implement a way of ensuring rate limits 92 | response.raise_for_status() 93 | if expect_json: 94 | return await response.json() 95 | 96 | async def _handle_message(self, ws, msg): 97 | """Handle a websocket message.""" 98 | msg = json.loads(msg) 99 | op = msg['op'] 100 | if msg.get('s'): 101 | self.last_seq = msg['s'] 102 | typ = msg.get('t') 103 | data = msg.get('d') 104 | self.logger.info(f'Received: {op} {typ}') 105 | if op == Opcode.HELLO: 106 | self.logger.info(data) 107 | reply = { 108 | 'op': Opcode.IDENTIFY, 109 | 'd': { 110 | 'token': self.token, 111 | 'properties': { 112 | '$os': platform.platform(terse=1), 113 | }, 114 | 'compress': False, 115 | }, 116 | } 117 | if self.activity_name: 118 | reply['d']['presence'] = { 119 | 'game': { 120 | 'name': self.activity_name, 121 | 'type': 0, 122 | }, 123 | 'status': 'online', 124 | 'since': None, 125 | 'afk': False, 126 | } 127 | await ws.send_json(reply) 128 | asyncio.create_task(self._heartbeat_task(ws, data['heartbeat_interval'])) 129 | elif op == Opcode.HEARTBEAT_ACK: 130 | self.logger.info('Heartbeat-ack received') 131 | elif op == Opcode.DISPATCH: 132 | self.logger.debug('Handling dispatch') 133 | await self._handle_dispatch(typ, data) 134 | else: 135 | self.logger.info(f'Did not handle opcode with data: {data}') 136 | 137 | async def _heartbeat_task(self, ws, interval_ms): 138 | """Run forever, send a heartbeat through the websocket ``ws`` every ``interval_ms`` milliseconds.""" 139 | interval_sec = interval_ms / 1000 140 | data = {'op': Opcode.HEARTBEAT} 141 | while True: 142 | await asyncio.sleep(interval_sec) 143 | data['d'] = self.last_seq 144 | self.logger.info(f'Sending heartbeat {self.last_seq}') 145 | await ws.send_json(data) 146 | 147 | async def _handle_dispatch(self, typ, data): 148 | """Handle a websocket dispatch event.""" 149 | if typ == EventType.READY: 150 | self.user = data['user'] 151 | self.logger.info(f'Self data: {self.user}') 152 | elif typ == EventType.MESSAGE_CREATE: 153 | if self.on_message: 154 | message = Message(**data) 155 | self.logger.debug('Calling on_message handler') 156 | # Run on_message as a separate coroutine. 157 | asyncio.create_task(self.on_message(message)) 158 | else: 159 | dict_ = self.listeners.get(typ) 160 | if dict_: 161 | for listener in dict_.values(): 162 | asyncio.create_task(listener(data)) 163 | 164 | def register_listener(self, event, tag, listener): 165 | """Register a listener to listen to Discord gateway dispatch events. 166 | 167 | :param event: the event type to listen to 168 | :param tag: a unique tag to identify the listener by 169 | :param listener: an awaitable listener 170 | """ 171 | dict_ = self.listeners.setdefault(event, {}) 172 | if tag in dict_: 173 | raise KeyError(f'Another listener with tag "{tag}" exists') 174 | dict_[tag] = listener 175 | 176 | def unregister_listener(self, event, tag): 177 | """Unregister a listener to Discord gateway dispatch events by tag. 178 | 179 | :param event: the event type the listener is registered with 180 | :param tag: the unique tag of the listener 181 | :return: whether a listener with the given tag was found and removed 182 | """ 183 | dict_ = self.listeners.get(event) 184 | if dict_ and tag in dict_: 185 | del dict_[tag] 186 | return True 187 | return False 188 | 189 | async def send_message(self, message, channel_id): 190 | """Send a message on Discord. 191 | 192 | :param message: the message as a dict 193 | :param channel_id: the channel to send the message to 194 | :return: the sent Message object 195 | """ 196 | self.logger.info(f'Sending messge to channel {channel_id}') 197 | message_d = await self._request('POST', f'/channels/{channel_id}/messages', json_data=message) 198 | return Message(**message_d) 199 | 200 | async def edit_message(self, channel_id, message_id, partial_message): 201 | """Edit a previously sent message. 202 | 203 | :param channel_id: the channel ID where the message exists 204 | :param message_id: the ID of the message 205 | :param partial_message: the partial message to replace the existing message 206 | :return: the updated Message object 207 | """ 208 | self.logger.info(f'Editing messge to channel {channel_id}') 209 | message_d = await self._request('PATCH', f'/channels/{channel_id}/messages/{message_id}', 210 | json_data=partial_message) 211 | return Message(**message_d) 212 | 213 | async def add_reaction(self, channel_id, message_id, emoji): 214 | """Add a reaction to a message. 215 | 216 | :param channel_id: the channel ID where the message exists 217 | :param message_id: the ID of the message 218 | :param emoji: the emoji to react with 219 | """ 220 | self.logger.info(f'Adding react {emoji} to message {message_id}') 221 | await self._request('PUT', f'/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me', 222 | expect_json=False) 223 | 224 | async def delete_own_reaction(self, channel_id, message_id, emoji): 225 | """Delete a reaction to a message by this bot. 226 | 227 | :param channel_id: the channel ID where the message exists 228 | :param message_id: the ID of the message 229 | :param emoji: the reaction to delete 230 | """ 231 | self.logger.info(f'Deleting react {emoji} to message {message_id}') 232 | await self._request('DELETE', f'/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me', 233 | expect_json=False) 234 | 235 | async def delete_all_reactions(self, channel_id, message_id): 236 | """Delete all reactions to a message. 237 | 238 | :param channel_id: the channel ID where the message exists 239 | :param message_id: the ID of the message 240 | """ 241 | self.logger.info(f'Deleting all reacts on message {message_id}') 242 | await self._request('DELETE', f'/channels/{channel_id}/messages/{message_id}/reactions', expect_json=False) 243 | 244 | async def get_channel(self, channel_id): 245 | """Get the channel object for the channel with given id.""" 246 | self.logger.info(f'Getting channel with id: {channel_id}') 247 | channel_d = await self._request('GET', f'/channels/{channel_id}') 248 | return Channel(**channel_d) 249 | 250 | async def get_dm_channel(self, user_id): 251 | """Get the channel object for the DM channel with the user with given id.""" 252 | self.logger.info(f'Getting DM channel for user: {user_id}') 253 | json_data = {'recipient_id': user_id} 254 | channel_d = await self._request('POST', '/users/@me/channels', json_data=json_data) 255 | return Channel(**channel_d) 256 | 257 | async def trigger_typing(self, channel_id): 258 | """Trigger the typing indicator on the channel with given id.""" 259 | self.logger.info(f'Triggering typing on channel {channel_id}') 260 | return await self._request('POST', f'/channels/{channel_id}/typing', expect_json=False) 261 | -------------------------------------------------------------------------------- /bot/discord/models.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class User: 5 | __slots__ = ('id', 'username', 'discriminator', 'bot') 6 | 7 | def __init__(self, **kwargs): 8 | self.id = kwargs['id'] 9 | self.username = kwargs['username'] 10 | self.discriminator = kwargs['discriminator'] 11 | self.bot = kwargs.get('bot') 12 | 13 | def to_dict(self): 14 | return { 15 | key: getattr(self, key) 16 | for key in self.__slots__ 17 | if getattr(self, key) is not None 18 | } 19 | 20 | 21 | class Channel: 22 | __slots__ = ('id', 'type', 'name', 'guild_id', 'recipients') 23 | 24 | class Type(IntEnum): 25 | GUILD_TEXT = 0 26 | DM = 1 27 | GUILD_VOICE = 2 28 | GROUP_DM = 3 29 | GUILD_CATEGORY = 4 30 | 31 | def __init__(self, **kwargs): 32 | self.id = kwargs['id'] 33 | self.type = self.Type(kwargs['type']) 34 | self.name = kwargs.get('name') 35 | self.guild_id = kwargs.get('guild_id') 36 | self.recipients = None 37 | if kwargs.get('recipients'): 38 | self.recipients = [User(**user_d) for user_d in kwargs.get('recipients')] 39 | 40 | def to_dict(self): 41 | channel_d = { 42 | key: getattr(self, key) 43 | for key in self.__slots__ 44 | if getattr(self, key) is not None 45 | } 46 | if self.recipients: 47 | channel_d['recipients'] = [user.to_dict() for user in self.recipients] 48 | return channel_d 49 | 50 | 51 | class Message: 52 | __slots__ = ('id', 'type', 'channel_id', 'webhook_id', 'author', 'content', 'embeds') 53 | 54 | class Type(IntEnum): 55 | DEFAULT = 0 56 | RECIPIENT_ADD = 1 57 | RECIPIENT_REMOVE = 2 58 | CALL = 3 59 | CHANNEL_NAME_CHANGE = 4 60 | CHANNEL_ICON_CHANGE = 5 61 | CHANNEL_PINNED_MESSAGE = 6 62 | GUILD_MEMBER_JOIN = 7 63 | 64 | def __init__(self, **kwargs): 65 | self.id = kwargs['id'] 66 | self.type = self.Type(kwargs['type']) 67 | self.channel_id = kwargs['channel_id'] 68 | self.webhook_id = kwargs.get('webhook_id') 69 | self.author = User(**kwargs['author']) if not self.webhook_id else None 70 | self.content = kwargs['content'] 71 | self.embeds = kwargs['embeds'] 72 | -------------------------------------------------------------------------------- /bot/entity_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .models import User 4 | from .discord import Channel 5 | 6 | 7 | class EntityManager: 8 | """Responsible for managing users and channels. 9 | 10 | Loads entities from the database on start up, and saves them to the database on modification. 11 | """ 12 | 13 | def __init__(self, db_connector): 14 | self.db_connector = db_connector 15 | self.users = None 16 | self._user_id_to_user = None 17 | self._channel_id_to_channel = None 18 | self.logger = logging.getLogger(self.__class__.__qualname__) 19 | 20 | async def run(self): 21 | """Connects to the database and loads users and channels.""" 22 | self.logger.debug('Running EntityManager...') 23 | self.db_connector.connect() 24 | await self._load_users() 25 | await self._load_channels() 26 | 27 | async def _load_users(self): 28 | users = await self.db_connector.get_all_users() 29 | self.users = [User.from_dict(user) for user in users] 30 | self._user_id_to_user = {user.discord_id: user for user in self.users} 31 | self.logger.info(f'Loaded {len(self.users)} users from db') 32 | 33 | async def _load_channels(self): 34 | channels = await self.db_connector.get_all_channels() 35 | channels = [Channel(**channel_d) for channel_d in channels] 36 | self._channel_id_to_channel = {channel.id: channel for channel in channels} 37 | self.logger.info(f'Loaded {len(channels)} channels from db') 38 | 39 | def get_user(self, user_id): 40 | """Looks up and returns a user by the user's Discord id, ``None`` if there is no such user""" 41 | return self._user_id_to_user.get(user_id) 42 | 43 | def create_user(self, user_id, dm_channel_id): 44 | """Creates a new user with given Discord id and DM channel id. Does nothing if user already 45 | exists. 46 | """ 47 | user = self._user_id_to_user.get(user_id) 48 | if user is None: 49 | user = User(user_id, dm_channel_id) 50 | self.users.append(user) 51 | self._user_id_to_user[user_id] = user 52 | 53 | async def update_user_site_profile(self, user_id, profile): 54 | """Creates or updates a user's site profile to the given profile. Returns whether the name 55 | or rating changed. 56 | """ 57 | user = self._user_id_to_user.get(user_id) 58 | changed_any, changed_name_or_rating = user.update_profile(profile) 59 | if changed_any: 60 | await self.db_connector.put_user(user.to_dict()) 61 | self.logger.info(f'Saved user with id {user_id} to db') 62 | return changed_name_or_rating 63 | 64 | async def delete_user_site_profile(self, user_id, site_tag): 65 | """Deletes a user's site profile associated with the given site tag.""" 66 | user = self._user_id_to_user.get(user_id) 67 | changed = user.delete_profile(site_tag) 68 | # TODO: Optimize to update instead of replace. 69 | if changed: 70 | await self.db_connector.put_user(user.to_dict()) 71 | self.logger.info(f'Saved user with id {user_id} to db') 72 | return changed 73 | 74 | def get_channel(self, channel_id): 75 | """Returns the channel with the given id, or ``None`` if no channel with given id is found.""" 76 | return self._channel_id_to_channel.get(channel_id) 77 | 78 | async def save_channel(self, channel): 79 | """Saves the given channel.""" 80 | self._channel_id_to_channel[channel.id] = channel 81 | await self.db_connector.put_channel(channel.to_dict()) 82 | self.logger.info(f'Saved channel with id {channel.id} to db') 83 | -------------------------------------------------------------------------------- /bot/models.py: -------------------------------------------------------------------------------- 1 | from .sites import Profile 2 | 3 | 4 | class User: 5 | """A user of the bot. 6 | 7 | Has attributes related to Discord and CP sites. 8 | """ 9 | 10 | def __init__(self, discord_id, dm_channel_id, site_profiles=None): 11 | self.discord_id = discord_id 12 | self.dm_channel_id = dm_channel_id 13 | if site_profiles is None: 14 | site_profiles = [] 15 | self.site_profiles = site_profiles 16 | self._profile_map = {profile.site_tag: profile for profile in self.site_profiles} 17 | 18 | def update_profile(self, profile): 19 | """Update or create the user's site profile. 20 | 21 | Returns ``True`` if the profile was created or updated, ``False`` if the profile was not changed. 22 | """ 23 | old_profile = self._profile_map.get(profile.site_tag) 24 | self._profile_map[profile.site_tag] = profile 25 | self.site_profiles = list(self._profile_map.values()) 26 | if old_profile is None: 27 | changed_any = changed_name_or_rating = True 28 | else: 29 | changed_any = old_profile.to_dict() != profile.to_dict() 30 | changed_name_or_rating = (old_profile.name, old_profile.rating) != (profile.name, profile.rating) 31 | return changed_any, changed_name_or_rating 32 | 33 | def delete_profile(self, site_tag): 34 | """Delete the user's profile aasociated with the given site tag. 35 | 36 | Returns ``True`` if the profile was found and deleted, ``False`` if the profile did not exist. 37 | """ 38 | profile = self._profile_map.get(site_tag) 39 | if profile is None: 40 | return False 41 | del self._profile_map[site_tag] 42 | self.site_profiles = list(self._profile_map.values()) 43 | return True 44 | 45 | def get_profile_for_site(self, site_tag): 46 | """Returns the site profile of the user for the given site tag, ``None`` if no such profile exists.""" 47 | return self._profile_map.get(site_tag) 48 | 49 | def get_profile_embed(self, site_tag): 50 | profile = self.get_profile_for_site(site_tag) 51 | if profile is None: 52 | return None 53 | return { 54 | 'author': profile.make_embed_author(), 55 | 'description': profile.make_embed_name_and_rating_text(), 56 | 'footer': profile.make_embed_footer(), 57 | } 58 | 59 | @staticmethod 60 | def get_profile_change_embed(old_profile, new_profile): 61 | return { 62 | 'author': new_profile.make_embed_author(), 63 | 'fields': [ 64 | { 65 | 'name': 'Previous', 66 | 'value': old_profile.make_embed_name_and_rating_text(), 67 | 'inline': 'true', 68 | }, 69 | { 70 | 'name': 'Current', 71 | 'value': new_profile.make_embed_name_and_rating_text(), 72 | 'inline': 'true', 73 | }, 74 | ], 75 | 'footer': new_profile.make_embed_footer(), 76 | } 77 | 78 | def get_all_profiles_embed(self): 79 | if not self.site_profiles: 80 | return None 81 | self.site_profiles.sort(key=lambda profile: profile.site_name) 82 | fields = [] 83 | for profile in self.site_profiles: 84 | field = { 85 | 'name': profile.site_name, 86 | 'value': profile.make_embed_handle_text() + '\n' 87 | + profile.make_embed_name_and_rating_text(), 88 | 'inline': True, 89 | } 90 | fields.append(field) 91 | return {'fields': fields} 92 | 93 | @classmethod 94 | def from_dict(cls, user_d): 95 | """Creates and returns a user object from its ``dict`` representation.""" 96 | return cls( 97 | user_d['discord_id'], 98 | user_d['dm_channel_id'], 99 | [Profile.from_dict(profile_dict) for profile_dict in user_d['site_profiles']] 100 | ) 101 | 102 | def to_dict(self): 103 | """Returns a ``dict`` representing the user.""" 104 | return { 105 | 'discord_id': self.discord_id, 106 | 'dm_channel_id': self.dm_channel_id, 107 | 'site_profiles': [profile.to_dict() for profile in self.site_profiles] 108 | } 109 | -------------------------------------------------------------------------------- /bot/paginator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import math 4 | import time 5 | 6 | from .discord import EventType 7 | 8 | EMOJI_PREV = '\N{BLACK LEFT-POINTING TRIANGLE}' 9 | EMOJI_NEXT = '\N{BLACK RIGHT-POINTING TRIANGLE}' 10 | 11 | 12 | class Paginated: 13 | """Represents a paginated message.""" 14 | 15 | def __init__(self, message, *, per_page): 16 | """ 17 | :param message: the message to paginate; the fields in the message embed are paginated 18 | :param per_page: the number of fields to show per page 19 | """ 20 | self.message = message 21 | self.fields = message['embed']['fields'] 22 | self.per_page = per_page 23 | self.num_pages = math.ceil(len(self.fields) / per_page) 24 | 25 | self.bot = None 26 | self.cur_page = None 27 | self.sent_message = None 28 | self.time_delay = None 29 | self.expire_handle = None 30 | self.tag = self.__class__.__name__ + str(time.time()) 31 | self.logger = logging.getLogger(self.__class__.__qualname__) 32 | 33 | def _set_page(self, page_num): 34 | """Set current page to ``page_num``.""" 35 | end = page_num * self.per_page 36 | begin = end - self.per_page 37 | self.cur_page = page_num 38 | self.message['embed']['fields'] = self.fields[begin:end] 39 | if self.num_pages > 1: 40 | self.message['embed']['footer'] = {'text': f'Page {page_num} / {self.num_pages}'} 41 | 42 | def schedule_unregister_after(self, time_delay): 43 | """Schedule unregister after a delay if that is later than current expiry time.""" 44 | loop = asyncio.get_running_loop() 45 | time_expire = loop.time() + time_delay 46 | if self.expire_handle is not None: 47 | if time_expire < self.expire_handle.when(): 48 | return 49 | self.expire_handle.cancel() 50 | self.expire_handle = loop.call_at(when=time_expire, callback=self.unregister) 51 | 52 | async def send(self, bot, channel_id, *, page_num=1, time_active, time_delay): 53 | """Send a paginated message. 54 | 55 | :param bot: the Bot instance to send the message through 56 | :param channel_id: the channel ID to send the mssage to 57 | :param page_num: the page number to display initially 58 | :param time_active: the time for which the message will be active 59 | :param time_delay: the time delay between last attempt to change pages and deactivation 60 | """ 61 | self.bot = bot 62 | self._set_page(page_num) 63 | self.sent_message = await bot.client.send_message(self.message, channel_id) 64 | if self.num_pages <= 1: 65 | # No need to paginate. 66 | return 67 | await asyncio.sleep(0.5) 68 | await self.bot.client.add_reaction(self.sent_message.channel_id, self.sent_message.id, EMOJI_PREV) 69 | await asyncio.sleep(0.5) # Delay to avoid 429 70 | await self.bot.client.add_reaction(self.sent_message.channel_id, self.sent_message.id, EMOJI_NEXT) 71 | bot.client.register_listener(EventType.MESSAGE_REACTION_ADD, self.tag, self._on_reaction_add_or_remove) 72 | bot.client.register_listener(EventType.MESSAGE_REACTION_REMOVE, self.tag, self._on_reaction_add_or_remove) 73 | self.logger.info(f'Paginating stuff') 74 | self.time_delay = time_delay 75 | self.schedule_unregister_after(time_active) 76 | 77 | async def _on_reaction_add_or_remove(self, data): 78 | """Event listener that is triggered when a reaction is added or removed.""" 79 | if data['user_id'] == self.bot.client.user['id']: 80 | return 81 | if data['message_id'] != self.sent_message.id: 82 | return 83 | emoji = data['emoji'].get('name') 84 | if emoji not in (EMOJI_PREV, EMOJI_NEXT): 85 | return 86 | changed = False 87 | if emoji == EMOJI_PREV: 88 | if self.cur_page > 1: 89 | self._set_page(self.cur_page - 1) 90 | changed = True 91 | else: 92 | if self.cur_page < self.num_pages: 93 | self._set_page(self.cur_page + 1) 94 | changed = True 95 | if changed: 96 | partial_message = {'embed': self.message['embed']} 97 | await self.bot.client.edit_message(self.sent_message.channel_id, self.sent_message.id, partial_message) 98 | self.logger.debug(f'Updated page') 99 | self.schedule_unregister_after(self.time_delay) 100 | 101 | def unregister(self): 102 | """Delete all reactions on paginated message to signify deactivation and remove registered listeners.""" 103 | self.logger.info('Removing paginator listeners') 104 | asyncio.create_task(self.bot.client.delete_all_reactions(self.sent_message.channel_id, self.sent_message.id)) 105 | self.bot.client.unregister_listener(EventType.MESSAGE_REACTION_ADD, self.tag) 106 | self.bot.client.unregister_listener(EventType.MESSAGE_REACTION_REMOVE, self.tag) 107 | 108 | 109 | async def paginate_and_send(message, bot, channel_id, *, per_page, initial_page=1, time_active, time_delay): 110 | """Convenience method to paginate and send the given message.""" 111 | paginated = Paginated(message, per_page=per_page) 112 | await paginated.send(bot, channel_id, page_num=initial_page, time_active=time_active, time_delay=time_delay) 113 | -------------------------------------------------------------------------------- /bot/sites/__init__.py: -------------------------------------------------------------------------------- 1 | from .atcoder import AtCoder 2 | from .codechef import CodeChef 3 | from .codeforces import Codeforces 4 | from .models import Profile 5 | from .site_container import SiteContainer 6 | 7 | __all__ = ['AtCoder', 'CodeChef', 'Codeforces', 'Profile', 'SiteContainer'] 8 | -------------------------------------------------------------------------------- /bot/sites/atcoder.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import aiohttp 4 | from bs4 import BeautifulSoup 5 | 6 | from .competitive_programming_site import CPSite 7 | from .models import Contest, Profile 8 | 9 | 10 | class AtCoder(CPSite): 11 | NAME = 'AtCoder' 12 | TAG = 'at' 13 | BASE_URL = 'https://beta.atcoder.jp' 14 | CONTESTS_PATH = '/contests' 15 | USERS_PATH = '/users' 16 | 17 | def __init__(self, *, contest_refresh_interval, user_refresh_interval, user_delay_interval): 18 | super().__init__(contest_refresh_interval, user_refresh_interval, user_delay_interval) 19 | 20 | async def _request(self, path): 21 | path = self.BASE_URL + path 22 | headers = {'User-Agent': f'aiohttp/{aiohttp.__version__}'} 23 | self.logger.debug(f'GET {path} {headers}') 24 | async with aiohttp.request('GET', path, headers=headers) as response: 25 | response.raise_for_status() 26 | return await response.text() 27 | 28 | async def fetch_future_contests(self): 29 | """Overrides method in ContestSite""" 30 | html = await self._request(self.CONTESTS_PATH) 31 | soup = BeautifulSoup(html, 'html.parser') 32 | 33 | title = soup.find(text='Upcoming Contests') 34 | if title is None: 35 | self.logger.info('No future contests') 36 | return [] 37 | 38 | # Assuming table has > 0 entries if "Upcoming Contests" title is present 39 | h4 = title.parent 40 | newline = h4.next_sibling 41 | div = newline.next_sibling 42 | tbody = div.find('tbody') 43 | rows = tbody.find_all('tr') 44 | future_contests = [] 45 | for row in rows: 46 | vals = row.find_all('td') 47 | 48 | time_tag = vals[0].find('time') 49 | # The string format is like so: 2018-09-08 21:00:00+0900 50 | fmt = '%Y-%m-%d %H:%M:%S%z' 51 | start = str(time_tag.string) 52 | start = datetime.strptime(start, fmt) 53 | start = int(start.timestamp()) 54 | 55 | name_tag = vals[1].find('a') 56 | url = self.BASE_URL + name_tag['href'] 57 | name = str(name_tag.string) 58 | 59 | # The duration format is like so: 01:40 60 | duration_str = str(vals[2].string) 61 | hrs, mins = duration_str.split(':') 62 | length = int(hrs) * 60 * 60 + int(mins) * 60 63 | 64 | future_contests.append(Contest(name, self.TAG, self.NAME, url, start, length)) 65 | 66 | future_contests.sort() 67 | return future_contests 68 | 69 | async def fetch_profile(self, handle): 70 | """Overrides method in CPSite""" 71 | path = self.USERS_PATH + '/' + handle 72 | try: 73 | html = await self._request(path) 74 | except aiohttp.ClientResponseError as err: 75 | if err.status == 404: 76 | # User not found. 77 | return None 78 | raise 79 | 80 | soup = BeautifulSoup(html, 'html.parser') 81 | 82 | # No option of real name on AtCoder. 83 | name = None 84 | 85 | avatar_tag = soup.find('img', class_='avatar') 86 | avatar = avatar_tag['src'] 87 | if 'avatar.png' in avatar: 88 | # Relative URL for default avatar 89 | avatar = self.BASE_URL + avatar 90 | else: 91 | # Absolute URL otherwise (img.atcoder.jp) 92 | pass 93 | 94 | rating_heading = soup.find('th', text='Rating') 95 | if rating_heading is None: 96 | # User is unrated. 97 | rating = None 98 | else: 99 | rating_tag = rating_heading.next_sibling.span 100 | rating = int(rating_tag.string) 101 | return Profile(handle, self.TAG, self.NAME, self.BASE_URL + path, avatar, name, rating) 102 | -------------------------------------------------------------------------------- /bot/sites/codechef.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import aiohttp 4 | from bs4 import BeautifulSoup 5 | 6 | from .competitive_programming_site import CPSite 7 | from .models import Contest, Profile 8 | 9 | 10 | class CodeChef(CPSite): 11 | NAME = 'CodeChef' 12 | TAG = 'cc' 13 | BASE_URL = 'https://www.codechef.com' 14 | CONTESTS_PATH = '/contests' 15 | USERS_PATH = '/users' 16 | 17 | def __init__(self, *, contest_refresh_interval, user_refresh_interval, user_delay_interval): 18 | super().__init__(contest_refresh_interval, user_refresh_interval, user_delay_interval) 19 | 20 | async def _request(self, path): 21 | path = self.BASE_URL + path 22 | headers = {'User-Agent': f'aiohttp/{aiohttp.__version__}'} 23 | self.logger.debug(f'GET {path} {headers}') 24 | async with aiohttp.request('GET', path, headers=headers, allow_redirects=False) as response: 25 | response.raise_for_status() 26 | if 301 <= response.status <= 399: 27 | raise ValueError(f'Request status {response.status}') 28 | return await response.text() 29 | 30 | async def fetch_future_contests(self): 31 | """Overrides method in ContestSite""" 32 | html = await self._request(self.CONTESTS_PATH) 33 | soup = BeautifulSoup(html, 'html.parser') 34 | 35 | title = soup.find(text='Future Contests') 36 | if title is None: 37 | self.logger.info('No future contests') 38 | return [] 39 | 40 | # Assuming table has > 0 entries if "Future Contests" title is present 41 | h3 = title.parent 42 | newline = h3.next_sibling 43 | div = newline.next_sibling 44 | tbody = div.find('tbody') 45 | rows = tbody.find_all('tr') 46 | future_contests = [] 47 | for row in rows: 48 | vals = row.find_all('td') 49 | url = self.BASE_URL + '/' + str(vals[0].string) 50 | name = str(vals[1].string) 51 | 52 | # The actual string format is like so: 2018-09-07T15:00:00+05:30 53 | # This function removes last colon so that strptime can parse it. 54 | def remove_last_colon(s): 55 | return ''.join(s.rsplit(':', 1)) 56 | 57 | fmt = '%Y-%m-%dT%H:%M:%S%z' 58 | start = remove_last_colon(vals[2]['data-starttime']) 59 | start = datetime.strptime(start, fmt) 60 | start = int(start.timestamp()) 61 | end = remove_last_colon(vals[3]['data-endtime']) 62 | end = datetime.strptime(end, fmt) 63 | end = int(end.timestamp()) 64 | length = end - start 65 | future_contests.append(Contest(name, self.TAG, self.NAME, url, start, length)) 66 | 67 | future_contests.sort() 68 | return future_contests 69 | 70 | async def fetch_profile(self, handle): 71 | """Overrides method in CPSite""" 72 | path = self.USERS_PATH + '/' + handle 73 | try: 74 | html = await self._request(path) 75 | except aiohttp.ClientResponseError as err: 76 | if err.status == 404: 77 | # User not found. 78 | return None 79 | raise 80 | except ValueError: 81 | # Team handle provided, site attempted to redirect. 82 | return None 83 | 84 | soup = BeautifulSoup(html, 'html.parser') 85 | 86 | user_details = soup.find('div', class_='user-details-container') 87 | name_tag = user_details.header.h2 88 | name = str(name_tag.string) 89 | avatar = self.BASE_URL + user_details.header.img['src'] 90 | 91 | rating_tag = soup.find('div', class_='rating-number') 92 | rating = int(rating_tag.string) 93 | if rating == 0: 94 | # User is either unrated or truly terrible at CP, assume former. 95 | rating = None 96 | return Profile(handle, self.TAG, self.NAME, self.BASE_URL + path, avatar, name, rating) 97 | -------------------------------------------------------------------------------- /bot/sites/codeforces.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | 3 | from .competitive_programming_site import CPSite 4 | from .models import Contest, Profile 5 | 6 | 7 | class Codeforces(CPSite): 8 | NAME = 'Codeforces' 9 | TAG = 'cf' 10 | API_URL = 'http://codeforces.com/api' 11 | API_CONTESTS_PATH = '/contest.list' 12 | API_USERS_PATH = '/user.info' 13 | BASE_URL = 'http://codeforces.com' 14 | CONTESTS_PATH = '/contests' 15 | USERS_PATH = '/profile' 16 | 17 | def __init__(self, *, contest_refresh_interval, user_refresh_interval, user_delay_interval): 18 | super().__init__(contest_refresh_interval, user_refresh_interval, user_delay_interval) 19 | 20 | async def _request(self, path, params=None, raise_for_status=True): 21 | path = self.API_URL + path 22 | self.logger.debug(f'GET {path} {params}') 23 | async with aiohttp.request('GET', path, params=params) as response: 24 | if raise_for_status: 25 | response.raise_for_status() 26 | return await response.json() 27 | 28 | async def fetch_future_contests(self): 29 | """Overrides method in ContestSite""" 30 | data = await self._request(self.API_CONTESTS_PATH) 31 | assert data['status'] == 'OK', data['comment'] 32 | contests = [contest for contest in data['result'] if contest['phase'] == 'BEFORE'] 33 | future_contests = [Contest(contest['name'], 34 | self.TAG, 35 | self.NAME, 36 | f'{self.BASE_URL}{self.CONTESTS_PATH}/{contest["id"]}', 37 | contest.get('startTimeSeconds'), 38 | contest['durationSeconds']) for contest in contests] 39 | # TODO: Consider how to handle contests with missing start 40 | future_contests = [contest for contest in future_contests if contest.start is not None] 41 | future_contests.sort() 42 | return future_contests 43 | 44 | async def fetch_profile(self, handle): 45 | """Override method in CPSite""" 46 | params = {'handles': handle} 47 | data = await self._request(self.API_USERS_PATH, params=params, raise_for_status=False) 48 | if data['status'] == 'FAILED' and 'not found' in data['comment']: 49 | # User not found. 50 | return None 51 | 52 | result = data['result'][0] 53 | fullname = ' '.join([result.get('firstName', ''), result.get('lastName', '')]) 54 | if fullname == ' ': 55 | fullname = None 56 | rating = result.get('rating') 57 | url = self.BASE_URL + self.USERS_PATH + '/' + handle 58 | # Avatar comes in the form '//userpic.codeforces.com//avatar/.jpg'. 59 | avatar = 'http:' + result['avatar'] 60 | return Profile(handle, self.TAG, self.NAME, url, avatar, fullname, rating) 61 | -------------------------------------------------------------------------------- /bot/sites/competitive_programming_site.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | from datetime import datetime, timezone 5 | 6 | 7 | class ContestSite: 8 | """A site that has contests.""" 9 | 10 | def __init__(self, contest_refresh_interval): 11 | """ 12 | :param contest_refresh_interval: the interval between consecutive requests to the site to fetch contests. 13 | """ 14 | self.contest_refresh_interval = contest_refresh_interval 15 | self.future_contests = None 16 | self.contests_last_fetched = None 17 | self.logger = logging.getLogger(self.__class__.__qualname__) 18 | 19 | async def run(self): 20 | """Update the contest list and schedule future updates.""" 21 | self.logger.info('Setting up site...') 22 | # Initial fetch. 23 | await self.update_contests() 24 | # Schedule for future. 25 | asyncio.create_task(self._contest_updater_task()) 26 | 27 | async def update_contests(self): 28 | """Update the list of future contests.""" 29 | self.future_contests = await self.fetch_future_contests() 30 | self.logger.info(f'Updated! {len(self.future_contests)} upcoming') 31 | self.logger.debug(f'Fetched contests: {self.future_contests}') 32 | self.contests_last_fetched = time.time() 33 | 34 | async def _contest_updater_task(self): 35 | """Run forever and update contests at regular intervals.""" 36 | while True: 37 | try: 38 | await asyncio.sleep(self.contest_refresh_interval) 39 | await self.update_contests() 40 | except asyncio.CancelledError: 41 | self.logger.info('Received CancelledError, stopping task') 42 | break 43 | except Exception as ex: 44 | self.logger.exception(f'Exception in fetching: {ex}, continuing regardless') 45 | 46 | async def fetch_future_contests(self): 47 | raise NotImplementedError('This method must be overridden') 48 | 49 | @staticmethod 50 | def filter_by_site(site_tags): 51 | if not site_tags: 52 | return lambda contest: True 53 | return lambda contest: contest.site_tag in site_tags 54 | 55 | @staticmethod 56 | def filter_by_start_min(start_min): 57 | return lambda contest: start_min < contest.start 58 | 59 | @staticmethod 60 | def filter_by_start_max(start_max): 61 | return lambda contest: contest.start <= start_max 62 | 63 | def _get_future_contests(self): 64 | """ Returns future contests. Because the contests are fetched every self.refresh_interval time, 65 | self.future_contests may contain contests which have already started. This function filters out such contests. 66 | """ 67 | now = datetime.now(timezone.utc).timestamp() 68 | future_contests = filter(self.filter_by_start_min(now), self.future_contests) 69 | return future_contests 70 | 71 | def get_future_contests_cnt(self, cnt, sites_tags): 72 | """Get the given number of future contests. 73 | 74 | :param cnt: the number of contests to get, and integer or ``"all"``. 75 | :param sites_tags: the site tags for sites to filter by. 76 | :return: a list of contests. 77 | """ 78 | self.logger.info(f'get_future_contests_cnt: {cnt} {sites_tags}') 79 | future_contests = self._get_future_contests() 80 | filtered_by_site = filter(self.filter_by_site(sites_tags), future_contests) 81 | if cnt == 'all': 82 | cnt = len(self.future_contests) 83 | return list(filtered_by_site)[:cnt] 84 | 85 | def get_future_contests_before(self, start_max, site_tags): 86 | """Get future contests starting before the given time. 87 | 88 | :param start_max: the maximum start time of the contest as UTC timestamp. 89 | :param site_tags: the site tags for sites to filter by. 90 | :return: a list of contests. 91 | """ 92 | future_contests = self._get_future_contests() 93 | filtered_by_site = filter(self.filter_by_site(site_tags), future_contests) 94 | filtered_by_start = filter(self.filter_by_start_max(start_max), filtered_by_site) 95 | return list(filtered_by_start) 96 | 97 | 98 | class CPSite(ContestSite): 99 | """A site that has contests as well as users.""" 100 | 101 | def __init__(self, contest_refresh_interval, user_refresh_interval, user_delay_interval): 102 | """ 103 | :param contest_refresh_interval: the interval between consecutive requests to the site to fetch contests. 104 | :param user_refresh_interval: the interval between consecutive requests to the site to fetch users. 105 | :param user_delay_interval: the delay between requests for two consecutive users. 106 | """ 107 | super().__init__(contest_refresh_interval) 108 | self.user_refresh_interval = user_refresh_interval 109 | self.user_delay_interval = user_delay_interval 110 | self.get_all_users = None 111 | self.on_profile_fetch = None 112 | 113 | async def run(self, get_all_users=None, on_profile_fetch=None): 114 | """ 115 | Schedule regular fetch of contests and profiles. 116 | 117 | :param get_all_users: the function that provides a list of users to fetch. 118 | :param on_profile_fetch: the callback to be executed when a profile is fetched. 119 | :return: 120 | """ 121 | self.get_all_users = get_all_users 122 | self.on_profile_fetch = on_profile_fetch 123 | await super().run() 124 | asyncio.create_task(self._user_updater_task()) 125 | 126 | async def update_users(self): 127 | """Update all users provided by the registered function ``get_all_users``.""" 128 | if self.get_all_users is None or self.on_profile_fetch is None: 129 | self.logger.info('Profile handlers not registered') 130 | return 131 | 132 | for user in self.get_all_users(): 133 | old_profile = user.get_profile_for_site(self.TAG) 134 | if old_profile is None: 135 | # The user has not registered a profile for this site. 136 | continue 137 | new_profile = await self.fetch_profile(old_profile.handle) 138 | self.logger.info(f'Profile with handle {old_profile.handle} fetched') 139 | await self.on_profile_fetch(user, old_profile, new_profile) 140 | await asyncio.sleep(self.user_delay_interval) 141 | 142 | async def _user_updater_task(self): 143 | """Run forever and update users at regular intervals.""" 144 | while True: 145 | try: 146 | await asyncio.sleep(self.user_refresh_interval) 147 | await self.update_users() 148 | except asyncio.CancelledError: 149 | self.logger.info('Received CancelledError, stopping task') 150 | break 151 | except Exception as ex: 152 | self.logger.exception(f'Exception in fetching: {ex}, continuing regardless') 153 | 154 | async def fetch_profile(self, handle): 155 | raise NotImplementedError('This method must be overridden') 156 | -------------------------------------------------------------------------------- /bot/sites/models.py: -------------------------------------------------------------------------------- 1 | class Contest: 2 | __slots__ = ('name', 'site_tag', 'site_name', 'url', 'start', 'length') 3 | 4 | def __init__(self, name, site_tag, site_name, url, start, length): 5 | """Represents a competitive programming contest. 6 | 7 | :param name: the name of the contest 8 | :param site_tag: the site tag 9 | :param site_name: the site name 10 | :param url: the URL of the contest 11 | :param start: the timestamp of the UTC start time 12 | :param length: the length of the contest in seconds 13 | """ 14 | 15 | self.name = name 16 | self.site_tag = site_tag 17 | self.site_name = site_name 18 | self.url = url 19 | self.start = start 20 | self.length = length 21 | 22 | def __lt__(self, other): 23 | return (self.start, self.length, self.site_name) < (other.start, other.length, other.site_name) 24 | 25 | def __repr__(self): 26 | return '' 27 | 28 | 29 | class Profile: 30 | __slots__ = ('handle', 'site_tag', 'site_name', 'url', 'avatar', 'name', 'rating') 31 | 32 | def __init__(self, handle, site_tag, site_name, url, avatar, name, rating): 33 | """Represents a user of a competitive programming site. 34 | 35 | :param handle: the user's handle 36 | :param site_tag: the site tag 37 | :param site_name: the site name 38 | :param url: the URL of the profile 39 | :param avatar: the URL of the user's avatar 40 | :param name: the user's full name, ``None`` if unavailable 41 | :param rating: the user's current rating, ``None`` if unrated 42 | """ 43 | 44 | self.handle = handle 45 | self.site_tag = site_tag 46 | self.site_name = site_name 47 | self.url = url 48 | self.avatar = avatar 49 | self.name = name 50 | self.rating = rating 51 | 52 | def make_embed_handle_text(self): 53 | return f'**Handle**: [{self.handle}]({self.url})' 54 | 55 | def make_embed_author(self): 56 | """Make an author section for a Discord embed.""" 57 | return { 58 | 'name': f'{self.handle}', 59 | 'url': self.url, 60 | 'icon_url': self.avatar, 61 | } 62 | 63 | def make_embed_name_and_rating_text(self): 64 | """Make a formatted string containing name and rating.""" 65 | desc = f'**Name**: {self.name}\n' if self.name is not None else '' 66 | desc += f'**Rating**: {self.rating if self.rating is not None else "Unrated"}' 67 | return desc 68 | 69 | def make_embed_footer(self): 70 | """make a footer with site name for a Discord embed.""" 71 | return {'text': self.site_name} 72 | 73 | @classmethod 74 | def from_dict(cls, profile_dict): 75 | params = [profile_dict.get(key) for key in cls.__slots__] 76 | return Profile(*params) 77 | 78 | def to_dict(self): 79 | return {key: getattr(self, key) for key in self.__slots__} 80 | -------------------------------------------------------------------------------- /bot/sites/site_container.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .competitive_programming_site import ContestSite 4 | 5 | 6 | class SiteContainer(ContestSite): 7 | """Manages multiple sites.""" 8 | 9 | def __init__(self, sites, contest_refresh_interval=None): 10 | """ 11 | :param sites: the list of ``CPSite`` objects to manage. 12 | :param contest_refresh_interval: the interval between updating the list of contests. 13 | """ 14 | if contest_refresh_interval is None: 15 | contest_refresh_interval = min(site.contest_refresh_interval for site in sites) 16 | super().__init__(contest_refresh_interval) 17 | self.sites = sites 18 | self._site_map = {site.TAG: site for site in self.sites} 19 | self.logger = logging.getLogger(self.__class__.__qualname__) 20 | 21 | async def run(self, get_all_users=None, on_profile_fetch=None): 22 | """Set up each site being managed. 23 | 24 | :param get_all_users: the function that provides a list of users to fetch. 25 | :param on_profile_fetch: the callback to be executed when a profile is fetched. 26 | """ 27 | self.logger.info('Setting up the SiteContainer...') 28 | for site in self.sites: 29 | await site.run(get_all_users=get_all_users, on_profile_fetch=on_profile_fetch) 30 | await super().run() 31 | 32 | async def fetch_future_contests(self): 33 | """Overrides method in ContestSite""" 34 | future_contests = [] 35 | for site in self.sites: 36 | future_contests += site.future_contests 37 | future_contests.sort() 38 | return future_contests 39 | 40 | async def fetch_profile(self, handle, site_tag): 41 | """Fetch the profile for the given handle and site.""" 42 | site = self._site_map[site_tag] 43 | profile = await site.fetch_profile(handle) 44 | self.logger.info(f'Fetched profile: {profile}') 45 | return profile 46 | 47 | def get_site_name(self, site_tag): 48 | """Get the site name corresponding to the given site tag.""" 49 | site = self._site_map.get(site_tag) 50 | return None if site is None else site.NAME 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.4.2 2 | beautifulsoup4==4.6.3 3 | dnspython==1.15.0 4 | motor==2.0.0 5 | --------------------------------------------------------------------------------