├── .editorconfig ├── .env.example ├── .github └── workflows │ ├── build.yaml │ └── lint.yaml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yaml ├── extra ├── fonts.conf └── scrape_cf_contest_writers.py ├── requirements.txt ├── ruff.toml └── tle ├── __init__.py ├── __main__.py ├── cogs ├── cache_control.py ├── codeforces.py ├── contests.py ├── deactivated │ └── cses.py ├── duel.py ├── graphs.py ├── handles.py ├── logging.py ├── meta.py └── starboard.py ├── constants.py └── util ├── __init__.py ├── cache_system2.py ├── codeforces_api.py ├── codeforces_common.py ├── cses_scraper.py ├── db ├── __init__.py ├── cache_db_conn.py └── user_db_conn.py ├── discord_common.py ├── events.py ├── font_downloader.py ├── graph_common.py ├── handledict.py ├── paginator.py ├── ranklist ├── __init__.py ├── ranklist.py └── rating_calculator.py ├── table.py └── tasks.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # Unix-style newlines with a newline ending every file 2 | [*] 3 | end_of_line = lf 4 | insert_final_newline = true 5 | 6 | # Set default charset 7 | [*.{js,py}] 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.py] 12 | indent_style = space 13 | indent_size = 4 14 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | export BOT_TOKEN="XXXXXXXXXXXXXXXXXXXXXXXX.XXXXXX.XXXXXXXXXXXXXXXXXXXXXXXXXXX" 2 | export LOGGING_COG_CHANNEL_ID="XXXXXXXXXXXXXXXXXX" 3 | export ALLOW_DUEL_SELF_REGISTER="false" 4 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Docker Build 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Build Docker image 15 | run: docker build -t tle-bot . 16 | - name: Create minimal env file for testing 17 | run: | 18 | echo "BOT_TOKEN=dummy_token" > .env.test 19 | echo "LOGGING_COG_CHANNEL_ID=123456789012345678" >> .env.test 20 | - name: Verify Docker image 21 | run: | 22 | docker run --rm --env-file .env.test tle-bot python -c "import tle; print('Import successful')" 23 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.11" 17 | - name: Install Ruff 18 | run: pip install ruff 19 | - name: Run Ruff linting 20 | run: ruff check . 21 | - name: Run Ruff formatting check 22 | run: ruff format --check . 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | files/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | environment 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | .idea/ 109 | pip-wheel-metadata/ 110 | 111 | # vscode 112 | .vscode/ 113 | 114 | # data 115 | /data 116 | 117 | # logs 118 | /logs 119 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | libcairo2 gir1.2-pango-1.0 \ 5 | gobject-introspection python3-gi python3-gi-cairo python3-cairo \ 6 | libjpeg-dev zlib1g-dev \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | ENV PYTHONPATH=/usr/lib/python3/dist-packages 10 | ENV FONTCONFIG_FILE=/bot/extra/fonts.conf 11 | ENV PYTHONUNBUFFERED=1 12 | 13 | WORKDIR /bot 14 | COPY requirements.txt . 15 | RUN pip install --no-cache-dir -r requirements.txt 16 | 17 | COPY . . 18 | 19 | CMD ["python", "-m", "tle"] 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Cheran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TLE ― The Competitive-Programming Discord Bot 2 | 3 | TLE is a feature-packed Discord bot aimed at competitive programmers 4 | (Codeforces, CSES, …). 5 | It can recommend problems, show stats & graphs, run duels on your server 6 | and manage starboards – all with a single prefix `;`. 7 | 8 | If you have Docker ≥ 24 (or Docker Desktop on Win/Mac) you are ready to 9 | go. 10 | 11 | --- 12 | 13 | ## 1 · Features (quick glance) 14 | 15 | | Cog | What it does | 16 | |-----|--------------| 17 | | **Codeforces** | problem / contest recommender, rating changes, user look-ups | 18 | | **Contests** | shows upcoming & live contests | 19 | | **Graphs** | rating distributions, solved-set histograms, etc. | 20 | | **Handles** | link Discord users to CF handles | 21 | | **CSES** | CSES leaderboard, problem info | 22 | | **Starboard** | pins popular messages to a channel | 23 | | **CacheControl** | warm-up & manage local caches | 24 | 25 | All graphs require cairo + pango; the Docker image already contains 26 | everything. 27 | 28 | --- 29 | 30 | ## 2 · Quick start (production) 31 | 32 | ```bash 33 | # 1 · clone the repo 34 | git clone https://github.com/cheran-senthil/TLE 35 | cd TLE 36 | 37 | # 2 · create a config file 38 | cp .env.example .env # then edit BOT_TOKEN, LOGGING_COG_CHANNEL_ID … 39 | 40 | # 3 · build & start the bot (first run takes ~2 min) 41 | docker compose up -d 42 | ``` 43 | 44 | That’s it. 45 | The bot will appear online in your Discord server; use 46 | `;help` inside Discord to explore commands. 47 | 48 | ### Updating to a new release 49 | 50 | ```sh 51 | git pull 52 | docker compose build --pull # fetch newer base images 53 | docker compose up -d # zero-downtime restart 54 | ``` 55 | 56 | --- 57 | 58 | ## 3 · Environment variables ( `.env` ) 59 | 60 | | Variable | Required | Example | Description | 61 | |----------|----------|---------|-------------| 62 | | `BOT_TOKEN` | ✅ | `MTEz…` | Discord bot token from the Dev Portal | 63 | | `LOGGING_COG_CHANNEL_ID` | ✅ | `123456789012345678` | channel where uncaught errors are sent | 64 | | `ALLOW_DUEL_SELF_REGISTER` | ❌ | `true` | let users self-register for duels | 65 | | `TLE_ADMIN` | ❌ | `Admin` | role name that can run admin cmds | 66 | | `TLE_MODERATOR` | ❌ | `Moderator` | role name that can run mod cmds | 67 | 68 | Feel free to add any extra variables your cogs consume; Compose passes 69 | every key in `.env` to the container. 70 | 71 | --- 72 | 73 | ## 4 · Data & cache folder 74 | 75 | `docker compose` mounts `./data` into the container. 76 | It holds: 77 | 78 | * Codeforces caches & contest writers JSON 79 | * downloaded CJK fonts (~36 MB, fetched automatically) 80 | 81 | You can back this directory up or move it to a dedicated volume; wiping 82 | it only means the bot will re-download items on first run. 83 | 84 | --- 85 | 86 | ## 5 · Local development (optional) 87 | 88 | You can hack on the code without touching your system Python: 89 | 90 | ```bash 91 | # live-reload dev run (blocks & shows logs) 92 | docker compose up --build 93 | ``` 94 | 95 | Lint & format (Ruff): 96 | 97 | ```bash 98 | docker run --rm -v $PWD:/app -w /app python:3.11-slim \ 99 | sh -c "pip install ruff && ruff check . && ruff format --check ." 100 | ``` 101 | 102 | --- 103 | 104 | ## 6 · Repository layout 105 | 106 | ```sh 107 | . 108 | ├─ Dockerfile # 2-stage image, installs native cairo stack 109 | ├─ compose.yaml # single-service compose file 110 | ├─ requirements.txt # runtime Python deps (no pins) 111 | ├─ .env.example # template for your secrets 112 | ├─ data/ # persisted cache & fonts (git-ignored) 113 | ├─ tle/ … # bot source code 114 | └─ extra/ fonts.conf … # helper resources 115 | ``` 116 | 117 | --- 118 | 119 | ## 7 · Contributing 120 | 121 | Pull requests are welcome! 122 | Before opening a PR, please 123 | 124 | 1. run `ruff check --fix .` (auto-formats touched lines), 125 | 2. keep commits focused; large refactors in a separate PR. 126 | 127 | --- 128 | 129 | ## 8 · License 130 | 131 | MIT ― see `LICENSE`. 132 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | tle: 3 | build: . 4 | image: tle-bot:latest 5 | 6 | env_file: .env 7 | 8 | volumes: 9 | - ./data:/bot/data 10 | 11 | restart: unless-stopped 12 | environment: 13 | - TZ=UTC 14 | -------------------------------------------------------------------------------- /extra/fonts.conf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TLE fontconfig file 5 | data/assets/fonts 6 | 7 | -------------------------------------------------------------------------------- /extra/scrape_cf_contest_writers.py: -------------------------------------------------------------------------------- 1 | """This script scrapes contests and their writers from Codeforces and saves 2 | them to a JSON file. This exists because there is no way to do this through 3 | the official API :( 4 | """ 5 | 6 | import json 7 | import urllib.request 8 | 9 | from lxml import html 10 | 11 | URL = 'https://codeforces.com/contests/page/{}?complete=true' 12 | JSONFILE = 'contest_writers.json' 13 | 14 | 15 | def get_page(pagenum): 16 | url = URL.format(pagenum) 17 | with urllib.request.urlopen(url) as f: 18 | text = f.read().decode() 19 | return html.fromstring(text) 20 | 21 | 22 | def get_contests(doc): 23 | contests = [] 24 | rows = doc.xpath('//div[@class="contests-table"]//table[1]//tr')[1:] 25 | for row in rows: 26 | contest_id = int(row.get('data-contestid')) 27 | name, writers, start, length, standings, registrants = row.xpath('td') 28 | writers = writers.text_content().split() 29 | contests.append({'id': contest_id, 'writers': writers}) 30 | return contests 31 | 32 | 33 | print('Fetching page 1') 34 | page1 = get_page(1) 35 | lastpage = int(page1.xpath('//span[@class="page-index"]')[-1].get('pageindex')) 36 | 37 | contests = get_contests(page1) 38 | print(f'Found {len(contests)} contests') 39 | 40 | for pagenum in range(2, lastpage + 1): 41 | print(f'Fetching page {pagenum}') 42 | page = get_page(pagenum) 43 | page_contests = get_contests(page) 44 | print(f'Found {len(page_contests)} contests') 45 | contests.extend(page_contests) 46 | 47 | print(f'Found total {len(contests)} contests') 48 | 49 | with open(JSONFILE, 'w') as f: 50 | json.dump(contests, f) 51 | print(f'Data written to {JSONFILE}') 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Discord bot core 2 | discord.py==1.7.3 3 | aiohttp<3.8 4 | aiocache 5 | 6 | # Data & plotting 7 | numpy 8 | pandas 9 | matplotlib 10 | seaborn 11 | 12 | # Images 13 | pillow<10 14 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | target-version = "py311" 2 | line-length = 88 3 | indent-width = 4 4 | 5 | extend-exclude = [ 6 | ".cache", 7 | ".venv", 8 | "data", 9 | ] 10 | 11 | [lint] 12 | select = ["E", "F", "W", "I", "B"] 13 | ignore = [ 14 | "B904", # Within an except clause, raise exceptions with raise ... from err 15 | ] 16 | 17 | [lint.per-file-ignores] 18 | "tle/util/db/__init__.py" = ["F403"] # Star imports 19 | "tle/util/ranklist/__init__.py" = ["F403"] # Star imports 20 | "tle/cogs/handles.py" = ["E402"] # GTK requires version setting before import 21 | "tle/cogs/logging.py" = ["E722"] # Intentional bare except in error handler 22 | "tle/util/events.py" = ["E722"] # Intentional bare except in event error handler 23 | 24 | [format] 25 | line-ending = "lf" 26 | quote-style = "single" 27 | indent-style = "space" 28 | skip-magic-trailing-comma = false 29 | 30 | [lint.isort] 31 | case-sensitive = true 32 | combine-as-imports = true 33 | order-by-type = true 34 | -------------------------------------------------------------------------------- /tle/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /tle/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import logging 4 | import os 5 | from logging.handlers import TimedRotatingFileHandler 6 | from os import environ 7 | from pathlib import Path 8 | 9 | import discord 10 | import seaborn as sns 11 | from discord.ext import commands 12 | from matplotlib import pyplot as plt 13 | 14 | from tle import constants 15 | from tle.util import codeforces_common as cf_common, discord_common, font_downloader 16 | 17 | 18 | def setup(): 19 | # Make required directories. 20 | for path in constants.ALL_DIRS: 21 | os.makedirs(path, exist_ok=True) 22 | 23 | # logging to console and file on daily interval 24 | logging.basicConfig( 25 | format='{asctime}:{levelname}:{name}:{message}', 26 | style='{', 27 | datefmt='%d-%m-%Y %H:%M:%S', 28 | level=logging.INFO, 29 | handlers=[ 30 | logging.StreamHandler(), 31 | TimedRotatingFileHandler( 32 | constants.LOG_FILE_PATH, when='D', backupCount=3, utc=True 33 | ), 34 | ], 35 | ) 36 | 37 | # matplotlib and seaborn 38 | plt.rcParams['figure.figsize'] = 7.0, 3.5 39 | sns.set() 40 | options = { 41 | 'axes.edgecolor': '#A0A0C5', 42 | 'axes.spines.top': False, 43 | 'axes.spines.right': False, 44 | } 45 | sns.set_style('darkgrid', options) 46 | 47 | # Download fonts if necessary 48 | font_downloader.maybe_download() 49 | 50 | 51 | def strtobool(value: str) -> bool: 52 | """ 53 | Convert a string representation of truth to true (1) or false (0). 54 | 55 | True values are y, yes, t, true, on and 1; false values are n, no, f, 56 | false, off and 0. Raises ValueError if val is anything else. 57 | """ 58 | value = value.lower() 59 | if value in ('y', 'yes', 't', 'true', 'on', '1'): 60 | return True 61 | if value in ('n', 'no', 'f', 'false', 'off', '0'): 62 | return False 63 | raise ValueError(f'Invalid truth value {value!r}.') 64 | 65 | 66 | def main(): 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--nodb', action='store_true') 69 | args = parser.parse_args() 70 | 71 | token = environ.get('BOT_TOKEN') 72 | if not token: 73 | logging.error('Token required') 74 | return 75 | 76 | allow_self_register = environ.get('ALLOW_DUEL_SELF_REGISTER') 77 | if allow_self_register: 78 | constants.ALLOW_DUEL_SELF_REGISTER = strtobool(allow_self_register) 79 | 80 | setup() 81 | 82 | intents = discord.Intents.default() 83 | intents.members = True 84 | 85 | bot = commands.Bot(command_prefix=commands.when_mentioned_or(';'), intents=intents) 86 | cogs = [file.stem for file in Path('tle', 'cogs').glob('*.py')] 87 | for extension in cogs: 88 | bot.load_extension(f'tle.cogs.{extension}') 89 | logging.info(f'Cogs loaded: {", ".join(bot.cogs)}') 90 | 91 | def no_dm_check(ctx): 92 | if ctx.guild is None: 93 | raise commands.NoPrivateMessage('Private messages not permitted.') 94 | return True 95 | 96 | # Restrict bot usage to inside guild channels only. 97 | bot.add_check(no_dm_check) 98 | 99 | # cf_common.initialize needs to run first, so it must be set as the bot's 100 | # on_ready event handler rather than an on_ready listener. 101 | @discord_common.on_ready_event_once(bot) 102 | async def init(): 103 | await cf_common.initialize(args.nodb) 104 | asyncio.create_task(discord_common.presence(bot)) 105 | 106 | bot.add_listener(discord_common.bot_error_handler, name='on_command_error') 107 | bot.run(token) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /tle/cogs/cache_control.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import time 3 | 4 | from discord.ext import commands 5 | 6 | from tle import constants 7 | from tle.util import codeforces_common as cf_common 8 | 9 | 10 | def timed_command(coro): 11 | @functools.wraps(coro) 12 | async def wrapper(cog, ctx, *args): 13 | await ctx.send('Running...') 14 | begin = time.time() 15 | await coro(cog, ctx, *args) 16 | elapsed = time.time() - begin 17 | await ctx.send(f'Completed in {elapsed:.2f} seconds') 18 | 19 | return wrapper 20 | 21 | 22 | class CacheControl(commands.Cog): 23 | """Cog to manually trigger update of cached data. Intended for dev/admin use.""" 24 | 25 | def __init__(self, bot): 26 | self.bot = bot 27 | 28 | @commands.group( 29 | brief='Commands to force reload of cache', invoke_without_command=True 30 | ) 31 | @commands.has_role(constants.TLE_ADMIN) 32 | async def cache(self, ctx): 33 | await ctx.send_help('cache') 34 | 35 | @cache.command() 36 | @commands.has_role(constants.TLE_ADMIN) 37 | @timed_command 38 | async def contests(self, ctx): 39 | await cf_common.cache2.contest_cache.reload_now() 40 | 41 | @cache.command() 42 | @commands.has_role(constants.TLE_ADMIN) 43 | @timed_command 44 | async def problems(self, ctx): 45 | await cf_common.cache2.problem_cache.reload_now() 46 | 47 | @cache.command(usage='[missing|all|contest_id]') 48 | @commands.has_role(constants.TLE_ADMIN) 49 | @timed_command 50 | async def ratingchanges(self, ctx, contest_id='missing'): 51 | """Defaults to 'missing'. Mode 'all' clears existing cached changes. 52 | Mode 'contest_id' clears existing changes with the given contest id. 53 | """ 54 | if contest_id not in ('all', 'missing'): 55 | try: 56 | contest_id = int(contest_id) 57 | except ValueError: 58 | return 59 | if contest_id == 'all': 60 | await ctx.send('This will take a while') 61 | count = await cf_common.cache2.rating_changes_cache.fetch_all_contests() 62 | elif contest_id == 'missing': 63 | await ctx.send('This may take a while') 64 | count = await cf_common.cache2.rating_changes_cache.fetch_missing_contests() 65 | else: 66 | count = await cf_common.cache2.rating_changes_cache.fetch_contest( 67 | contest_id 68 | ) 69 | await ctx.send(f'Done, fetched {count} changes and recached handle ratings') 70 | 71 | @cache.command(usage='contest_id|all') 72 | @commands.has_role(constants.TLE_ADMIN) 73 | @timed_command 74 | async def problemsets(self, ctx, contest_id): 75 | """Mode 'all' clears all existing cached problems. Mode 'contest_id' 76 | clears existing problems with the given contest id. 77 | """ 78 | if contest_id == 'all': 79 | await ctx.send('This will take a while') 80 | count = await cf_common.cache2.problemset_cache.update_for_all() 81 | else: 82 | try: 83 | contest_id = int(contest_id) 84 | except ValueError: 85 | return 86 | count = await cf_common.cache2.problemset_cache.update_for_contest( 87 | contest_id 88 | ) 89 | await ctx.send(f'Done, fetched {count} problems') 90 | 91 | 92 | def setup(bot): 93 | bot.add_cog(CacheControl(bot)) 94 | -------------------------------------------------------------------------------- /tle/cogs/codeforces.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | from collections import defaultdict 4 | 5 | import discord 6 | from discord.ext import commands 7 | 8 | from tle import constants 9 | from tle.util import ( 10 | cache_system2, 11 | codeforces_api as cf, 12 | codeforces_common as cf_common, 13 | discord_common, 14 | paginator, 15 | ) 16 | from tle.util.db.user_db_conn import Gitgud 17 | 18 | _GITGUD_NO_SKIP_TIME = 3 * 60 * 60 19 | _GITGUD_SCORE_DISTRIB = (2, 3, 5, 8, 12, 17, 23) 20 | _GITGUD_MAX_ABS_DELTA_VALUE = 300 21 | 22 | 23 | class CodeforcesCogError(commands.CommandError): 24 | pass 25 | 26 | 27 | class Codeforces(commands.Cog): 28 | def __init__(self, bot): 29 | self.bot = bot 30 | self.converter = commands.MemberConverter() 31 | 32 | async def _validate_gitgud_status(self, ctx, delta): 33 | if delta is not None and delta % 100 != 0: 34 | raise CodeforcesCogError('Delta must be a multiple of 100.') 35 | 36 | if delta is not None and abs(delta) > _GITGUD_MAX_ABS_DELTA_VALUE: 37 | raise CodeforcesCogError( 38 | f'Delta must range from -{_GITGUD_MAX_ABS_DELTA_VALUE}' 39 | f' to {_GITGUD_MAX_ABS_DELTA_VALUE}.' 40 | ) 41 | 42 | user_id = ctx.message.author.id 43 | active = cf_common.user_db.check_challenge(user_id) 44 | if active is not None: 45 | _, _, name, contest_id, index, _ = active 46 | url = f'{cf.CONTEST_BASE_URL}{contest_id}/problem/{index}' 47 | raise CodeforcesCogError(f'You have an active challenge {name} at {url}') 48 | 49 | async def _gitgud(self, ctx, handle, problem, delta): 50 | # The caller of this function is responsible for calling 51 | # `_validate_gitgud_status` first. 52 | user_id = ctx.author.id 53 | 54 | issue_time = datetime.datetime.now().timestamp() 55 | rc = cf_common.user_db.new_challenge(user_id, issue_time, problem, delta) 56 | if rc != 1: 57 | raise CodeforcesCogError( 58 | 'Your challenge has already been added to the database!' 59 | ) 60 | 61 | title = f'{problem.index}. {problem.name}' 62 | desc = cf_common.cache2.contest_cache.get_contest(problem.contestId).name 63 | embed = discord.Embed(title=title, url=problem.url, description=desc) 64 | embed.add_field(name='Rating', value=problem.rating) 65 | await ctx.send(f'Challenge problem for `{handle}`', embed=embed) 66 | 67 | @commands.command(brief='Upsolve a problem') 68 | @cf_common.user_guard(group='gitgud') 69 | async def upsolve(self, ctx, choice: int = -1): 70 | """Request an unsolved problem from a contest you participated in 71 | delta | -300 | -200 | -100 | 0 | +100 | +200 | +300 72 | points | 2 | 3 | 5 | 8 | 12 | 17 | 23 73 | """ 74 | await self._validate_gitgud_status(ctx, delta=None) 75 | (handle,) = await cf_common.resolve_handles( 76 | ctx, self.converter, ('!' + str(ctx.author),) 77 | ) 78 | user = cf_common.user_db.fetch_cf_user(handle) 79 | rating = round(user.effective_rating, -2) 80 | resp = await cf.user.rating(handle=handle) 81 | contests = {change.contestId for change in resp} 82 | submissions = await cf.user.status(handle=handle) 83 | solved = {sub.problem.name for sub in submissions if sub.verdict == 'OK'} 84 | problems = [ 85 | prob 86 | for prob in cf_common.cache2.problem_cache.problems 87 | if prob.name not in solved 88 | and prob.contestId in contests 89 | and abs(rating - prob.rating) <= 300 90 | ] 91 | 92 | if not problems: 93 | raise CodeforcesCogError('Problems not found within the search parameters') 94 | 95 | problems.sort( 96 | key=lambda problem: cf_common.cache2.contest_cache.get_contest( 97 | problem.contestId 98 | ).startTimeSeconds, 99 | reverse=True, 100 | ) 101 | 102 | if choice > 0 and choice <= len(problems): 103 | problem = problems[choice - 1] 104 | await self._gitgud(ctx, handle, problem, problem.rating - rating) 105 | else: 106 | msg = '\n'.join( 107 | f'{i + 1}: [{prob.name}]({prob.url}) [{prob.rating}]' 108 | for i, prob in enumerate(problems[:5]) 109 | ) 110 | title = f'Select a problem to upsolve (1-{len(problems)}):' 111 | embed = discord_common.cf_color_embed(title=title, description=msg) 112 | await ctx.send(embed=embed) 113 | 114 | @commands.command(brief='Recommend a problem', usage='[+tag..] [~tag..] [rating]') 115 | @cf_common.user_guard(group='gitgud') 116 | async def gimme(self, ctx, *args): 117 | (handle,) = await cf_common.resolve_handles( 118 | ctx, self.converter, ('!' + str(ctx.author),) 119 | ) 120 | rating = round(cf_common.user_db.fetch_cf_user(handle).effective_rating, -2) 121 | tags = cf_common.parse_tags(args, prefix='+') 122 | bantags = cf_common.parse_tags(args, prefix='~') 123 | rating = cf_common.parse_rating(args, rating) 124 | 125 | submissions = await cf.user.status(handle=handle) 126 | solved = {sub.problem.name for sub in submissions if sub.verdict == 'OK'} 127 | 128 | problems = [ 129 | prob 130 | for prob in cf_common.cache2.problem_cache.problems 131 | if prob.rating == rating 132 | and prob.name not in solved 133 | and not cf_common.is_contest_writer(prob.contestId, handle) 134 | and prob.matches_all_tags(tags) 135 | and not prob.matches_any_tag(bantags) 136 | ] 137 | 138 | if not problems: 139 | raise CodeforcesCogError('Problems not found within the search parameters') 140 | 141 | problems.sort( 142 | key=lambda problem: cf_common.cache2.contest_cache.get_contest( 143 | problem.contestId 144 | ).startTimeSeconds 145 | ) 146 | 147 | choice = max([random.randrange(len(problems)) for _ in range(2)]) 148 | problem = problems[choice] 149 | 150 | title = f'{problem.index}. {problem.name}' 151 | desc = cf_common.cache2.contest_cache.get_contest(problem.contestId).name 152 | embed = discord.Embed(title=title, url=problem.url, description=desc) 153 | embed.add_field(name='Rating', value=problem.rating) 154 | if tags: 155 | tagslist = ', '.join(problem.get_matched_tags(tags)) 156 | embed.add_field(name='Matched tags', value=tagslist) 157 | await ctx.send(f'Recommended problem for `{handle}`', embed=embed) 158 | 159 | @commands.command( 160 | brief='List solved problems', 161 | usage='[handles] [+hardest] [+practice] [+contest] [+virtual] [+outof] [+team] [+tag..] [~tag..] [r>=rating] [r<=rating] [d>=[[dd]mm]yyyy] [d<[[dd]mm]yyyy] [c+marker..] [i+index..]', # noqa: E501 162 | ) 163 | async def stalk(self, ctx, *args): 164 | """Print problems solved by user sorted by time (default) or rating. 165 | All submission types are included by default (practice, contest, etc.) 166 | """ 167 | (hardest,), args = cf_common.filter_flags(args, ['+hardest']) 168 | filt = cf_common.SubFilter(False) 169 | args = filt.parse(args) 170 | handles = args or ('!' + str(ctx.author),) 171 | handles = await cf_common.resolve_handles(ctx, self.converter, handles) 172 | submissions = [await cf.user.status(handle=handle) for handle in handles] 173 | submissions = [sub for subs in submissions for sub in subs] 174 | submissions = filt.filter_subs(submissions) 175 | 176 | if not submissions: 177 | raise CodeforcesCogError( 178 | 'Submissions not found within the search parameters' 179 | ) 180 | 181 | if hardest: 182 | submissions.sort( 183 | key=lambda sub: (sub.problem.rating or 0, sub.creationTimeSeconds), 184 | reverse=True, 185 | ) 186 | else: 187 | submissions.sort(key=lambda sub: sub.creationTimeSeconds, reverse=True) 188 | 189 | def make_line(sub): 190 | data = ( 191 | f'[{sub.problem.name}]({sub.problem.url})', 192 | f'[{sub.problem.rating if sub.problem.rating else "?"}]', 193 | f'({cf_common.days_ago(sub.creationTimeSeconds)})', 194 | ) 195 | return '\N{EN SPACE}'.join(data) 196 | 197 | def make_page(chunk): 198 | title = '{} solved problems by `{}`'.format( 199 | 'Hardest' if hardest else 'Recently', '`, `'.join(handles) 200 | ) 201 | hist_str = '\n'.join(make_line(sub) for sub in chunk) 202 | embed = discord_common.cf_color_embed(description=hist_str) 203 | return title, embed 204 | 205 | pages = [ 206 | make_page(chunk) for chunk in paginator.chunkify(submissions[:100], 10) 207 | ] 208 | paginator.paginate( 209 | self.bot, ctx.channel, pages, wait_time=5 * 60, set_pagenum_footers=True 210 | ) 211 | 212 | @commands.command(brief='Create a mashup', usage='[handles] [+tag..] [~tag..]') 213 | async def mashup(self, ctx, *args): 214 | """Create a mashup contest. 215 | 216 | The contest uses problems within +-100 of average rating of handles provided. 217 | Add tags with "+" before them. 218 | Ban tags with "~" before them. 219 | """ 220 | handles = [arg for arg in args if arg[0] not in '+~'] 221 | tags = cf_common.parse_tags(args, prefix='+') 222 | bantags = cf_common.parse_tags(args, prefix='~') 223 | 224 | handles = handles or ('!' + str(ctx.author),) 225 | handles = await cf_common.resolve_handles(ctx, self.converter, handles) 226 | resp = [await cf.user.status(handle=handle) for handle in handles] 227 | submissions = [sub for user in resp for sub in user] 228 | solved = {sub.problem.name for sub in submissions} 229 | info = await cf.user.info(handles=handles) 230 | rating = int( 231 | round(sum(user.effective_rating for user in info) / len(handles), -2) 232 | ) 233 | problems = [ 234 | prob 235 | for prob in cf_common.cache2.problem_cache.problems 236 | if abs(prob.rating - rating) <= 100 237 | and prob.name not in solved 238 | and not any( 239 | cf_common.is_contest_writer(prob.contestId, handle) 240 | for handle in handles 241 | ) 242 | and not cf_common.is_nonstandard_problem(prob) 243 | and prob.matches_all_tags(tags) 244 | and not prob.matches_any_tag(bantags) 245 | ] 246 | 247 | if len(problems) < 4: 248 | raise CodeforcesCogError('Problems not found within the search parameters') 249 | 250 | problems.sort( 251 | key=lambda problem: cf_common.cache2.contest_cache.get_contest( 252 | problem.contestId 253 | ).startTimeSeconds 254 | ) 255 | 256 | choices = [] 257 | for i in range(4): 258 | k = max(random.randrange(len(problems) - i) for _ in range(2)) 259 | for c in choices: 260 | if k >= c: 261 | k += 1 262 | choices.append(k) 263 | choices.sort() 264 | 265 | problems = reversed([problems[k] for k in choices]) 266 | msg = '\n'.join( 267 | f'{"ABCD"[i]}: [{p.name}]({p.url}) [{p.rating}]' 268 | for i, p in enumerate(problems) 269 | ) 270 | str_handles = '`, `'.join(handles) 271 | embed = discord_common.cf_color_embed(description=msg) 272 | await ctx.send(f'Mashup contest for `{str_handles}`', embed=embed) 273 | 274 | @commands.command(brief='Challenge') 275 | @cf_common.user_guard(group='gitgud') 276 | async def gitgud(self, ctx, delta: int = 0): 277 | """Request a problem for gitgud points. 278 | delta | -300 | -200 | -100 | 0 | +100 | +200 | +300 279 | points | 2 | 3 | 5 | 8 | 12 | 17 | 23 280 | """ 281 | await self._validate_gitgud_status(ctx, delta) 282 | (handle,) = await cf_common.resolve_handles( 283 | ctx, self.converter, ('!' + str(ctx.author),) 284 | ) 285 | user = cf_common.user_db.fetch_cf_user(handle) 286 | rating = round(user.effective_rating, -2) 287 | submissions = await cf.user.status(handle=handle) 288 | solved = {sub.problem.name for sub in submissions} 289 | noguds = cf_common.user_db.get_noguds(ctx.message.author.id) 290 | 291 | problems = [ 292 | prob 293 | for prob in cf_common.cache2.problem_cache.problems 294 | if ( 295 | prob.rating == rating + delta 296 | and prob.name not in solved 297 | and prob.name not in noguds 298 | ) 299 | ] 300 | 301 | def check(problem): 302 | return not cf_common.is_nonstandard_problem( 303 | problem 304 | ) and not cf_common.is_contest_writer(problem.contestId, handle) 305 | 306 | problems = list(filter(check, problems)) 307 | if not problems: 308 | raise CodeforcesCogError('No problem to assign') 309 | 310 | problems.sort( 311 | key=lambda problem: cf_common.cache2.contest_cache.get_contest( 312 | problem.contestId 313 | ).startTimeSeconds 314 | ) 315 | 316 | choice = max(random.randrange(len(problems)) for _ in range(2)) 317 | await self._gitgud(ctx, handle, problems[choice], delta) 318 | 319 | @commands.command(brief='Print user gitgud history') 320 | async def gitlog(self, ctx, member: discord.Member = None): 321 | """Displays the list of gitgud problems issued to the specified member, 322 | excluding those noguded by admins. If the challenge was completed, time 323 | of completion and amount of points gained will also be displayed. 324 | """ 325 | 326 | def make_line(entry): 327 | issue, finish, name, contest, index, delta, status = entry 328 | problem = cf_common.cache2.problem_cache.problem_by_name[name] 329 | line = f'[{name}]({problem.url})\N{EN SPACE}[{problem.rating}]' 330 | if finish: 331 | time_str = cf_common.days_ago(finish) 332 | points = f'{_GITGUD_SCORE_DISTRIB[delta // 100 + 3]:+}' 333 | line += f'\N{EN SPACE}{time_str}\N{EN SPACE}[{points}]' 334 | return line 335 | 336 | def make_page(chunk): 337 | message = discord.utils.escape_mentions( 338 | f'gitgud log for {member.display_name}' 339 | ) 340 | log_str = '\n'.join(make_line(entry) for entry in chunk) 341 | embed = discord_common.cf_color_embed(description=log_str) 342 | return message, embed 343 | 344 | member = member or ctx.author 345 | data = cf_common.user_db.gitlog(member.id) 346 | if not data: 347 | raise CodeforcesCogError(f'{member.mention} has no gitgud history.') 348 | 349 | pages = [make_page(chunk) for chunk in paginator.chunkify(data, 7)] 350 | paginator.paginate( 351 | self.bot, ctx.channel, pages, wait_time=5 * 60, set_pagenum_footers=True 352 | ) 353 | 354 | @commands.command(brief='Report challenge completion') 355 | @cf_common.user_guard(group='gitgud') 356 | async def gotgud(self, ctx): 357 | (handle,) = await cf_common.resolve_handles( 358 | ctx, self.converter, ('!' + str(ctx.author),) 359 | ) 360 | user_id = ctx.message.author.id 361 | active = cf_common.user_db.check_challenge(user_id) 362 | if not active: 363 | raise CodeforcesCogError('You do not have an active challenge') 364 | 365 | submissions = await cf.user.status(handle=handle) 366 | solved = {sub.problem.name for sub in submissions if sub.verdict == 'OK'} 367 | 368 | challenge_id, issue_time, name, contestId, index, delta = active 369 | if name not in solved: 370 | raise CodeforcesCogError("You haven't completed your challenge.") 371 | 372 | delta = _GITGUD_SCORE_DISTRIB[delta // 100 + 3] 373 | finish_time = int(datetime.datetime.now().timestamp()) 374 | rc = cf_common.user_db.complete_challenge( 375 | user_id, challenge_id, finish_time, delta 376 | ) 377 | if rc == 1: 378 | duration = cf_common.pretty_time_format(finish_time - issue_time) 379 | await ctx.send( 380 | f'Challenge completed in {duration}. {handle} gained {delta} points.' 381 | ) 382 | else: 383 | await ctx.send('You have already claimed your points') 384 | 385 | @commands.command(brief='Skip challenge') 386 | @cf_common.user_guard(group='gitgud') 387 | async def nogud(self, ctx): 388 | await cf_common.resolve_handles(ctx, self.converter, ('!' + str(ctx.author),)) 389 | user_id = ctx.message.author.id 390 | active = cf_common.user_db.check_challenge(user_id) 391 | if not active: 392 | raise CodeforcesCogError('You do not have an active challenge') 393 | 394 | challenge_id, issue_time, name, contestId, index, delta = active 395 | finish_time = int(datetime.datetime.now().timestamp()) 396 | if finish_time - issue_time < _GITGUD_NO_SKIP_TIME: 397 | skip_time = cf_common.pretty_time_format( 398 | issue_time + _GITGUD_NO_SKIP_TIME - finish_time 399 | ) 400 | await ctx.send(f'Think more. You can skip your challenge in {skip_time}.') 401 | return 402 | cf_common.user_db.skip_challenge(user_id, challenge_id, Gitgud.NOGUD) 403 | await ctx.send('Challenge skipped.') 404 | 405 | @commands.command(brief='Force skip a challenge') 406 | @cf_common.user_guard(group='gitgud') 407 | @commands.has_any_role(constants.TLE_ADMIN, constants.TLE_MODERATOR) 408 | async def _nogud(self, ctx, member: discord.Member): 409 | active = cf_common.user_db.check_challenge(member.id) 410 | rc = cf_common.user_db.skip_challenge(member.id, active[0], Gitgud.FORCED_NOGUD) 411 | if rc == 1: 412 | await ctx.send('Challenge skip forced.') 413 | else: 414 | await ctx.send('Failed to force challenge skip.') 415 | 416 | @commands.command(brief='Recommend a contest', usage='[handles...] [+pattern...]') 417 | async def vc(self, ctx, *args: str): 418 | """Recommends a contest based on Codeforces rating of the handle provided. 419 | e.g ;vc mblazev c1729 +global +hello +goodbye +avito""" 420 | markers = [x for x in args if x[0] == '+'] 421 | handles = [x for x in args if x[0] != '+'] or ('!' + str(ctx.author),) 422 | handles = await cf_common.resolve_handles( 423 | ctx, self.converter, handles, maxcnt=25 424 | ) 425 | info = await cf.user.info(handles=handles) 426 | contests = cf_common.cache2.contest_cache.get_contests_in_phase('FINISHED') 427 | 428 | if not markers: 429 | divr = sum(user.effective_rating for user in info) / len(handles) 430 | div1_indicators = ['div1', 'global', 'avito', 'goodbye', 'hello'] 431 | markers = ( 432 | ['div3'] 433 | if divr < 1600 434 | else ['div2'] 435 | if divr < 2100 436 | else div1_indicators 437 | ) 438 | 439 | recommendations = { 440 | contest.id 441 | for contest in contests 442 | if contest.matches(markers) 443 | and not cf_common.is_nonstandard_contest(contest) 444 | and not any( 445 | cf_common.is_contest_writer(contest.id, handle) for handle in handles 446 | ) 447 | } 448 | 449 | # Discard contests in which user has non-CE submissions. 450 | visited_contests = await cf_common.get_visited_contests(handles) 451 | recommendations -= visited_contests 452 | 453 | if not recommendations: 454 | raise CodeforcesCogError('Unable to recommend a contest') 455 | 456 | recommendations = list(recommendations) 457 | random.shuffle(recommendations) 458 | contests = [ 459 | cf_common.cache2.contest_cache.get_contest(contest_id) 460 | for contest_id in recommendations[:25] 461 | ] 462 | 463 | def make_line(c): 464 | return ( 465 | f'[{c.name}]({c.url}) {cf_common.pretty_time_format(c.durationSeconds)}' 466 | ) 467 | 468 | def make_page(chunk): 469 | str_handles = '`, `'.join(handles) 470 | message = f'Recommended contest(s) for `{str_handles}`' 471 | vc_str = '\n'.join(make_line(contest) for contest in chunk) 472 | embed = discord_common.cf_color_embed(description=vc_str) 473 | return message, embed 474 | 475 | pages = [make_page(chunk) for chunk in paginator.chunkify(contests, 5)] 476 | paginator.paginate( 477 | self.bot, ctx.channel, pages, wait_time=5 * 60, set_pagenum_footers=True 478 | ) 479 | 480 | @commands.command( 481 | brief='Display unsolved rounds closest to completion', usage='[keywords]' 482 | ) 483 | async def fullsolve(self, ctx, *args: str): 484 | """Displays a list of contests, sorted by number of unsolved problems. 485 | Contest names matching any of the provided tags will be considered. e.g 486 | ;fullsolve +edu""" 487 | (handle,) = await cf_common.resolve_handles( 488 | ctx, self.converter, ('!' + str(ctx.author),) 489 | ) 490 | tags = [x for x in args if x[0] == '+'] 491 | 492 | problem_to_contests = cf_common.cache2.problemset_cache.problem_to_contests 493 | contests = [ 494 | contest 495 | for contest in cf_common.cache2.contest_cache.get_contests_in_phase( 496 | 'FINISHED' 497 | ) 498 | if (not tags or contest.matches(tags)) 499 | and not cf_common.is_nonstandard_contest(contest) 500 | ] 501 | 502 | # subs_by_contest_id contains contest_id mapped to [list of problem.name] 503 | subs_by_contest_id = defaultdict(set) 504 | for sub in await cf.user.status(handle=handle): 505 | if sub.verdict == 'OK': 506 | try: 507 | contest = cf_common.cache2.contest_cache.get_contest( 508 | sub.problem.contestId 509 | ) 510 | problem_id = (sub.problem.name, contest.startTimeSeconds) 511 | for contestId in problem_to_contests[problem_id]: 512 | subs_by_contest_id[contestId].add(sub.problem.name) 513 | except cache_system2.ContestNotFound: 514 | pass 515 | 516 | contest_unsolved_pairs = [] 517 | for contest in contests: 518 | num_solved = len(subs_by_contest_id[contest.id]) 519 | try: 520 | num_problems = len( 521 | cf_common.cache2.problemset_cache.get_problemset(contest.id) 522 | ) 523 | if 0 < num_solved < num_problems: 524 | contest_unsolved_pairs.append((contest, num_solved, num_problems)) 525 | except cache_system2.ProblemsetNotCached: 526 | # In case of recent contents or cetain bugged contests 527 | pass 528 | 529 | contest_unsolved_pairs.sort(key=lambda p: (p[2] - p[1], -p[0].startTimeSeconds)) 530 | 531 | if not contest_unsolved_pairs: 532 | raise CodeforcesCogError( 533 | f'`{handle}` has no contests to fullsolve :confetti_ball:' 534 | ) 535 | 536 | def make_line(entry): 537 | contest, solved, total = entry 538 | return f'[{contest.name}]({contest.url})\N{EN SPACE}[{solved}/{total}]' 539 | 540 | def make_page(chunk): 541 | message = f'Fullsolve list for `{handle}`' 542 | full_solve_list = '\n'.join(make_line(entry) for entry in chunk) 543 | embed = discord_common.cf_color_embed(description=full_solve_list) 544 | return message, embed 545 | 546 | pages = [ 547 | make_page(chunk) for chunk in paginator.chunkify(contest_unsolved_pairs, 10) 548 | ] 549 | paginator.paginate( 550 | self.bot, ctx.channel, pages, wait_time=5 * 60, set_pagenum_footers=True 551 | ) 552 | 553 | @staticmethod 554 | def getEloWinProbability(ra: float, rb: float) -> float: 555 | return 1.0 / (1 + 10 ** ((rb - ra) / 400.0)) 556 | 557 | @staticmethod 558 | def composeRatings(left: float, right: float, ratings: list[float]) -> int: 559 | for _tt in range(20): 560 | r = (left + right) / 2.0 561 | 562 | rWinsProbability = 1.0 563 | for rating, count in ratings: 564 | rWinsProbability *= Codeforces.getEloWinProbability(r, rating) ** count 565 | 566 | if rWinsProbability < 0.5: 567 | left = r 568 | else: 569 | right = r 570 | return round((left + right) / 2) 571 | 572 | @commands.command(brief='Calculate team rating', usage='[handles] [+peak]') 573 | async def teamrate(self, ctx, *args: str): 574 | """Provides the combined rating of the entire team. If +server is 575 | provided as the only handle, will display the rating of the entire 576 | server. Supports multipliers. e.g: ;teamrate gamegame*1000""" 577 | 578 | (is_entire_server, peak), handles = cf_common.filter_flags( 579 | args, ['+server', '+peak'] 580 | ) 581 | handles = handles or ('!' + str(ctx.author),) 582 | 583 | def rating(user): 584 | return user.maxRating if peak else user.rating 585 | 586 | if is_entire_server: 587 | res = cf_common.user_db.get_cf_users_for_guild(ctx.guild.id) 588 | ratings = [ 589 | (rating(user), 1) for user_id, user in res if user.rating is not None 590 | ] 591 | user_str = '+server' 592 | else: 593 | 594 | def normalize(x): 595 | return [i.lower() for i in x] 596 | 597 | handle_counts = {} 598 | parsed_handles = [] 599 | for i in handles: 600 | parse_str = normalize(i.split('*')) 601 | if len(parse_str) > 1: 602 | try: 603 | handle_counts[parse_str[0]] = int(parse_str[1]) 604 | except ValueError: 605 | raise CodeforcesCogError("Can't multiply by non-integer") 606 | else: 607 | handle_counts[parse_str[0]] = 1 608 | parsed_handles.append(parse_str[0]) 609 | 610 | cf_handles = await cf_common.resolve_handles( 611 | ctx, self.converter, parsed_handles, mincnt=1, maxcnt=1000 612 | ) 613 | cf_handles = normalize(cf_handles) 614 | cf_to_original = { 615 | a: b for a, b in zip(cf_handles, parsed_handles, strict=False) 616 | } 617 | original_to_cf = { 618 | a: b for a, b in zip(parsed_handles, cf_handles, strict=False) 619 | } 620 | users = await cf.user.info(handles=cf_handles) 621 | user_strs = [] 622 | for a, b in handle_counts.items(): 623 | if b > 1: 624 | user_strs.append(f'{original_to_cf[a]}*{b}') 625 | elif b == 1: 626 | user_strs.append(original_to_cf[a]) 627 | elif b <= 0: 628 | raise CodeforcesCogError( 629 | 'How can you have nonpositive members in team?' 630 | ) 631 | 632 | user_str = ', '.join(user_strs) 633 | ratings = [ 634 | (rating(user), handle_counts[cf_to_original[user.handle.lower()]]) 635 | for user in users 636 | if user.rating 637 | ] 638 | 639 | if len(ratings) == 0: 640 | raise CodeforcesCogError('No CF usernames with ratings passed in.') 641 | 642 | left = -100.0 643 | right = 10000.0 644 | teamRating = Codeforces.composeRatings(left, right, ratings) 645 | embed = discord.Embed( 646 | title=user_str, 647 | description=teamRating, 648 | color=cf.rating2rank(teamRating).color_embed, 649 | ) 650 | await ctx.send(embed=embed) 651 | 652 | @discord_common.send_error_if( 653 | CodeforcesCogError, cf_common.ResolveHandleError, cf_common.FilterError 654 | ) 655 | async def cog_command_error(self, ctx, error): 656 | pass 657 | 658 | 659 | def setup(bot): 660 | bot.add_cog(Codeforces(bot)) 661 | -------------------------------------------------------------------------------- /tle/cogs/deactivated/cses.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from discord.ext import commands 4 | 5 | from tle.util import cses_scraper as cses, discord_common, table, tasks 6 | 7 | 8 | def score(placings): 9 | points = {1: 8, 2: 5, 3: 3, 4: 2, 5: 1} 10 | # points = {1:5, 2:4, 3:3, 4:2, 5:1} 11 | return sum(points[rank] for rank in placings) 12 | 13 | 14 | class CSES(commands.Cog): 15 | def __init__(self, bot): 16 | self.bot = bot 17 | self.short_placings = {} 18 | self.fast_placings = {} 19 | self.reloading = False 20 | 21 | @commands.Cog.listener() 22 | @discord_common.once 23 | async def on_ready(self): 24 | self._cache_data.start() 25 | 26 | @tasks.task_spec( 27 | name='ProblemsetCacheUpdate', waiter=tasks.Waiter.fixed_delay(30 * 60) 28 | ) 29 | async def _cache_data(self, _): 30 | await self._reload() 31 | 32 | async def _reload(self): 33 | self.reloading = True 34 | short_placings = defaultdict(list) 35 | fast_placings = defaultdict(list) 36 | try: 37 | for pid in await cses.get_problems(): 38 | fast, short = await cses.get_problem_leaderboard(pid) 39 | for i in range(len(fast)): 40 | fast_placings[fast[i]].append(i + 1) 41 | for i in range(len(short)): 42 | short_placings[short[i]].append(i + 1) 43 | self.short_placings = short_placings 44 | self.fast_placings = fast_placings 45 | finally: 46 | self.reloading = False 47 | 48 | def format_leaderboard(self, top, placings): 49 | if not top: 50 | return 'Failed to load :<' 51 | 52 | header = ' 1st 2nd 3rd 4th 5th '.split(' ') 53 | 54 | style = table.Style( 55 | header='{:>} {:>} {:>} {:>} {:>} {:>} {:>}', 56 | body='{:>} | {:>} {:>} {:>} {:>} {:>} | {:>} pts', 57 | ) 58 | 59 | t = table.Table(style) 60 | t += table.Header(*header) 61 | 62 | for user, points in top: 63 | hist = [placings[user].count(i + 1) for i in range(5)] 64 | t += table.Data(user, *hist, points) 65 | 66 | return str(t) 67 | 68 | def leaderboard(self, placings, num): 69 | leaderboard = sorted( 70 | ((k, score(v)) for k, v in placings.items() if k != 'N/A'), 71 | key=lambda x: x[1], 72 | reverse=True, 73 | ) 74 | 75 | top = leaderboard[:num] 76 | 77 | return self.format_leaderboard(top, placings) 78 | 79 | def leaderboard_individual(self, placings, handles): 80 | leaderboard = sorted( 81 | ((k, score(v)) for k, v in placings.items() if k != 'N/A' and k in handles), 82 | key=lambda x: x[1], 83 | reverse=True, 84 | ) 85 | 86 | included = [handle for handle, score in leaderboard] 87 | leaderboard += [(handle, 0) for handle in handles if handle not in included] 88 | 89 | top = leaderboard 90 | 91 | return self.format_leaderboard(top, placings) 92 | 93 | @property 94 | def fastest(self, num=10): 95 | return self.leaderboard(self.fast_placings, num) 96 | 97 | @property 98 | def shortest(self, num=10): 99 | return self.leaderboard(self.short_placings, num) 100 | 101 | def fastest_individual(self, handles): 102 | return self.leaderboard_individual(self.fast_placings, handles) 103 | 104 | def shortest_individual(self, handles): 105 | return self.leaderboard_individual(self.short_placings, handles) 106 | 107 | @commands.command(brief='Shows compiled CSES leaderboard', usage='[handles...]') 108 | async def cses(self, ctx, *handles: str): 109 | """Shows compiled CSES leaderboard. 110 | 111 | If handles are given, leaderboard will contain only those indicated 112 | handles, otherwise leaderboard will contain overall top ten. 113 | """ 114 | if not handles: 115 | await ctx.send( 116 | '```\nFastest\n' 117 | + self.fastest 118 | + '\n\n' 119 | + 'Shortest\n' 120 | + self.shortest 121 | + '\n' 122 | + '```' 123 | ) 124 | elif len(handles) > 10: 125 | await ctx.send('```Please indicate at most 10 users```') 126 | else: 127 | handles = set(handles) 128 | await ctx.send( 129 | '```\nFastest\n' 130 | + self.fastest_individual(handles) 131 | + '\n\n' 132 | + 'Shortest\n' 133 | + self.shortest_individual(handles) 134 | + '\n' 135 | + '```' 136 | ) 137 | 138 | @commands.command(brief='Force update the CSES leaderboard') 139 | async def _updatecses(self, ctx): 140 | """Shows compiled CSES leaderboard.""" 141 | if self.reloading: 142 | await ctx.send("Have some patience, I'm already reloading!") 143 | else: 144 | await self._reload() 145 | await ctx.send('CSES leaderboards updated!') 146 | 147 | 148 | def setup(bot): 149 | bot.add_cog(CSES(bot)) 150 | -------------------------------------------------------------------------------- /tle/cogs/logging.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | 5 | from discord.ext import commands 6 | 7 | from tle.util import discord_common 8 | 9 | root_logger = logging.getLogger() 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Logging(commands.Cog, logging.Handler): 14 | def __init__(self, bot, channel_id): 15 | logging.Handler.__init__(self) 16 | self.bot = bot 17 | self.channel_id = channel_id 18 | self.queue = asyncio.Queue() 19 | self.task = None 20 | self.logger = logging.getLogger(self.__class__.__name__) 21 | 22 | @commands.Cog.listener() 23 | @discord_common.once 24 | async def on_ready(self): 25 | self.task = asyncio.create_task(self._log_task()) 26 | width = 79 27 | stars, msg = f'{"*" * width}', f'***{"Bot running":^{width - 6}}***' 28 | self.logger.log(level=100, msg=stars) 29 | self.logger.log(level=100, msg=msg) 30 | self.logger.log(level=100, msg=stars) 31 | 32 | async def _log_task(self): 33 | while True: 34 | record = await self.queue.get() 35 | channel = self.bot.get_channel(self.channel_id) 36 | if channel is None: 37 | # Channel no longer exists. 38 | root_logger.removeHandler(self) 39 | self.logger.warning( 40 | 'Logging channel not available, disabling Discord log handler.' 41 | ) 42 | break 43 | try: 44 | msg = self.format(record) 45 | # Not all errors will have message_contents or jump urls. 46 | try: 47 | await channel.send( 48 | 'Original Command: {}\nJump Url: {}'.format( 49 | record.message_content, record.jump_url 50 | ) 51 | ) 52 | except AttributeError: 53 | pass 54 | discord_msg_char_limit = 2000 55 | char_limit = discord_msg_char_limit - 2 * len('```') 56 | too_long = len(msg) > char_limit 57 | msg = msg[:char_limit] 58 | await channel.send('```{}```'.format(msg)) 59 | if too_long: 60 | await channel.send('`Check logs for full stack trace`') 61 | except: 62 | self.handleError(record) 63 | 64 | # logging.Handler overrides below. 65 | 66 | def emit(self, record): 67 | self.queue.put_nowait(record) 68 | 69 | def close(self): 70 | if self.task: 71 | self.task.cancel() 72 | 73 | 74 | def setup(bot): 75 | logging_cog_channel_id = os.environ.get('LOGGING_COG_CHANNEL_ID') 76 | if logging_cog_channel_id is None: 77 | logger.info( 78 | 'Skipping installation of logging cog as logging channel is not provided.' 79 | ) 80 | return 81 | 82 | logging_cog = Logging(bot, int(logging_cog_channel_id)) 83 | logging_cog.setLevel(logging.WARNING) 84 | logging_cog.setFormatter( 85 | logging.Formatter( 86 | fmt='{asctime}:{levelname}:{name}:{message}', 87 | style='{', 88 | datefmt='%d-%m-%Y %H:%M:%S', 89 | ) 90 | ) 91 | root_logger.addHandler(logging_cog) 92 | bot.add_cog(logging_cog) 93 | -------------------------------------------------------------------------------- /tle/cogs/meta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import textwrap 4 | import time 5 | 6 | from discord.ext import commands 7 | 8 | from tle import constants 9 | from tle.util.codeforces_common import pretty_time_format 10 | 11 | RESTART = 42 12 | 13 | 14 | # Adapted from numpy sources. 15 | # https://github.com/numpy/numpy/blob/master/setup.py#L64-85 16 | def git_history(): 17 | def _minimal_ext_cmd(cmd): 18 | # construct minimal environment 19 | env = {} 20 | for k in ['SYSTEMROOT', 'PATH']: 21 | v = os.environ.get(k) 22 | if v is not None: 23 | env[k] = v 24 | # LANGUAGE is used on win32 25 | env['LANGUAGE'] = 'C' 26 | env['LANG'] = 'C' 27 | env['LC_ALL'] = 'C' 28 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 29 | return out 30 | 31 | try: 32 | out = _minimal_ext_cmd(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 33 | branch = out.strip().decode('ascii') 34 | out = _minimal_ext_cmd(['git', 'log', '--oneline', '-5']) 35 | history = out.strip().decode('ascii') 36 | return ( 37 | 'Branch:\n' 38 | + textwrap.indent(branch, ' ') 39 | + '\nCommits:\n' 40 | + textwrap.indent(history, ' ') 41 | ) 42 | except OSError: 43 | return 'Fetching git info failed' 44 | 45 | 46 | class Meta(commands.Cog): 47 | def __init__(self, bot): 48 | self.bot = bot 49 | self.start_time = time.time() 50 | 51 | @commands.group(brief='Bot control', invoke_without_command=True) 52 | async def meta(self, ctx): 53 | """Command the bot or get information about the bot.""" 54 | await ctx.send_help(ctx.command) 55 | 56 | @meta.command(brief='Restarts TLE') 57 | @commands.has_role(constants.TLE_ADMIN) 58 | async def restart(self, ctx): 59 | """Restarts the bot.""" 60 | # Really, we just exit with a special code 61 | # the magic is handled elsewhere 62 | await ctx.send('Restarting...') 63 | os._exit(RESTART) 64 | 65 | @meta.command(brief='Kill TLE') 66 | @commands.has_role(constants.TLE_ADMIN) 67 | async def kill(self, ctx): 68 | """Restarts the bot.""" 69 | await ctx.send('Dying...') 70 | os._exit(0) 71 | 72 | @meta.command(brief='Is TLE up?') 73 | async def ping(self, ctx): 74 | """Replies to a ping.""" 75 | start = time.perf_counter() 76 | message = await ctx.send(':ping_pong: Pong!') 77 | end = time.perf_counter() 78 | duration = (end - start) * 1000 79 | await message.edit( 80 | content=( 81 | f'REST API latency: {int(duration)}ms\n' 82 | f'Gateway API latency: {int(self.bot.latency * 1000)}ms' 83 | ) 84 | ) 85 | 86 | @meta.command(brief='Get git information') 87 | async def git(self, ctx): 88 | """Replies with git information.""" 89 | await ctx.send('```yaml\n' + git_history() + '```') 90 | 91 | @meta.command(brief='Prints bot uptime') 92 | async def uptime(self, ctx): 93 | """Replies with how long TLE has been up.""" 94 | await ctx.send( 95 | 'TLE has been running for ' 96 | + pretty_time_format(time.time() - self.start_time) 97 | ) 98 | 99 | @meta.command(brief='Print bot guilds') 100 | @commands.has_role(constants.TLE_ADMIN) 101 | async def guilds(self, ctx): 102 | "Replies with info on the bot's guilds" 103 | msg = [ 104 | ' | '.join( 105 | [ 106 | f'Guild ID: {guild.id}', 107 | f'Name: {guild.name}', 108 | f'Owner: {guild.owner.id}', 109 | f'Icon: {guild.icon_url}', 110 | ] 111 | ) 112 | for guild in self.bot.guilds 113 | ] 114 | await ctx.send('```' + '\n'.join(msg) + '```') 115 | 116 | 117 | def setup(bot): 118 | bot.add_cog(Meta(bot)) 119 | -------------------------------------------------------------------------------- /tle/cogs/starboard.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import discord 5 | from discord.ext import commands 6 | 7 | from tle import constants 8 | from tle.util import codeforces_common as cf_common, discord_common 9 | 10 | 11 | class StarboardCogError(commands.CommandError): 12 | pass 13 | 14 | 15 | class Starboard(commands.Cog): 16 | def __init__(self, bot): 17 | self.bot = bot 18 | self.locks = {} 19 | self.logger = logging.getLogger(self.__class__.__name__) 20 | 21 | @commands.Cog.listener() 22 | async def on_raw_reaction_add(self, payload): 23 | guild_id = payload.guild_id 24 | if guild_id is None: 25 | return 26 | emoji = str(payload.emoji) 27 | entry = cf_common.user_db.get_starboard_entry(guild_id, emoji) 28 | if entry is None: 29 | return 30 | channel_id, threshold, color = entry 31 | try: 32 | await self.check_and_add_to_starboard( 33 | channel_id, threshold, color, emoji, payload 34 | ) 35 | except StarboardCogError as e: 36 | self.logger.info(f'Failed to starboard: {e!r}') 37 | 38 | @commands.Cog.listener() 39 | async def on_raw_message_delete(self, payload): 40 | if payload.guild_id is None: 41 | return 42 | removed = cf_common.user_db.remove_starboard_message( 43 | starboard_msg_id=payload.message_id 44 | ) 45 | if removed: 46 | self.logger.info( 47 | f'Removed starboard record for deleted message {payload.message_id}' 48 | ) 49 | 50 | @staticmethod 51 | def prepare_embed(message, color): 52 | embed = discord.Embed(color=color, timestamp=message.created_at) 53 | embed.add_field(name='Channel', value=message.channel.mention) 54 | embed.add_field(name='Jump to', value=f'[Original]({message.jump_url})') 55 | 56 | if message.content: 57 | embed.add_field(name='Content', value=message.content, inline=False) 58 | 59 | if message.embeds: 60 | data = message.embeds[0] 61 | if data.type == 'image': 62 | embed.set_image(url=data.url) 63 | 64 | if message.attachments: 65 | file = message.attachments[0] 66 | if file.filename.lower().endswith(('png', 'jpeg', 'jpg', 'gif', 'webp')): 67 | embed.set_image(url=file.url) 68 | else: 69 | embed.add_field( 70 | name='Attachment', 71 | value=f'[{file.filename}]({file.url})', 72 | inline=False, 73 | ) 74 | 75 | embed.set_footer(text=str(message.author), icon_url=message.author.avatar_url) 76 | return embed 77 | 78 | async def check_and_add_to_starboard( 79 | self, channel_id, threshold, color, emoji, payload 80 | ): 81 | guild = self.bot.get_guild(payload.guild_id) 82 | starboard_channel = guild.get_channel(channel_id) 83 | if starboard_channel is None: 84 | raise StarboardCogError('Starboard channel not found') 85 | 86 | channel = self.bot.get_channel(payload.channel_id) 87 | message = await channel.fetch_message(payload.message_id) 88 | if message.type != discord.MessageType.default or ( 89 | not message.content and not message.attachments 90 | ): 91 | raise StarboardCogError('Cannot starboard this message') 92 | 93 | count = sum(r.count for r in message.reactions if str(r) == emoji) 94 | if count < threshold: 95 | return 96 | 97 | lock = self.locks.setdefault(payload.guild_id, asyncio.Lock()) 98 | async with lock: 99 | if cf_common.user_db.check_exists_starboard_message(message.id, emoji): 100 | return 101 | embed = self.prepare_embed(message, color or constants._DEFAULT_COLOR) 102 | star_msg = await starboard_channel.send(embed=embed) 103 | cf_common.user_db.add_starboard_message( 104 | message.id, star_msg.id, payload.guild_id, emoji 105 | ) 106 | self.logger.info(f'Added message {message.id} to starboard under {emoji}') 107 | 108 | @commands.group(brief='Starboard commands', invoke_without_command=True) 109 | async def starboard(self, ctx): 110 | """Group for commands involving the starboard.""" 111 | await ctx.send_help(ctx.command) 112 | 113 | @starboard.command(brief='Add an emoji to starboard list') 114 | @commands.has_role(constants.TLE_ADMIN) 115 | async def add(self, ctx, emoji: str, threshold: int, color: str = None): 116 | """Register an emoji with a reaction threshold and optional hex color.""" 117 | clr = int(color, 16) if color else constants._DEFAULT_COLOR 118 | cf_common.user_db.add_starboard_emoji(ctx.guild.id, emoji, threshold, clr) 119 | await ctx.send( 120 | embed=discord_common.embed_success( 121 | f'Added {emoji}: threshold={threshold}, color={hex(clr)}' 122 | ) 123 | ) 124 | 125 | @starboard.command(brief='Delete an emoji from starboard list') 126 | @commands.has_role(constants.TLE_ADMIN) 127 | async def delete(self, ctx, emoji: str): 128 | """Unregister an emoji from starboard.""" 129 | cf_common.user_db.remove_starboard_emoji(ctx.guild.id, emoji) 130 | cf_common.user_db.clear_starboard_channel(ctx.guild.id, emoji) 131 | await ctx.send(embed=discord_common.embed_success(f'Removed {emoji}')) 132 | 133 | @starboard.command(brief='Edit threshold for an emoji') 134 | @commands.has_role(constants.TLE_ADMIN) 135 | async def edit_threshold(self, ctx, emoji: str, threshold: int): 136 | """Update reaction threshold for an emoji.""" 137 | cf_common.user_db.update_starboard_threshold(ctx.guild.id, emoji, threshold) 138 | await ctx.send( 139 | embed=discord_common.embed_success( 140 | f'Updated {emoji} threshold to {threshold}' 141 | ) 142 | ) 143 | 144 | @starboard.command(brief='Edit embed color for an emoji') 145 | @commands.has_role(constants.TLE_ADMIN) 146 | async def edit_color(self, ctx, emoji: str, color: str): 147 | """Update embed color (hex) for an emoji.""" 148 | clr = int(color, 16) 149 | cf_common.user_db.update_starboard_color(ctx.guild.id, emoji, clr) 150 | await ctx.send( 151 | embed=discord_common.embed_success(f'Updated {emoji} color to {hex(clr)}') 152 | ) 153 | 154 | @starboard.command(brief='Set starboard channel (and optional color) for an emoji') 155 | @commands.has_role(constants.TLE_ADMIN) 156 | async def here(self, ctx, emoji: str): 157 | """Set the channel and optional color for an emoji.""" 158 | cf_common.user_db.set_starboard_channel(ctx.guild.id, emoji, ctx.channel.id) 159 | msg = f'Set {emoji} channel to {ctx.channel.mention}' 160 | await ctx.send(embed=discord_common.embed_success(msg)) 161 | 162 | @starboard.command(brief='Clear starboard channel for an emoji') 163 | @commands.has_role(constants.TLE_ADMIN) 164 | async def clear(self, ctx, emoji: str): 165 | """Remove the starboard channel (and color) setting for an emoji.""" 166 | cf_common.user_db.clear_starboard_channel(ctx.guild.id, emoji) 167 | await ctx.send( 168 | embed=discord_common.embed_success(f'Cleared channel for {emoji}') 169 | ) 170 | 171 | @starboard.command(brief='Remove a message from starboard') 172 | @commands.has_role(constants.TLE_ADMIN) 173 | async def remove(self, ctx, emoji: str, original_message_id: int): 174 | """Remove a particular message from the starboard database.""" 175 | rc = cf_common.user_db.remove_starboard_message( 176 | original_msg_id=original_message_id, emoji=emoji 177 | ) 178 | if rc: 179 | await ctx.send(embed=discord_common.embed_success('Successfully removed')) 180 | else: 181 | await ctx.send(embed=discord_common.embed_alert('Not found')) 182 | 183 | @discord_common.send_error_if(StarboardCogError) 184 | async def cog_command_error(self, ctx, error): 185 | pass 186 | 187 | 188 | def setup(bot): 189 | bot.add_cog(Starboard(bot)) 190 | -------------------------------------------------------------------------------- /tle/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATA_DIR = 'data' 4 | LOGS_DIR = 'logs' 5 | 6 | ASSETS_DIR = os.path.join(DATA_DIR, 'assets') 7 | DB_DIR = os.path.join(DATA_DIR, 'db') 8 | MISC_DIR = os.path.join(DATA_DIR, 'misc') 9 | TEMP_DIR = os.path.join(DATA_DIR, 'temp') 10 | 11 | USER_DB_FILE_PATH = os.path.join(DB_DIR, 'user.db') 12 | CACHE_DB_FILE_PATH = os.path.join(DB_DIR, 'cache.db') 13 | 14 | FONTS_DIR = os.path.join(ASSETS_DIR, 'fonts') 15 | 16 | NOTO_SANS_CJK_BOLD_FONT_PATH = os.path.join(FONTS_DIR, 'NotoSansCJK-Bold.ttc') 17 | NOTO_SANS_CJK_REGULAR_FONT_PATH = os.path.join(FONTS_DIR, 'NotoSansCJK-Regular.ttc') 18 | 19 | CONTEST_WRITERS_JSON_FILE_PATH = os.path.join(MISC_DIR, 'contest_writers.json') 20 | 21 | LOG_FILE_PATH = os.path.join(LOGS_DIR, 'tle.log') 22 | 23 | ALL_DIRS = ( 24 | attrib_value 25 | for attrib_name, attrib_value in list(globals().items()) 26 | if attrib_name.endswith('DIR') 27 | ) 28 | 29 | ALLOW_DUEL_SELF_REGISTER = False 30 | 31 | TLE_ADMIN = os.environ.get('TLE_ADMIN', 'Admin') 32 | TLE_MODERATOR = os.environ.get('TLE_MODERATOR', 'Moderator') 33 | TLE_TRUSTED = os.environ.get('TLE_TRUSTED', 'Trusted') 34 | TLE_PURGATORY = os.environ.get('TLE_PURGATORY', 'Purgatory') 35 | 36 | _DEFAULT_COLOR = 0xFFAA10 37 | _DEFAULT_STAR = '\N{WHITE MEDIUM STAR}' 38 | -------------------------------------------------------------------------------- /tle/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheran-senthil/TLE/4b5620ced4bb9014a31226184eb5c5c905e0b3e4/tle/util/__init__.py -------------------------------------------------------------------------------- /tle/util/cache_system2.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | from collections import defaultdict 5 | 6 | from aiocache import cached 7 | from discord.ext import commands 8 | 9 | from tle.util import ( 10 | codeforces_api as cf, 11 | codeforces_common as cf_common, 12 | events, 13 | paginator, 14 | tasks, 15 | ) 16 | from tle.util.ranklist import Ranklist 17 | 18 | logger = logging.getLogger(__name__) 19 | _CONTESTS_PER_BATCH_IN_CACHE_UPDATES = 100 20 | CONTEST_BLACKLIST = {1308, 1309, 1431, 1432} 21 | 22 | 23 | def _is_blacklisted(contest): 24 | return contest.id in CONTEST_BLACKLIST 25 | 26 | 27 | class CacheError(commands.CommandError): 28 | pass 29 | 30 | 31 | class ContestCacheError(CacheError): 32 | pass 33 | 34 | 35 | class ContestNotFound(ContestCacheError): 36 | def __init__(self, contest_id): 37 | super().__init__(f'Contest with id `{contest_id}` not found') 38 | self.contest_id = contest_id 39 | 40 | 41 | class ContestCache: 42 | _NORMAL_CONTEST_RELOAD_DELAY = 30 * 60 43 | _EXCEPTION_CONTEST_RELOAD_DELAY = 5 * 60 44 | _ACTIVE_CONTEST_RELOAD_DELAY = 5 * 60 45 | _ACTIVATE_BEFORE = 20 * 60 46 | 47 | _RUNNING_PHASES = ('CODING', 'PENDING_SYSTEM_TEST', 'SYSTEM_TEST') 48 | 49 | def __init__(self, cache_master): 50 | self.cache_master = cache_master 51 | 52 | self.contests = [] 53 | self.contest_by_id = {} 54 | self.contests_by_phase = {phase: [] for phase in cf.Contest.PHASES} 55 | self.contests_by_phase['_RUNNING'] = [] 56 | self.contests_last_cache = 0 57 | 58 | self.reload_lock = asyncio.Lock() 59 | self.reload_exception = None 60 | self.next_delay = None 61 | 62 | self.logger = logging.getLogger(self.__class__.__name__) 63 | 64 | async def run(self): 65 | await self._try_disk() 66 | self._update_task.start() 67 | 68 | async def reload_now(self): 69 | """Force a reload. If currently reloading it will wait until done.""" 70 | reloading = self.reload_lock.locked() 71 | if reloading: 72 | # Wait until reload complete. 73 | # To wait until lock is free, await acquire then release immediately. 74 | async with self.reload_lock: 75 | pass 76 | else: 77 | await self._update_task.manual_trigger() 78 | 79 | if self.reload_exception: 80 | raise self.reload_exception 81 | 82 | def get_contest(self, contest_id): 83 | try: 84 | return self.contest_by_id[contest_id] 85 | except KeyError: 86 | raise ContestNotFound(contest_id) 87 | 88 | def get_problemset(self, contest_id): 89 | return self.cache_master.conn.get_problemset_from_contest(contest_id) 90 | 91 | def get_contests_in_phase(self, phase): 92 | return self.contests_by_phase[phase] 93 | 94 | async def _try_disk(self): 95 | async with self.reload_lock: 96 | contests = self.cache_master.conn.fetch_contests() 97 | if not contests: 98 | self.logger.info('Contest cache on disk is empty.') 99 | return 100 | await self._update(contests, from_api=False) 101 | 102 | @tasks.task_spec(name='ContestCacheUpdate') 103 | async def _update_task(self, _): 104 | async with self.reload_lock: 105 | self.next_delay = await self._reload_contests() 106 | self.reload_exception = None 107 | 108 | @_update_task.waiter() 109 | async def _update_task_waiter(self): 110 | await asyncio.sleep(self.next_delay) 111 | 112 | @_update_task.exception_handler() 113 | async def _update_task_exception_handler(self, ex): 114 | self.reload_exception = ex 115 | self.next_delay = self._EXCEPTION_CONTEST_RELOAD_DELAY 116 | 117 | async def _reload_contests(self): 118 | contests = await cf.contest.to_list() 119 | delay = await self._update(contests) 120 | return delay 121 | 122 | async def _update(self, contests, from_api=True): 123 | self.logger.info( 124 | f'{len(contests)} contests fetched from {"API" if from_api else "disk"}' 125 | ) 126 | contests.sort(key=lambda contest: (contest.startTimeSeconds, contest.id)) 127 | 128 | if from_api: 129 | rc = self.cache_master.conn.cache_contests(contests) 130 | self.logger.info(f'{rc} contests stored in database') 131 | 132 | contests_by_phase = {phase: [] for phase in cf.Contest.PHASES} 133 | contests_by_phase['_RUNNING'] = [] 134 | contest_by_id = {} 135 | for contest in contests: 136 | contests_by_phase[contest.phase].append(contest) 137 | contest_by_id[contest.id] = contest 138 | if contest.phase in self._RUNNING_PHASES: 139 | contests_by_phase['_RUNNING'].append(contest) 140 | 141 | now = time.time() 142 | delay = self._NORMAL_CONTEST_RELOAD_DELAY 143 | 144 | for contest in contests_by_phase['BEFORE']: 145 | at = contest.startTimeSeconds - self._ACTIVATE_BEFORE 146 | if at > now: 147 | # Reload at _ACTIVATE_BEFORE before contest to monitor contest delays. 148 | delay = min(delay, at - now) 149 | else: 150 | # The contest starts in <= _ACTIVATE_BEFORE. 151 | # Reload at contest start, or after 152 | # _ACTIVE_CONTEST_RELOAD_DELAY, whichever comes first. 153 | delay = min( 154 | contest.startTimeSeconds - now, self._ACTIVE_CONTEST_RELOAD_DELAY 155 | ) 156 | 157 | if contests_by_phase['_RUNNING']: 158 | # If any contest is running, reload at an increased rate to detect FINISHED 159 | delay = min(delay, self._ACTIVE_CONTEST_RELOAD_DELAY) 160 | 161 | self.contests = contests 162 | self.contests_by_phase = contests_by_phase 163 | self.contest_by_id = contest_by_id 164 | self.contests_last_cache = time.time() 165 | 166 | cf_common.event_sys.dispatch(events.ContestListRefresh, self.contests.copy()) 167 | 168 | return delay 169 | 170 | 171 | class ProblemCache: 172 | _RELOAD_INTERVAL = 6 * 60 * 60 173 | 174 | def __init__(self, cache_master): 175 | self.cache_master = cache_master 176 | 177 | self.problems = [] 178 | self.problem_by_name = {} 179 | self.problems_last_cache = 0 180 | 181 | self.reload_lock = asyncio.Lock() 182 | self.reload_exception = None 183 | 184 | self.logger = logging.getLogger(self.__class__.__name__) 185 | 186 | async def run(self): 187 | await self._try_disk() 188 | self._update_task.start() 189 | 190 | async def reload_now(self): 191 | """Force a reload. If currently reloading it will wait until done.""" 192 | reloading = self.reload_lock.locked() 193 | if reloading: 194 | # Wait until reload complete. 195 | # To wait until lock is free, await acquire then release immediately. 196 | async with self.reload_lock: 197 | pass 198 | else: 199 | await self._update_task.manual_trigger() 200 | 201 | if self.reload_exception: 202 | raise self.reload_exception 203 | 204 | async def _try_disk(self): 205 | async with self.reload_lock: 206 | problems = self.cache_master.conn.fetch_problems() 207 | if not problems: 208 | self.logger.info('Problem cache on disk is empty.') 209 | return 210 | self.problems = problems 211 | self.problem_by_name = {problem.name: problem for problem in problems} 212 | self.logger.info(f'{len(self.problems)} problems fetched from disk') 213 | 214 | @tasks.task_spec( 215 | name='ProblemCacheUpdate', waiter=tasks.Waiter.fixed_delay(_RELOAD_INTERVAL) 216 | ) 217 | async def _update_task(self, _): 218 | async with self.reload_lock: 219 | await self._reload_problems() 220 | self.reload_exception = None 221 | 222 | @_update_task.exception_handler() 223 | async def _update_task_exception_handler(self, ex): 224 | self.reload_exception = ex 225 | 226 | async def _reload_problems(self): 227 | problems, _ = await cf.problemset.problems() 228 | await self._update(problems) 229 | 230 | async def _update(self, problems): 231 | self.logger.info(f'{len(problems)} problems fetched from API') 232 | contest_map = { 233 | problem.contestId: self.cache_master.contest_cache.contest_by_id.get( 234 | problem.contestId 235 | ) 236 | for problem in problems 237 | } 238 | 239 | def keep(problem): 240 | return contest_map[problem.contestId] and problem.has_metadata() 241 | 242 | filtered_problems = list(filter(keep, problems)) 243 | problem_by_name = { 244 | problem.name: problem # This will discard some valid problems 245 | for problem in filtered_problems 246 | } 247 | self.logger.info(f'Keeping {len(problem_by_name)} problems') 248 | 249 | self.problems = list(problem_by_name.values()) 250 | self.problem_by_name = problem_by_name 251 | self.problems_last_cache = time.time() 252 | 253 | rc = self.cache_master.conn.cache_problems(self.problems) 254 | self.logger.info(f'{rc} problems stored in database') 255 | 256 | 257 | class ProblemsetCacheError(CacheError): 258 | pass 259 | 260 | 261 | class ProblemsetNotCached(ProblemsetCacheError): 262 | def __init__(self, contest_id): 263 | super().__init__(f'Problemset for contest with id {contest_id} not cached.') 264 | 265 | 266 | class ProblemsetCache: 267 | _MONITOR_PERIOD_SINCE_CONTEST_END = 14 * 24 * 60 * 60 268 | _RELOAD_DELAY = 60 * 60 269 | 270 | def __init__(self, cache_master): 271 | self.problems = [] 272 | # problem -> list of contests in which it appears 273 | self.problem_to_contests = defaultdict(list) 274 | self.cache_master = cache_master 275 | self.update_lock = asyncio.Lock() 276 | self.logger = logging.getLogger(self.__class__.__name__) 277 | 278 | async def run(self): 279 | if self.cache_master.conn.problemset_empty(): 280 | self.logger.warning( 281 | 'Problemset cache on disk is empty.' 282 | ' This must be populated manually before use.' 283 | ) 284 | self._update_task.start() 285 | 286 | async def update_for_contest(self, contest_id): 287 | """Update problemset for a particular contest. Intended for manual trigger.""" 288 | async with self.update_lock: 289 | contest = self.cache_master.contest_cache.get_contest(contest_id) 290 | problemset, _ = await self._fetch_problemsets([contest], force_fetch=True) 291 | self.cache_master.conn.clear_problemset(contest_id) 292 | self._save_problems(problemset) 293 | return len(problemset) 294 | 295 | async def update_for_all(self): 296 | """Update problemsets for all finished contests. Intended for manual trigger.""" 297 | async with self.update_lock: 298 | contests = self.cache_master.contest_cache.contests_by_phase['FINISHED'] 299 | problemsets, _ = await self._fetch_problemsets(contests, force_fetch=True) 300 | self.cache_master.conn.clear_problemset() 301 | self._save_problems(problemsets) 302 | return len(problemsets) 303 | 304 | @tasks.task_spec( 305 | name='ProblemsetCacheUpdate', waiter=tasks.Waiter.fixed_delay(_RELOAD_DELAY) 306 | ) 307 | async def _update_task(self, _): 308 | async with self.update_lock: 309 | contests = self.cache_master.contest_cache.contests_by_phase['FINISHED'] 310 | new_problems, updated_problems = await self._fetch_problemsets(contests) 311 | self._save_problems(new_problems + updated_problems) 312 | self._update_from_disk() 313 | self.logger.info( 314 | f'{len(new_problems)} new problems saved and' 315 | f' {len(updated_problems)} saved problems updated.' 316 | ) 317 | 318 | async def _fetch_problemsets(self, contests, *, force_fetch=False): 319 | # We assume it is possible for problems in the same contest to get 320 | # assigned rating at different times. 321 | new_contest_ids = [] 322 | # List of (id, set of saved rated problem indices) pairs. 323 | contests_to_refetch = [] 324 | if force_fetch: 325 | new_contest_ids = [contest.id for contest in contests] 326 | else: 327 | now = time.time() 328 | for contest in contests: 329 | if now > contest.end_time + self._MONITOR_PERIOD_SINCE_CONTEST_END: 330 | # Contest too old, we do not want to check it. 331 | continue 332 | problemset = self.cache_master.conn.fetch_problemset(contest.id) 333 | if not problemset: 334 | new_contest_ids.append(contest.id) 335 | continue 336 | rated_problem_idx = { 337 | prob.index for prob in problemset if prob.rating is not None 338 | } 339 | if len(rated_problem_idx) < len(problemset): 340 | contests_to_refetch.append((contest.id, rated_problem_idx)) 341 | 342 | new_problems, updated_problems = [], [] 343 | for contest_id in new_contest_ids: 344 | new_problems += await self._fetch_for_contest(contest_id) 345 | for contest_id, rated_problem_idx in contests_to_refetch: 346 | updated_problems += [ 347 | prob 348 | for prob in await self._fetch_for_contest(contest_id) 349 | if prob.rating is not None and prob.index not in rated_problem_idx 350 | ] 351 | 352 | return new_problems, updated_problems 353 | 354 | async def _fetch_for_contest(self, contest_id): 355 | try: 356 | _, problemset, _ = await cf.contest.standings( 357 | contest_id=contest_id, from_=1, count=1 358 | ) 359 | except cf.CodeforcesApiError as er: 360 | self.logger.warning( 361 | f'Problemset fetch failed for contest {contest_id}. {er!r}' 362 | ) 363 | problemset = [] 364 | return problemset 365 | 366 | def _save_problems(self, problems): 367 | rc = self.cache_master.conn.cache_problemset(problems) 368 | self.logger.info(f'Saved {rc} problems to database.') 369 | 370 | def get_problemset(self, contest_id): 371 | problemset = self.cache_master.conn.fetch_problemset(contest_id) 372 | if not problemset: 373 | raise ProblemsetNotCached(contest_id) 374 | return problemset 375 | 376 | def _update_from_disk(self): 377 | self.problems = self.cache_master.conn.fetch_problems2() 378 | self.problem_to_contests = defaultdict(list) 379 | for problem in self.problems: 380 | try: 381 | contest = cf_common.cache2.contest_cache.get_contest(problem.contestId) 382 | problem_id = (problem.name, contest.startTimeSeconds) 383 | self.problem_to_contests[problem_id].append(contest.id) 384 | except ContestNotFound: 385 | pass 386 | 387 | 388 | class RatingChangesCache: 389 | _RATED_DELAY = 36 * 60 * 60 390 | _RELOAD_DELAY = 10 * 60 391 | 392 | def __init__(self, cache_master): 393 | self.cache_master = cache_master 394 | self.monitored_contests = [] 395 | self.handle_rating_cache = {} 396 | self.logger = logging.getLogger(self.__class__.__name__) 397 | 398 | async def run(self): 399 | self._refresh_handle_cache() 400 | if not self.handle_rating_cache: 401 | self.logger.warning( 402 | 'Rating changes cache on disk is empty.' 403 | ' This must be populated manually before use.' 404 | ) 405 | self._update_task.start() 406 | 407 | async def fetch_contest(self, contest_id): 408 | """Fetch rating changes for a specific contest. 409 | 410 | Intended for manual trigger. 411 | """ 412 | contest = self.cache_master.contest_cache.contest_by_id[contest_id] 413 | changes = await self._fetch([contest]) 414 | self.cache_master.conn.clear_rating_changes(contest_id=contest_id) 415 | self._save_changes(changes) 416 | return len(changes) 417 | 418 | async def fetch_all_contests(self): 419 | """Fetch rating changes for all contests. 420 | 421 | Intended for manual trigger. 422 | """ 423 | self.cache_master.conn.clear_rating_changes() 424 | return await self.fetch_missing_contests() 425 | 426 | async def fetch_missing_contests(self): 427 | """Fetch rating changes for contests which are not saved in database. 428 | 429 | Intended for manual trigger. 430 | """ 431 | contests = self.cache_master.contest_cache.contests_by_phase['FINISHED'] 432 | contests = [ 433 | contest 434 | for contest in contests 435 | if not self.has_rating_changes_saved(contest.id) 436 | ] 437 | total_changes = 0 438 | for contests_chunk in paginator.chunkify( 439 | contests, _CONTESTS_PER_BATCH_IN_CACHE_UPDATES 440 | ): 441 | contests_chunk = await self._fetch(contests_chunk) 442 | self._save_changes(contests_chunk) 443 | total_changes += len(contests_chunk) 444 | return total_changes 445 | 446 | def is_newly_finished_without_rating_changes(self, contest): 447 | now = time.time() 448 | return ( 449 | contest.phase == 'FINISHED' 450 | and now - contest.end_time < self._RATED_DELAY 451 | and not self.has_rating_changes_saved(contest.id) 452 | ) 453 | 454 | @tasks.task_spec( 455 | name='RatingChangesCacheUpdate', 456 | waiter=tasks.Waiter.for_event(events.ContestListRefresh), 457 | ) 458 | async def _update_task(self, _): 459 | # Some notes: 460 | # A hack phase is tagged as FINISHED with empty list of rating changes. 461 | # After the hack phase, the phase changes to systest then again 462 | # FINISHED. Since we cannot differentiate between the two FINISHED 463 | # phases, we are forced to fetch during both. 464 | # A contest also has empty list if it is unrated. We assume that is the 465 | # case if _RATED_DELAY time has passed since the contest end. 466 | 467 | to_monitor = [ 468 | contest 469 | for contest in self.cache_master.contest_cache.contests_by_phase['FINISHED'] 470 | if self.is_newly_finished_without_rating_changes(contest) 471 | and not _is_blacklisted(contest) 472 | ] 473 | 474 | cur_ids = {contest.id for contest in self.monitored_contests} 475 | new_ids = {contest.id for contest in to_monitor} 476 | if new_ids != cur_ids: 477 | await self._monitor_task.stop() 478 | if to_monitor: 479 | self.monitored_contests = to_monitor 480 | self._monitor_task.start() 481 | else: 482 | self.monitored_contests = [] 483 | 484 | @tasks.task_spec( 485 | name='RatingChangesCacheUpdate.MonitorNewlyFinishedContests', 486 | waiter=tasks.Waiter.fixed_delay(_RELOAD_DELAY), 487 | ) 488 | async def _monitor_task(self, _): 489 | self.monitored_contests = [ 490 | contest 491 | for contest in self.monitored_contests 492 | if self.is_newly_finished_without_rating_changes(contest) 493 | and not _is_blacklisted(contest) 494 | ] 495 | 496 | if not self.monitored_contests: 497 | self.logger.info( 498 | 'Rated changes fetched for contests that were being monitored.' 499 | ) 500 | await self._monitor_task.stop() 501 | return 502 | 503 | contest_changes_pairs = await self._fetch(self.monitored_contests) 504 | # Sort by the rating update time of the first change in the list of 505 | # changes, assuming every change in the list has the same time. 506 | contest_changes_pairs.sort(key=lambda pair: pair[1][0].ratingUpdateTimeSeconds) 507 | self._save_changes(contest_changes_pairs) 508 | for contest, changes in contest_changes_pairs: 509 | cf_common.event_sys.dispatch( 510 | events.RatingChangesUpdate, contest=contest, rating_changes=changes 511 | ) 512 | 513 | async def _fetch(self, contests): 514 | all_changes = [] 515 | for contest in contests: 516 | try: 517 | changes = await cf.contest.ratingChanges(contest_id=contest.id) 518 | self.logger.info( 519 | f'{len(changes)} rating changes fetched for contest {contest.id}' 520 | ) 521 | if changes: 522 | all_changes.append((contest, changes)) 523 | except cf.CodeforcesApiError as er: 524 | self.logger.warning( 525 | f'Fetch rating changes failed for contest {contest.id},' 526 | f' ignoring. {er!r}' 527 | ) 528 | pass 529 | return all_changes 530 | 531 | def _save_changes(self, contest_changes_pairs): 532 | flattened = [ 533 | change for _, changes in contest_changes_pairs for change in changes 534 | ] 535 | if not flattened: 536 | return 537 | rc = self.cache_master.conn.save_rating_changes(flattened) 538 | self.logger.info(f'Saved {rc} changes to database.') 539 | self._refresh_handle_cache() 540 | 541 | def _refresh_handle_cache(self): 542 | changes = self.cache_master.conn.get_all_rating_changes() 543 | handle_rating_cache = {} 544 | for change in changes: 545 | handle_rating_cache[change.handle] = change.newRating 546 | self.handle_rating_cache = handle_rating_cache 547 | self.logger.info(f'Ratings for {len(handle_rating_cache)} handles cached') 548 | 549 | def get_users_with_more_than_n_contests(self, time_cutoff, n): 550 | return self.cache_master.conn.get_users_with_more_than_n_contests( 551 | time_cutoff, n 552 | ) 553 | 554 | def get_rating_changes_for_contest(self, contest_id): 555 | return self.cache_master.conn.get_rating_changes_for_contest(contest_id) 556 | 557 | def has_rating_changes_saved(self, contest_id): 558 | return self.cache_master.conn.has_rating_changes_saved(contest_id) 559 | 560 | def get_rating_changes_for_handle(self, handle): 561 | return self.cache_master.conn.get_rating_changes_for_handle(handle) 562 | 563 | def get_current_rating(self, handle, default_if_absent=False): 564 | return self.handle_rating_cache.get( 565 | handle, cf.DEFAULT_RATING if default_if_absent else None 566 | ) 567 | 568 | def get_all_ratings(self): 569 | return list(self.handle_rating_cache.values()) 570 | 571 | 572 | class RanklistCacheError(CacheError): 573 | pass 574 | 575 | 576 | class RanklistNotMonitored(RanklistCacheError): 577 | def __init__(self, contest): 578 | super().__init__(f'The ranklist for `{contest.name}` is not being monitored') 579 | self.contest = contest 580 | 581 | 582 | class RanklistCache: 583 | _RELOAD_DELAY = 2 * 60 584 | 585 | def __init__(self, cache_master): 586 | self.cache_master = cache_master 587 | self.monitored_contests = [] 588 | self.ranklist_by_contest = {} 589 | self.logger = logging.getLogger(self.__class__.__name__) 590 | 591 | async def run(self): 592 | self._update_task.start() 593 | 594 | # Currently ranklist monitoring only supports caching unofficial ranklists 595 | # If official ranklist is asked, the cache will throw RanklistNotMonitored Error 596 | def get_ranklist(self, contest, show_official): 597 | if show_official or contest.id not in self.ranklist_by_contest: 598 | raise RanklistNotMonitored(contest) 599 | return self.ranklist_by_contest[contest.id] 600 | 601 | @tasks.task_spec( 602 | name='RanklistCacheUpdate', 603 | waiter=tasks.Waiter.for_event(events.ContestListRefresh), 604 | ) 605 | async def _update_task(self, _): 606 | contests_by_phase = self.cache_master.contest_cache.contests_by_phase 607 | running_contests = contests_by_phase['_RUNNING'] 608 | 609 | rating_cache = self.cache_master.rating_changes_cache 610 | finished_contests = [ 611 | contest 612 | for contest in contests_by_phase['FINISHED'] 613 | if not _is_blacklisted(contest) 614 | and rating_cache.is_newly_finished_without_rating_changes(contest) 615 | ] 616 | 617 | to_monitor = running_contests + finished_contests 618 | cur_ids = {contest.id for contest in self.monitored_contests} 619 | new_ids = {contest.id for contest in to_monitor} 620 | if new_ids != cur_ids: 621 | await self._monitor_task.stop() 622 | if to_monitor: 623 | self.monitored_contests = to_monitor 624 | self._monitor_task.start() 625 | else: 626 | self.ranklist_by_contest = {} 627 | 628 | @tasks.task_spec( 629 | name='RanklistCacheUpdate.MonitorActiveContests', 630 | waiter=tasks.Waiter.fixed_delay(_RELOAD_DELAY), 631 | ) 632 | async def _monitor_task(self, _): 633 | cache = self.cache_master.rating_changes_cache 634 | self.monitored_contests = [ 635 | contest 636 | for contest in self.monitored_contests 637 | if not _is_blacklisted(contest) 638 | and ( 639 | contest.phase != 'FINISHED' 640 | or cache.is_newly_finished_without_rating_changes(contest) 641 | ) 642 | ] 643 | 644 | if not self.monitored_contests: 645 | self.ranklist_by_contest = {} 646 | self.logger.info('No more active contests for which to monitor ranklists.') 647 | await self._monitor_task.stop() 648 | return 649 | 650 | ranklist_by_contest = await self._fetch(self.monitored_contests) 651 | # If any ranklist could not be fetched, the old ranklist is kept. 652 | for contest_id, ranklist in ranklist_by_contest.items(): 653 | self.ranklist_by_contest[contest_id] = ranklist 654 | 655 | @staticmethod 656 | async def _get_contest_details(contest_id, show_unofficial): 657 | contest, problems, standings = await cf.contest.standings( 658 | contest_id=contest_id, show_unofficial=show_unofficial 659 | ) 660 | 661 | # Exclude PRACTICE and MANAGER 662 | standings = [ 663 | row 664 | for row in standings 665 | if row.party.participantType 666 | in ('CONTESTANT', 'OUT_OF_COMPETITION', 'VIRTUAL') 667 | ] 668 | 669 | return contest, problems, standings 670 | 671 | # Fetch final rating changes from CF. 672 | # For older contests. 673 | async def _get_ranklist_with_fetched_changes(self, contest_id, show_unofficial): 674 | contest, problems, standings = await self._get_contest_details( 675 | contest_id, show_unofficial 676 | ) 677 | now = time.time() 678 | 679 | is_rated = False 680 | try: 681 | changes = await cf.contest.ratingChanges(contest_id=contest_id) 682 | # For contests intended to be rated but declared unrated 683 | # an empty list is returned. 684 | is_rated = len(changes) > 0 685 | except cf.RatingChangesUnavailableError: 686 | pass 687 | 688 | ranklist = None 689 | if is_rated: 690 | ranklist = Ranklist(contest, problems, standings, now, is_rated=is_rated) 691 | delta_by_handle = { 692 | change.handle: change.newRating - change.oldRating for change in changes 693 | } 694 | ranklist.set_deltas(delta_by_handle) 695 | 696 | return ranklist 697 | 698 | # Rating changes have not been applied yet, predict rating changes. 699 | # For running/recent/unrated contests. 700 | async def _get_ranklist_with_predicted_changes(self, contest_id, show_unofficial): 701 | contest, problems, standings = await self._get_contest_details( 702 | contest_id, show_unofficial 703 | ) 704 | now = time.time() 705 | 706 | standings_official = None 707 | if not show_unofficial: 708 | standings_official = standings 709 | else: 710 | _, _, standings_official = await cf.contest.standings(contest_id=contest_id) 711 | 712 | has_teams = any(row.party.teamId is not None for row in standings_official) 713 | if cf_common.is_nonstandard_contest(contest) or has_teams: 714 | # The contest is not traditionally rated 715 | ranklist = Ranklist(contest, problems, standings, now, is_rated=False) 716 | else: 717 | current_rating = await CacheSystem.getUsersEffectiveRating(activeOnly=False) 718 | current_rating = { 719 | row.party.members[0].handle: current_rating.get( 720 | row.party.members[0].handle, 1500 721 | ) 722 | for row in standings_official 723 | } 724 | if 'Educational' in contest.name: 725 | # For some reason educational contests return all contestants 726 | # in ranklist even when unofficial contestants are not 727 | # requested. 728 | current_rating = { 729 | handle: rating 730 | for handle, rating in current_rating.items() 731 | if rating < 2100 732 | } 733 | ranklist = Ranklist(contest, problems, standings, now, is_rated=True) 734 | ranklist.predict(current_rating) 735 | return ranklist 736 | 737 | async def generate_ranklist( 738 | self, 739 | contest_id, 740 | *, 741 | fetch_changes=False, 742 | predict_changes=False, 743 | show_unofficial=True, 744 | ): 745 | assert fetch_changes ^ predict_changes 746 | 747 | ranklist = None 748 | if fetch_changes: 749 | ranklist = await self._get_ranklist_with_fetched_changes( 750 | contest_id, show_unofficial 751 | ) 752 | if ranklist is None: 753 | # Either predict_changes was true or fetching rating changes failed 754 | ranklist = await self._get_ranklist_with_predicted_changes( 755 | contest_id, show_unofficial 756 | ) 757 | 758 | # For some reason Educational contests also have div1 peeps in the 759 | # official standings. hence we need to manually weed them out 760 | if not show_unofficial and 'Educational' in ranklist.contest.name: 761 | ranklist.remove_unofficial_contestants() 762 | 763 | return ranklist 764 | 765 | async def generate_vc_ranklist(self, contest_id, handle_to_member_id): 766 | handles = list(handle_to_member_id.keys()) 767 | contest, problems, standings = await cf.contest.standings( 768 | contest_id=contest_id, show_unofficial=True 769 | ) 770 | # Exclude PRACTICE, MANAGER and OUR_OF_COMPETITION 771 | standings = [ 772 | row 773 | for row in standings 774 | if row.party.participantType == 'CONTESTANT' 775 | or row.party.members[0].handle in handles 776 | ] 777 | standings.sort(key=lambda row: row.rank) 778 | standings = [row._replace(rank=i + 1) for i, row in enumerate(standings)] 779 | now = time.time() 780 | rating_changes = await cf.contest.ratingChanges(contest_id=contest_id) 781 | current_official_rating = { 782 | rating_change.handle: rating_change.oldRating 783 | for rating_change in rating_changes 784 | } 785 | 786 | # TODO: assert that none of the given handles are in the official standings. 787 | handles = [ 788 | row.party.members[0].handle 789 | for row in standings 790 | if row.party.members[0].handle in handles 791 | and row.party.participantType == 'VIRTUAL' 792 | ] 793 | current_vc_rating = { 794 | handle: cf_common.user_db.get_vc_rating(handle_to_member_id.get(handle)) 795 | for handle in handles 796 | } 797 | ranklist = Ranklist(contest, problems, standings, now, is_rated=True) 798 | delta_by_handle = {} 799 | for handle in handles: 800 | mixed_ratings = current_official_rating.copy() 801 | mixed_ratings[handle] = current_vc_rating.get(handle) 802 | ranklist.predict(mixed_ratings) 803 | delta_by_handle[handle] = ranklist.delta_by_handle.get(handle, 0) 804 | 805 | ranklist.delta_by_handle = delta_by_handle 806 | return ranklist 807 | 808 | async def _fetch(self, contests): 809 | ranklist_by_contest = {} 810 | for contest in contests: 811 | try: 812 | ranklist = await self.generate_ranklist( 813 | contest.id, predict_changes=True 814 | ) 815 | ranklist_by_contest[contest.id] = ranklist 816 | self.logger.info(f'Ranklist fetched for contest {contest.id}') 817 | except cf.CodeforcesApiError as er: 818 | self.logger.warning( 819 | f'Ranklist fetch failed for contest {contest.id}. {er!r}' 820 | ) 821 | 822 | return ranklist_by_contest 823 | 824 | 825 | class CacheSystem: 826 | def __init__(self, conn): 827 | self.conn = conn 828 | self.contest_cache = ContestCache(self) 829 | self.problem_cache = ProblemCache(self) 830 | self.rating_changes_cache = RatingChangesCache(self) 831 | self.ranklist_cache = RanklistCache(self) 832 | self.problemset_cache = ProblemsetCache(self) 833 | 834 | async def run(self): 835 | await self.rating_changes_cache.run() 836 | await self.ranklist_cache.run() 837 | await self.contest_cache.run() 838 | await self.problem_cache.run() 839 | await self.problemset_cache.run() 840 | 841 | @staticmethod 842 | @cached(ttl=30 * 60) 843 | async def getUsersEffectiveRating(*, activeOnly=None): 844 | """Returns a mapping from user handles to their effective rating.""" 845 | ratedList = await cf.user.ratedList(activeOnly=activeOnly) 846 | users_effective_rating_dict = { 847 | user.handle: user.effective_rating for user in ratedList 848 | } 849 | return users_effective_rating_dict 850 | -------------------------------------------------------------------------------- /tle/util/codeforces_api.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import itertools 4 | import logging 5 | import time 6 | from collections import defaultdict, deque 7 | from collections.abc import Iterable, Iterator, Sequence 8 | from typing import Any, NamedTuple, Optional 9 | 10 | import aiohttp 11 | from discord.ext import commands 12 | 13 | from tle.util import codeforces_common as cf_common 14 | 15 | # ruff: noqa: N815 16 | 17 | API_BASE_URL = 'https://codeforces.com/api/' 18 | CONTEST_BASE_URL = 'https://codeforces.com/contest/' 19 | CONTESTS_BASE_URL = 'https://codeforces.com/contests/' 20 | GYM_BASE_URL = 'https://codeforces.com/gym/' 21 | PROFILE_BASE_URL = 'https://codeforces.com/profile/' 22 | ACMSGURU_BASE_URL = 'https://codeforces.com/problemsets/acmsguru/' 23 | GYM_ID_THRESHOLD = 100000 24 | DEFAULT_RATING = 1500 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class Rank(NamedTuple): 30 | """Codeforces rank.""" 31 | 32 | low: Optional[int] 33 | high: Optional[int] 34 | title: str 35 | title_abbr: Optional[str] 36 | color_graph: Optional[str] 37 | color_embed: Optional[int] 38 | 39 | 40 | RATED_RANKS = ( 41 | Rank(-(10**9), 1200, 'Newbie', 'N', '#CCCCCC', 0x808080), 42 | Rank(1200, 1400, 'Pupil', 'P', '#77FF77', 0x008000), 43 | Rank(1400, 1600, 'Specialist', 'S', '#77DDBB', 0x03A89E), 44 | Rank(1600, 1900, 'Expert', 'E', '#AAAAFF', 0x0000FF), 45 | Rank(1900, 2100, 'Candidate Master', 'CM', '#FF88FF', 0xAA00AA), 46 | Rank(2100, 2300, 'Master', 'M', '#FFCC88', 0xFF8C00), 47 | Rank(2300, 2400, 'International Master', 'IM', '#FFBB55', 0xF57500), 48 | Rank(2400, 2600, 'Grandmaster', 'GM', '#FF7777', 0xFF3030), 49 | Rank(2600, 3000, 'International Grandmaster', 'IGM', '#FF3333', 0xFF0000), 50 | Rank(3000, 10**9, 'Legendary Grandmaster', 'LGM', '#AA0000', 0xCC0000), 51 | ) 52 | UNRATED_RANK = Rank(None, None, 'Unrated', None, None, None) 53 | 54 | 55 | def rating2rank(rating: Optional[int]) -> Rank: 56 | """Returns the rank corresponding to the given rating.""" 57 | if rating is None: 58 | return UNRATED_RANK 59 | for rank in RATED_RANKS: 60 | assert rank.low is not None and rank.high is not None 61 | if rank.low <= rating < rank.high: 62 | return rank 63 | raise ValueError(f'Rating {rating} outside range of known ranks.') 64 | 65 | 66 | # Data classes 67 | 68 | 69 | class User(NamedTuple): 70 | """Codeforces user.""" 71 | 72 | handle: str 73 | firstName: Optional[str] 74 | lastName: Optional[str] 75 | country: Optional[str] 76 | city: Optional[str] 77 | organization: Optional[str] 78 | contribution: int 79 | rating: Optional[int] 80 | maxRating: Optional[int] 81 | lastOnlineTimeSeconds: int 82 | registrationTimeSeconds: int 83 | friendOfCount: int 84 | titlePhoto: str 85 | 86 | @property 87 | def effective_rating(self) -> int: 88 | """Returns the effective rating of the user.""" 89 | return self.rating if self.rating is not None else DEFAULT_RATING 90 | 91 | @property 92 | def rank(self) -> Rank: 93 | """Returns the rank corresponding to the user's rating.""" 94 | return rating2rank(self.rating) 95 | 96 | @property 97 | def url(self) -> str: 98 | """Returns the URL of the user's profile.""" 99 | return f'{PROFILE_BASE_URL}{self.handle}' 100 | 101 | 102 | class RatingChange(NamedTuple): 103 | """Codeforces rating change.""" 104 | 105 | contestId: int 106 | contestName: str 107 | handle: str 108 | rank: int 109 | ratingUpdateTimeSeconds: int 110 | oldRating: int 111 | newRating: int 112 | 113 | 114 | class Contest(NamedTuple): 115 | """Codeforces contest.""" 116 | 117 | id: int 118 | name: str 119 | startTimeSeconds: Optional[int] 120 | durationSeconds: Optional[int] 121 | type: str 122 | phase: str 123 | preparedBy: Optional[str] 124 | 125 | PHASES = 'BEFORE CODING PENDING_SYSTEM_TEST SYSTEM_TEST FINISHED'.split() 126 | 127 | @property 128 | def end_time(self) -> Optional[int]: 129 | """Returns the end time of the contest.""" 130 | if self.startTimeSeconds is None or self.durationSeconds is None: 131 | return None 132 | return self.startTimeSeconds + self.durationSeconds 133 | 134 | @property 135 | def url(self) -> str: 136 | """Returns the URL of the contest.""" 137 | if self.id < GYM_ID_THRESHOLD: 138 | return f'{CONTEST_BASE_URL}{self.id}' 139 | return f'{GYM_BASE_URL}{self.id}' 140 | 141 | @property 142 | def register_url(self) -> str: 143 | """Returns the URL to register for the contest.""" 144 | return f'{CONTESTS_BASE_URL}{self.id}' 145 | 146 | def matches(self, markers: Iterable[str]) -> bool: 147 | """Returns whether the contest matches any of the given markers.""" 148 | 149 | def filter_and_normalize(s: str) -> str: 150 | return ''.join(x for x in s.lower() if x.isalnum()) 151 | 152 | return any( 153 | filter_and_normalize(marker) in filter_and_normalize(self.name) 154 | for marker in markers 155 | ) 156 | 157 | 158 | class Member(NamedTuple): 159 | """Codeforces party member.""" 160 | 161 | handle: str 162 | 163 | 164 | class Party(NamedTuple): 165 | """Codeforces party.""" 166 | 167 | contestId: Optional[int] 168 | members: list[Member] 169 | participantType: str 170 | teamId: Optional[int] 171 | teamName: Optional[str] 172 | ghost: bool 173 | room: Optional[int] 174 | startTimeSeconds: Optional[int] 175 | 176 | PARTICIPANT_TYPES = ( 177 | 'CONTESTANT', 178 | 'PRACTICE', 179 | 'VIRTUAL', 180 | 'MANAGER', 181 | 'OUT_OF_COMPETITION', 182 | ) 183 | 184 | 185 | class Problem(NamedTuple): 186 | """Codeforces problem.""" 187 | 188 | contestId: Optional[int] 189 | problemsetName: Optional[str] 190 | index: str 191 | name: str 192 | type: str 193 | points: Optional[float] 194 | rating: Optional[int] 195 | tags: list[str] 196 | 197 | @property 198 | def contest_identifier(self) -> str: 199 | """Returns a string identifying the contest.""" 200 | return f'{self.contestId}{self.index}' 201 | 202 | @property 203 | def url(self) -> str: 204 | """Returns the URL of the problem.""" 205 | if self.contestId is None: 206 | assert self.problemsetName == 'acmsguru', ( 207 | f'Unknown problemset {self.problemsetName}' 208 | ) 209 | return f'{ACMSGURU_BASE_URL}problem/99999/{self.index}' 210 | base = CONTEST_BASE_URL if self.contestId < GYM_ID_THRESHOLD else GYM_BASE_URL 211 | return f'{base}{self.contestId}/problem/{self.index}' 212 | 213 | def has_metadata(self) -> bool: 214 | """Returns whether the problem has metadata.""" 215 | return self.contestId is not None and self.rating is not None 216 | 217 | def _matching_tags_dict(self, match_tags: Iterable[str]) -> dict[str, list[str]]: 218 | """Returns a dict with matching tags.""" 219 | tags = defaultdict(list) 220 | for match_tag in match_tags: 221 | for tag in self.tags: 222 | if match_tag in tag: 223 | tags[match_tag].append(tag) 224 | return dict(tags) 225 | 226 | def matches_all_tags(self, match_tags: Iterable[str]) -> bool: 227 | """Returns whether the problem matches all of the given tags.""" 228 | match_tags = set(match_tags) 229 | return len(self._matching_tags_dict(match_tags)) == len(match_tags) 230 | 231 | def matches_any_tag(self, match_tags: Iterable[str]) -> bool: 232 | """Returns whether the problem matches any of the given tags.""" 233 | match_tags = set(match_tags) 234 | return len(self._matching_tags_dict(match_tags)) > 0 235 | 236 | def get_matched_tags(self, match_tags: Iterable[str]) -> list[str]: 237 | """Returns a list of tags that match any of the given tags.""" 238 | return [ 239 | tag 240 | for tags in self._matching_tags_dict(match_tags).values() 241 | for tag in tags 242 | ] 243 | 244 | 245 | class ProblemStatistics(NamedTuple): 246 | """Codeforces problem statistics.""" 247 | 248 | contestId: Optional[int] 249 | index: str 250 | solvedCount: int 251 | 252 | 253 | class Submission(NamedTuple): 254 | """Codeforces submission for a problem.""" 255 | 256 | id: int 257 | contestId: Optional[int] 258 | problem: Problem 259 | author: Party 260 | programmingLanguage: str 261 | verdict: Optional[str] 262 | creationTimeSeconds: int 263 | relativeTimeSeconds: int 264 | 265 | 266 | class RanklistRow(NamedTuple): 267 | """Codeforces ranklist row.""" 268 | 269 | party: Party 270 | rank: int 271 | points: float 272 | penalty: int 273 | problemResults: list['ProblemResult'] 274 | 275 | 276 | class ProblemResult(NamedTuple): 277 | """Codeforces problem result.""" 278 | 279 | points: float 280 | penalty: Optional[int] 281 | rejectedAttemptCount: int 282 | type: str 283 | bestSubmissionTimeSeconds: Optional[int] 284 | 285 | 286 | def make_from_dict(namedtuple_cls, dict_): 287 | """Creates a namedtuple from a subset of values in a dict.""" 288 | field_vals = [dict_.get(field) for field in namedtuple_cls._fields] 289 | return namedtuple_cls._make(field_vals) 290 | 291 | 292 | # Error classes 293 | 294 | 295 | class CodeforcesApiError(commands.CommandError): 296 | """Base class for all API related errors.""" 297 | 298 | def __init__(self, message: Optional[str] = None): 299 | super().__init__(message or 'Codeforces API error') 300 | 301 | 302 | class TrueApiError(CodeforcesApiError): 303 | """An error originating from a valid response of the API.""" 304 | 305 | def __init__(self, comment: str, message: Optional[str] = None): 306 | super().__init__(message) 307 | self.comment = comment 308 | 309 | 310 | class ClientError(CodeforcesApiError): 311 | """An error caused by a request to the API failing.""" 312 | 313 | def __init__(self): 314 | super().__init__('Error connecting to Codeforces API') 315 | 316 | 317 | class HandleNotFoundError(TrueApiError): 318 | """An error caused by a handle not being found on Codeforces.""" 319 | 320 | def __init__(self, comment: str, handle: str): 321 | super().__init__(comment, f'Handle `{handle}` not found on Codeforces') 322 | self.handle = handle 323 | 324 | 325 | class HandleInvalidError(TrueApiError): 326 | """An error caused by a handle not being valid on Codeforces.""" 327 | 328 | def __init__(self, comment: str, handle: str): 329 | super().__init__(comment, f'`{handle}` is not a valid Codeforces handle') 330 | self.handle = handle 331 | 332 | 333 | class CallLimitExceededError(TrueApiError): 334 | """An error caused by the call limit being exceeded.""" 335 | 336 | def __init__(self, comment: str): 337 | super().__init__(comment, 'Codeforces API call limit exceeded') 338 | 339 | 340 | class ContestNotFoundError(TrueApiError): 341 | """An error caused by a contest not being found on Codeforces.""" 342 | 343 | def __init__(self, comment: str, contest_id: Any): 344 | super().__init__( 345 | comment, f'Contest with ID `{contest_id}` not found on Codeforces' 346 | ) 347 | 348 | 349 | class RatingChangesUnavailableError(TrueApiError): 350 | """An error caused by rating changes being unavailable for a contest.""" 351 | 352 | def __init__(self, comment: str, contest_id: Any): 353 | super().__init__( 354 | comment, f'Rating changes unavailable for contest with ID `{contest_id}`' 355 | ) 356 | 357 | 358 | # Codeforces API query methods 359 | 360 | _session: aiohttp.ClientSession = None 361 | 362 | 363 | async def initialize() -> None: 364 | """Initialization for the Codeforces API module.""" 365 | global _session 366 | _session = aiohttp.ClientSession() 367 | 368 | 369 | def _bool_to_str(value: bool) -> str: 370 | if isinstance(value, bool): 371 | return 'true' if value else 'false' 372 | raise TypeError(f'Expected bool, got {value} of type {type(value)}') 373 | 374 | 375 | def cf_ratelimit(f): 376 | tries = 3 377 | per_second = 1 378 | last = deque([0.0] * per_second) 379 | 380 | @functools.wraps(f) 381 | async def wrapped(*args, **kwargs): 382 | for i in itertools.count(): 383 | now = time.time() 384 | 385 | # Next valid slot is 1s after the `per_second`th last request 386 | next_valid = max(now, 1 + last[0]) 387 | last.append(next_valid) 388 | last.popleft() 389 | 390 | # Delay as needed 391 | delay = next_valid - now 392 | if delay > 0: 393 | await asyncio.sleep(delay) 394 | 395 | try: 396 | return await f(*args, **kwargs) 397 | except (ClientError, CallLimitExceededError) as e: 398 | logger.info(f'Try {i + 1}/{tries} at query failed.') 399 | logger.info(repr(e)) 400 | if i < tries - 1: 401 | logger.info('Retrying...') 402 | else: 403 | logger.info('Aborting.') 404 | raise e 405 | raise AssertionError('Unreachable') 406 | 407 | return wrapped 408 | 409 | 410 | @cf_ratelimit 411 | async def _query_api(path: str, data: Any = None): 412 | url = API_BASE_URL + path 413 | try: 414 | logger.info(f'Querying CF API at {url} with {data}') 415 | # Explicitly state encoding (though aiohttp accepts gzip by default) 416 | headers = {'Accept-Encoding': 'gzip'} 417 | async with _session.post(url, data=data, headers=headers) as resp: 418 | try: 419 | respjson = await resp.json() 420 | except aiohttp.ContentTypeError: 421 | logger.warning( 422 | f'CF API did not respond with JSON, status {resp.status}.' 423 | ) 424 | raise CodeforcesApiError 425 | if resp.status == 200: 426 | return respjson['result'] 427 | comment = f'HTTP Error {resp.status}, {respjson.get("comment")}' 428 | except aiohttp.ClientError as e: 429 | logger.error(f'Request to CF API encountered error: {e!r}') 430 | raise ClientError from e 431 | logger.warning(f'Query to CF API failed: {comment}') 432 | if 'limit exceeded' in comment: 433 | raise CallLimitExceededError(comment) 434 | raise TrueApiError(comment) 435 | 436 | 437 | class contest: 438 | @staticmethod 439 | async def to_list(*, gym: Optional[bool] = None) -> list[Contest]: 440 | """Returns a list of contests.""" 441 | params = {} 442 | if gym is not None: 443 | params['gym'] = _bool_to_str(gym) 444 | resp = await _query_api('contest.list', params) 445 | return [make_from_dict(Contest, contest_dict) for contest_dict in resp] 446 | 447 | @staticmethod 448 | async def ratingChanges(*, contest_id: Any) -> list[RatingChange]: 449 | """Returns a list of rating changes for a contest.""" 450 | params = {'contestId': contest_id} 451 | try: 452 | resp = await _query_api('contest.ratingChanges', params) 453 | except TrueApiError as e: 454 | if 'not found' in e.comment: 455 | raise ContestNotFoundError(e.comment, contest_id) 456 | if 'Rating changes are unavailable' in e.comment: 457 | raise RatingChangesUnavailableError(e.comment, contest_id) 458 | raise 459 | return [make_from_dict(RatingChange, change_dict) for change_dict in resp] 460 | 461 | @staticmethod 462 | async def standings( 463 | *, 464 | contest_id: Any, 465 | from_: Optional[int] = None, 466 | count: Optional[int] = None, 467 | handles: Optional[list[str]] = None, 468 | room: Optional[Any] = None, 469 | show_unofficial: Optional[bool] = None, 470 | ) -> tuple[Contest, list[Problem], list[RanklistRow]]: 471 | params = {'contestId': contest_id} 472 | if from_ is not None: 473 | params['from'] = from_ 474 | if count is not None: 475 | params['count'] = count 476 | if handles is not None: 477 | params['handles'] = ';'.join(handles) 478 | if room is not None: 479 | params['room'] = room 480 | if show_unofficial is not None: 481 | params['showUnofficial'] = _bool_to_str(show_unofficial) 482 | try: 483 | resp = await _query_api('contest.standings', params) 484 | except TrueApiError as e: 485 | if 'not found' in e.comment: 486 | raise ContestNotFoundError(e.comment, contest_id) 487 | raise 488 | contest_ = make_from_dict(Contest, resp['contest']) 489 | problems = [ 490 | make_from_dict(Problem, problem_dict) for problem_dict in resp['problems'] 491 | ] 492 | for row in resp['rows']: 493 | row['party']['members'] = [ 494 | make_from_dict(Member, member) for member in row['party']['members'] 495 | ] 496 | row['party'] = make_from_dict(Party, row['party']) 497 | row['problemResults'] = [ 498 | make_from_dict(ProblemResult, problem_result) 499 | for problem_result in row['problemResults'] 500 | ] 501 | ranklist = [make_from_dict(RanklistRow, row_dict) for row_dict in resp['rows']] 502 | return contest_, problems, ranklist 503 | 504 | 505 | class problemset: 506 | @staticmethod 507 | async def problems( 508 | *, tags=None, problemset_name=None 509 | ) -> tuple[list[Problem], list[ProblemStatistics]]: 510 | """Returns a list of problems.""" 511 | params = {} 512 | if tags is not None: 513 | params['tags'] = ';'.join(tags) 514 | if problemset_name is not None: 515 | params['problemsetName'] = problemset_name 516 | resp = await _query_api('problemset.problems', params) 517 | problems = [ 518 | make_from_dict(Problem, problem_dict) for problem_dict in resp['problems'] 519 | ] 520 | problemstats = [ 521 | make_from_dict(ProblemStatistics, problemstat_dict) 522 | for problemstat_dict in resp['problemStatistics'] 523 | ] 524 | return problems, problemstats 525 | 526 | 527 | def user_info_chunkify(handles: Iterable[str]) -> Iterator[list[str]]: 528 | """Yields chunks of handles that can be queried with user.info.""" 529 | # Querying user.info using POST requests is limited to 10000 handles or 2**16 530 | # bytes, so requests might need to be split into chunks 531 | SIZE_LIMIT = 2**16 532 | HANDLE_LIMIT = 10000 533 | chunk = [] 534 | size = 0 535 | for handle in handles: 536 | if size + len(handle) > SIZE_LIMIT or len(chunk) == HANDLE_LIMIT: 537 | yield chunk 538 | chunk = [] 539 | size = 0 540 | chunk.append(handle) 541 | size += len(handle) + 1 542 | if chunk: 543 | yield chunk 544 | 545 | 546 | class user: 547 | @staticmethod 548 | async def info(*, handles: Sequence[str]) -> list[User]: 549 | """Returns a list of user info.""" 550 | chunks = list(user_info_chunkify(handles)) 551 | if len(chunks) > 1: 552 | logger.warning( 553 | f'cf.info request with {len(handles)} handles,' 554 | f' will be chunkified into {len(chunks)} requests.' 555 | ) 556 | 557 | result = [] 558 | for chunk in chunks: 559 | params = {'handles': ';'.join(chunk)} 560 | try: 561 | resp = await _query_api('user.info', params) 562 | except TrueApiError as e: 563 | if 'not found' in e.comment: 564 | # Comment format is "handles: User with handle ***** not found" 565 | handle = e.comment.partition('not found')[0].split()[-1] 566 | raise HandleNotFoundError(e.comment, handle) 567 | raise 568 | result += [make_from_dict(User, user_dict) for user_dict in resp] 569 | return [cf_common.fix_urls(user) for user in result] 570 | 571 | @staticmethod 572 | async def rating(*, handle: str): 573 | """Returns a list of rating changes for a user.""" 574 | params = {'handle': handle} 575 | try: 576 | resp = await _query_api('user.rating', params) 577 | except TrueApiError as e: 578 | if 'not found' in e.comment: 579 | raise HandleNotFoundError(e.comment, handle) 580 | if 'should contain' in e.comment: 581 | raise HandleInvalidError(e.comment, handle) 582 | raise 583 | return [ 584 | make_from_dict(RatingChange, ratingchange_dict) 585 | for ratingchange_dict in resp 586 | ] 587 | 588 | @staticmethod 589 | async def ratedList(*, activeOnly: bool = None) -> list[User]: 590 | """Returns a list of rated users.""" 591 | params = {} 592 | if activeOnly is not None: 593 | params['activeOnly'] = _bool_to_str(activeOnly) 594 | resp = await _query_api('user.ratedList', params) 595 | return [make_from_dict(User, user_dict) for user_dict in resp] 596 | 597 | @staticmethod 598 | async def status( 599 | *, handle: str, from_: Optional[int] = None, count: Optional[int] = None 600 | ) -> list[Submission]: 601 | """Returns a list of submissions for a user.""" 602 | params: dict[str, Any] = {'handle': handle} 603 | if from_ is not None: 604 | params['from'] = from_ 605 | if count is not None: 606 | params['count'] = count 607 | try: 608 | resp = await _query_api('user.status', params) 609 | except TrueApiError as e: 610 | if 'not found' in e.comment: 611 | raise HandleNotFoundError(e.comment, handle) 612 | if 'should contain' in e.comment: 613 | raise HandleInvalidError(e.comment, handle) 614 | raise 615 | for submission in resp: 616 | submission['problem'] = make_from_dict(Problem, submission['problem']) 617 | submission['author']['members'] = [ 618 | make_from_dict(Member, member) 619 | for member in submission['author']['members'] 620 | ] 621 | submission['author'] = make_from_dict(Party, submission['author']) 622 | return [make_from_dict(Submission, submission_dict) for submission_dict in resp] 623 | 624 | 625 | async def _resolve_redirect(handle: str) -> Optional[str]: 626 | url = PROFILE_BASE_URL + handle 627 | async with _session.head(url) as r: 628 | if r.status == 200: 629 | return handle 630 | if r.status == 302: 631 | redirected = r.headers.get('Location') 632 | if '/profile/' not in redirected: 633 | # Ended up not on profile page, probably invalid handle 634 | return None 635 | return redirected.split('/profile/')[-1] 636 | raise CodeforcesApiError(f'Something went wrong trying to redirect {url}') 637 | 638 | 639 | async def _resolve_handle_to_new_user( 640 | handle: str, 641 | ) -> Optional[User]: 642 | new_handle = await _resolve_redirect(handle) 643 | if new_handle is None: 644 | return None 645 | (cf_user,) = await user.info(handles=[new_handle]) 646 | return cf_user 647 | 648 | 649 | async def _resolve_handles(handles: Iterable[str]) -> dict[str, User]: 650 | chunks = user_info_chunkify(handles) 651 | 652 | resolved_handles: dict[str, User] = {} 653 | for handle_chunk in chunks: 654 | while handle_chunk: 655 | try: 656 | cf_users = await user.info(handles=handle_chunk) 657 | # No failure, all handles resolve to users, 658 | for handle, cf_user in zip(handle_chunk, cf_users, strict=False): 659 | if cf_user is not None: 660 | resolved_handles[handle] = cf_user 661 | break 662 | except HandleNotFoundError as e: 663 | # Handle not found, drop it. 664 | logger.warning(f'Handle {e.handle} not found, dropping it.') 665 | handle_chunk.remove(e.handle) 666 | return resolved_handles 667 | 668 | 669 | async def resolve_redirects( 670 | handles: Iterable[str], skip_filter: bool = False 671 | ) -> dict[str, User]: 672 | """Returns a mapping of handles to their resolved CF users.""" 673 | resolved_handles = await _resolve_handles(handles) 674 | if skip_filter: 675 | return resolved_handles 676 | 677 | return { 678 | handle: cf_user 679 | for handle, cf_user in resolved_handles.items() 680 | if cf_user is not None and handle != cf_user.handle 681 | } 682 | -------------------------------------------------------------------------------- /tle/util/codeforces_common.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import functools 3 | import itertools 4 | import json 5 | import logging 6 | import math 7 | import time 8 | from collections import defaultdict 9 | 10 | import discord 11 | from discord.ext import commands 12 | 13 | from tle import constants 14 | from tle.util import cache_system2, codeforces_api as cf, db, events 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | # Connection to database 19 | user_db = None 20 | 21 | # Cache system 22 | cache2 = None 23 | 24 | # Event system 25 | event_sys = events.EventSystem() 26 | 27 | _contest_id_to_writers_map = None 28 | 29 | _initialize_done = False 30 | 31 | active_groups = defaultdict(set) 32 | 33 | 34 | async def initialize(nodb): 35 | global cache2 36 | global user_db 37 | global event_sys 38 | global _contest_id_to_writers_map 39 | global _initialize_done 40 | 41 | if _initialize_done: 42 | # This happens if the bot loses connection to Discord and on_ready is 43 | # triggered again when it reconnects. 44 | return 45 | 46 | await cf.initialize() 47 | 48 | if nodb: 49 | user_db = db.DummyUserDbConn() 50 | else: 51 | user_db = db.UserDbConn(constants.USER_DB_FILE_PATH) 52 | 53 | cache_db = db.CacheDbConn(constants.CACHE_DB_FILE_PATH) 54 | cache2 = cache_system2.CacheSystem(cache_db) 55 | await cache2.run() 56 | 57 | try: 58 | with open(constants.CONTEST_WRITERS_JSON_FILE_PATH) as f: 59 | data = json.load(f) 60 | _contest_id_to_writers_map = { 61 | contest['id']: [s.lower() for s in contest['writers']] for contest in data 62 | } 63 | logger.info('Contest writers loaded from JSON file') 64 | except FileNotFoundError: 65 | logger.warning('JSON file containing contest writers not found') 66 | 67 | _initialize_done = True 68 | 69 | 70 | # algmyr's guard idea: 71 | def user_guard(*, group, get_exception=None): 72 | active = active_groups[group] 73 | 74 | def guard(fun): 75 | @functools.wraps(fun) 76 | async def f(self, ctx, *args, **kwargs): 77 | user = ctx.message.author.id 78 | if user in active: 79 | logger.info(f'{user} repeatedly calls {group} group') 80 | if get_exception is not None: 81 | raise get_exception() 82 | return 83 | active.add(user) 84 | try: 85 | await fun(self, ctx, *args, **kwargs) 86 | finally: 87 | active.remove(user) 88 | 89 | return f 90 | 91 | return guard 92 | 93 | 94 | def is_contest_writer(contest_id, handle): 95 | if _contest_id_to_writers_map is None: 96 | return False 97 | writers = _contest_id_to_writers_map.get(contest_id) 98 | return writers and handle.lower() in writers 99 | 100 | 101 | _NONSTANDARD_CONTEST_INDICATORS = [ 102 | 'wild', 103 | 'fools', 104 | 'unrated', 105 | 'surprise', 106 | 'unknown', 107 | 'friday', 108 | 'q#', 109 | 'testing', 110 | 'marathon', 111 | 'kotlin', 112 | 'onsite', 113 | 'experimental', 114 | 'abbyy', 115 | 'icpc', 116 | ] 117 | 118 | 119 | def is_nonstandard_contest(contest): 120 | return any( 121 | string in contest.name.lower() for string in _NONSTANDARD_CONTEST_INDICATORS 122 | ) 123 | 124 | 125 | def is_nonstandard_problem(problem): 126 | return is_nonstandard_contest( 127 | cache2.contest_cache.get_contest(problem.contestId) 128 | ) or problem.matches_all_tags(['*special']) 129 | 130 | 131 | async def get_visited_contests(handles: [str]): 132 | """Returns a set of contest ids of contests that any of the given handles 133 | has at least one non-CE submission. 134 | """ 135 | user_submissions = [await cf.user.status(handle=handle) for handle in handles] 136 | problem_to_contests = cache2.problemset_cache.problem_to_contests 137 | 138 | contest_ids = [] 139 | for sub in itertools.chain.from_iterable(user_submissions): 140 | if sub.verdict == 'COMPILATION_ERROR': 141 | continue 142 | try: 143 | contest = cache2.contest_cache.get_contest(sub.problem.contestId) 144 | problem_id = (sub.problem.name, contest.startTimeSeconds) 145 | contest_ids += problem_to_contests[problem_id] 146 | except cache_system2.ContestNotFound: 147 | pass 148 | return set(contest_ids) 149 | 150 | 151 | # These are special rated-for-all contests which have a combined ranklist for 152 | # onsite and online participants. The onsite participants have their 153 | # submissions marked as out of competition. Just Codeforces things. 154 | _RATED_FOR_ONSITE_CONTEST_IDS = [ 155 | 86, # Yandex.Algorithm 2011 Round 2 https://codeforces.com/contest/86 156 | 173, # Croc Champ 2012 - Round 1 https://codeforces.com/contest/173 157 | 335, # MemSQL start[c]up Round 2 - online version https://codeforces.com/contest/335 158 | ] 159 | 160 | 161 | def is_rated_for_onsite_contest(contest): 162 | return contest.id in _RATED_FOR_ONSITE_CONTEST_IDS 163 | 164 | 165 | class ResolveHandleError(commands.CommandError): 166 | pass 167 | 168 | 169 | class HandleCountOutOfBoundsError(ResolveHandleError): 170 | def __init__(self, mincnt, maxcnt): 171 | super().__init__(f'Number of handles must be between {mincnt} and {maxcnt}') 172 | 173 | 174 | class FindMemberFailedError(ResolveHandleError): 175 | def __init__(self, member): 176 | super().__init__(f'Unable to convert `{member}` to a server member') 177 | 178 | 179 | class HandleNotRegisteredError(ResolveHandleError): 180 | def __init__(self, member): 181 | super().__init__( 182 | f'Codeforces handle for {member.mention} not found in database' 183 | ) 184 | 185 | 186 | class HandleIsVjudgeError(ResolveHandleError): 187 | HANDLES = """ 188 | vjudge1 vjudge2 vjudge3 vjudge4 vjudge5 189 | luogu_bot1 luogu_bot2 luogu_bot3 luogu_bot4 luogu_bot5 190 | """.split() 191 | 192 | def __init__(self, handle): 193 | super().__init__(f"`{handle}`? I'm not doing that!\n\n(╯°□°)╯︵ ┻━┻") 194 | 195 | 196 | class FilterError(commands.CommandError): 197 | pass 198 | 199 | 200 | class ParamParseError(FilterError): 201 | pass 202 | 203 | 204 | def time_format(seconds): 205 | seconds = int(seconds) 206 | days, seconds = divmod(seconds, 86400) 207 | hours, seconds = divmod(seconds, 3600) 208 | minutes, seconds = divmod(seconds, 60) 209 | return days, hours, minutes, seconds 210 | 211 | 212 | def pretty_time_format( 213 | seconds, *, shorten=False, only_most_significant=False, always_seconds=False 214 | ): 215 | days, hours, minutes, seconds = time_format(seconds) 216 | timespec = [ 217 | (days, 'day', 'days'), 218 | (hours, 'hour', 'hours'), 219 | (minutes, 'minute', 'minutes'), 220 | ] 221 | timeprint = [(cnt, singular, plural) for cnt, singular, plural in timespec if cnt] 222 | if not timeprint or always_seconds: 223 | timeprint.append((seconds, 'second', 'seconds')) 224 | if only_most_significant: 225 | timeprint = [timeprint[0]] 226 | 227 | def format_(triple): 228 | cnt, singular, plural = triple 229 | return ( 230 | f'{cnt}{singular[0]}' 231 | if shorten 232 | else f'{cnt} {singular if cnt == 1 else plural}' 233 | ) 234 | 235 | return ' '.join(map(format_, timeprint)) 236 | 237 | 238 | def days_ago(t): 239 | days = (time.time() - t) / (60 * 60 * 24) 240 | if days < 1: 241 | return 'today' 242 | if days < 2: 243 | return 'yesterday' 244 | return f'{math.floor(days)} days ago' 245 | 246 | 247 | async def resolve_handles( 248 | ctx, converter, handles, *, mincnt=1, maxcnt=5, default_to_all_server=False 249 | ): 250 | """Convert an iterable of strings to CF handles. 251 | 252 | A string beginning with ! indicates Discord username, otherwise it is a raw 253 | CF handle to be left unchanged. 254 | """ 255 | handles = set(handles) 256 | if default_to_all_server and not handles: 257 | handles.add('+server') 258 | if '+server' in handles: 259 | handles.remove('+server') 260 | guild_handles = { 261 | handle for discord_id, handle in user_db.get_handles_for_guild(ctx.guild.id) 262 | } 263 | handles.update(guild_handles) 264 | if len(handles) < mincnt or (maxcnt and maxcnt < len(handles)): 265 | raise HandleCountOutOfBoundsError(mincnt, maxcnt) 266 | resolved_handles = [] 267 | for handle in handles: 268 | if handle.startswith('!'): 269 | # ! denotes Discord user 270 | member_identifier = handle[1:] 271 | # suffix removal as quickfix for new username changes 272 | if member_identifier[-2:] == '#0': 273 | member_identifier = member_identifier[:-2] 274 | 275 | try: 276 | member = await converter.convert(ctx, member_identifier) 277 | except commands.errors.CommandError: 278 | raise FindMemberFailedError(member_identifier) 279 | handle = user_db.get_handle(member.id, ctx.guild.id) 280 | if handle is None: 281 | raise HandleNotRegisteredError(member) 282 | if handle in HandleIsVjudgeError.HANDLES: 283 | raise HandleIsVjudgeError(handle) 284 | resolved_handles.append(handle) 285 | return resolved_handles 286 | 287 | 288 | def members_to_handles(members: [discord.Member], guild_id): 289 | handles = [] 290 | for member in members: 291 | handle = user_db.get_handle(member.id, guild_id) 292 | if handle is None: 293 | raise HandleNotRegisteredError(member) 294 | handles.append(handle) 295 | return handles 296 | 297 | 298 | def filter_flags(args, params): 299 | args = list(args) 300 | flags = [False] * len(params) 301 | rest = [] 302 | for arg in args: 303 | try: 304 | flags[params.index(arg)] = True 305 | except ValueError: 306 | rest.append(arg) 307 | return flags, rest 308 | 309 | 310 | def negate_flags(*args): 311 | return [not x for x in args] 312 | 313 | 314 | def parse_date(arg): 315 | try: 316 | if len(arg) == 8: 317 | fmt = '%d%m%Y' 318 | elif len(arg) == 6: 319 | fmt = '%m%Y' 320 | elif len(arg) == 4: 321 | fmt = '%Y' 322 | else: 323 | raise ValueError 324 | return time.mktime(datetime.datetime.strptime(arg, fmt).timetuple()) 325 | except ValueError: 326 | raise ParamParseError(f'{arg} is an invalid date argument') 327 | 328 | 329 | def parse_tags(args, *, prefix): 330 | tags = [x[1:] for x in args if x[0] == prefix] 331 | return tags 332 | 333 | 334 | def parse_rating(args, default_value=None): 335 | for arg in args: 336 | if arg.isdigit(): 337 | return int(arg) 338 | return default_value 339 | 340 | 341 | def fix_urls(user: cf.User): 342 | if user.titlePhoto.startswith('//'): 343 | user = user._replace(titlePhoto='https:' + user.titlePhoto) 344 | return user 345 | 346 | 347 | class SubFilter: 348 | def __init__(self, rated=True): 349 | self.team = False 350 | self.rated = rated 351 | self.dlo, self.dhi = 0, 10**10 352 | self.rlo, self.rhi = 500, 3800 353 | self.types = [] 354 | self.tags = [] 355 | self.bantags = [] 356 | self.contests = [] 357 | self.indices = [] 358 | 359 | def parse(self, args): 360 | args = list(set(args)) 361 | rest = [] 362 | 363 | for arg in args: 364 | if arg == '+team': 365 | self.team = True 366 | elif arg == '+contest': 367 | self.types.append('CONTESTANT') 368 | elif arg == '+outof': 369 | self.types.append('OUT_OF_COMPETITION') 370 | elif arg == '+virtual': 371 | self.types.append('VIRTUAL') 372 | elif arg == '+practice': 373 | self.types.append('PRACTICE') 374 | elif arg[0:2] == 'c+': 375 | self.contests.append(arg[2:]) 376 | elif arg[0:2] == 'i+': 377 | self.indices.append(arg[2:]) 378 | elif arg[0] == '+': 379 | if len(arg) == 1: 380 | raise ParamParseError('Problem tag cannot be empty.') 381 | self.tags.append(arg[1:]) 382 | elif arg[0] == '~': 383 | if len(arg) == 1: 384 | raise ParamParseError('Problem tag cannot be empty.') 385 | self.bantags.append(arg[1:]) 386 | elif arg[0:2] == 'd<': 387 | self.dhi = min(self.dhi, parse_date(arg[2:])) 388 | elif arg[0:3] == 'd>=': 389 | self.dlo = max(self.dlo, parse_date(arg[3:])) 390 | elif arg[0:3] in ['r<=', 'r>=']: 391 | if len(arg) < 4: 392 | raise ParamParseError(f'{arg} is an invalid rating argument') 393 | elif arg[1] == '>': 394 | self.rlo = max(self.rlo, int(arg[3:])) 395 | else: 396 | self.rhi = min(self.rhi, int(arg[3:])) 397 | self.rated = True 398 | else: 399 | rest.append(arg) 400 | 401 | self.types = self.types or [ 402 | 'CONTESTANT', 403 | 'OUT_OF_COMPETITION', 404 | 'VIRTUAL', 405 | 'PRACTICE', 406 | ] 407 | return rest 408 | 409 | @staticmethod 410 | def filter_solved(submissions): 411 | """Filters and keeps only solved submissions. 412 | 413 | If a problem is solved multiple times the first accepted submission is 414 | kept. The unique id for a problem is 415 | (problem name, contest start time). 416 | """ 417 | submissions.sort(key=lambda sub: sub.creationTimeSeconds) 418 | problems = set() 419 | solved_subs = [] 420 | 421 | for submission in submissions: 422 | problem = submission.problem 423 | contest = cache2.contest_cache.contest_by_id.get(problem.contestId, None) 424 | if submission.verdict == 'OK': 425 | # Assume (name, contest start time) is a unique identifier for problems 426 | problem_key = (problem.name, contest.startTimeSeconds if contest else 0) 427 | if problem_key not in problems: 428 | solved_subs.append(submission) 429 | problems.add(problem_key) 430 | return solved_subs 431 | 432 | def filter_subs(self, submissions): 433 | submissions = SubFilter.filter_solved(submissions) 434 | filtered_subs = [] 435 | for submission in submissions: 436 | problem = submission.problem 437 | contest = cache2.contest_cache.contest_by_id.get(problem.contestId, None) 438 | type_ok = submission.author.participantType in self.types 439 | date_ok = self.dlo <= submission.creationTimeSeconds < self.dhi 440 | tag_ok = problem.matches_all_tags(self.tags) 441 | bantag_ok = not problem.matches_any_tag(self.bantags) 442 | index_ok = not self.indices or any( 443 | index.lower() == problem.index.lower() for index in self.indices 444 | ) 445 | contest_ok = not self.contests or ( 446 | contest and contest.matches(self.contests) 447 | ) 448 | team_ok = self.team or len(submission.author.members) == 1 449 | if self.rated: 450 | problem_ok = ( 451 | contest 452 | and contest.id < cf.GYM_ID_THRESHOLD 453 | and not is_nonstandard_problem(problem) 454 | ) 455 | rating_ok = problem.rating and self.rlo <= problem.rating <= self.rhi 456 | else: 457 | # acmsguru and gym allowed 458 | problem_ok = ( 459 | not contest 460 | or contest.id >= cf.GYM_ID_THRESHOLD 461 | or not is_nonstandard_problem(problem) 462 | ) 463 | rating_ok = True 464 | if ( 465 | type_ok 466 | and date_ok 467 | and rating_ok 468 | and tag_ok 469 | and bantag_ok 470 | and team_ok 471 | and problem_ok 472 | and contest_ok 473 | and index_ok 474 | ): 475 | filtered_subs.append(submission) 476 | return filtered_subs 477 | 478 | def filter_rating_changes(self, rating_changes): 479 | rating_changes = [ 480 | change 481 | for change in rating_changes 482 | if self.dlo <= change.ratingUpdateTimeSeconds < self.dhi 483 | ] 484 | return rating_changes 485 | -------------------------------------------------------------------------------- /tle/util/cses_scraper.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | from lxml import html 3 | 4 | 5 | class CSESError(Exception): 6 | pass 7 | 8 | 9 | session = aiohttp.ClientSession() 10 | 11 | 12 | async def _fetch(url): 13 | async with session.get(url) as response: 14 | if response.status != 200: 15 | raise CSESError(f'Bad response from CSES, status code {response.status}') 16 | tree = html.fromstring(await response.read()) 17 | return tree 18 | 19 | 20 | async def get_problems(): 21 | tree = await _fetch('https://cses.fi/problemset/list/') 22 | links = [li.get('href') for li in tree.xpath('//*[@class="task"]/a')] 23 | ids = sorted(int(x.split('/')[-1]) for x in links) 24 | return ids 25 | 26 | 27 | async def get_problem_leaderboard(num): 28 | tree = await _fetch(f'https://cses.fi/problemset/stats/{num}/') 29 | fastest_table, shortest_table = tree.xpath( 30 | '//table[@class!="summary-table" and @class!="bot-killer"]' 31 | ) 32 | 33 | fastest = [a.text for a in fastest_table.xpath('.//a')] 34 | shortest = [a.text for a in shortest_table.xpath('.//a')] 35 | return fastest, shortest 36 | -------------------------------------------------------------------------------- /tle/util/db/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache_db_conn import * 2 | from .user_db_conn import * 3 | -------------------------------------------------------------------------------- /tle/util/db/cache_db_conn.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sqlite3 3 | 4 | from tle.util import codeforces_api as cf 5 | 6 | 7 | class CacheDbConn: 8 | def __init__(self, db_file): 9 | self.conn = sqlite3.connect(db_file) 10 | self.create_tables() 11 | 12 | def create_tables(self): 13 | # Table for contests from the contest.list endpoint. 14 | self.conn.execute( 15 | 'CREATE TABLE IF NOT EXISTS contest (' 16 | 'id INTEGER NOT NULL,' 17 | 'name TEXT,' 18 | 'start_time INTEGER,' 19 | 'duration INTEGER,' 20 | 'type TEXT,' 21 | 'phase TEXT,' 22 | 'prepared_by TEXT,' 23 | 'PRIMARY KEY (id)' 24 | ')' 25 | ) 26 | 27 | # Table for problems from the problemset.problems endpoint. 28 | self.conn.execute( 29 | 'CREATE TABLE IF NOT EXISTS problem (' 30 | 'contest_id INTEGER,' 31 | 'problemset_name TEXT,' 32 | '[index] TEXT,' 33 | 'name TEXT NOT NULL,' 34 | 'type TEXT,' 35 | 'points REAL,' 36 | 'rating INTEGER,' 37 | 'tags TEXT,' 38 | 'PRIMARY KEY (name)' 39 | ')' 40 | ) 41 | 42 | # Table for rating changes fetched from contest.ratingChanges endpoint 43 | # for every contest. 44 | self.conn.execute( 45 | 'CREATE TABLE IF NOT EXISTS rating_change (' 46 | 'contest_id INTEGER NOT NULL,' 47 | 'handle TEXT NOT NULL,' 48 | 'rank INTEGER,' 49 | 'rating_update_time INTEGER,' 50 | 'old_rating INTEGER,' 51 | 'new_rating INTEGER,' 52 | 'UNIQUE (contest_id, handle)' 53 | ')' 54 | ) 55 | self.conn.execute(""" 56 | CREATE INDEX IF NOT EXISTS ix_rating_change_contest_id ON rating_change ( 57 | contest_id 58 | ) 59 | """) 60 | self.conn.execute(""" 61 | CREATE INDEX IF NOT EXISTS ix_rating_change_handle ON rating_change (handle) 62 | """) 63 | 64 | # Table for problems fetched from contest.standings endpoint for every 65 | # contest. This is separate from table problem as it contains the same 66 | # problem twice if it appeared in both Div 1 and Div 2 of some round. 67 | self.conn.execute(""" 68 | CREATE TABLE IF NOT EXISTS problem2 ( 69 | contest_id INTEGER, 70 | problemset_name TEXT, 71 | [index] TEXT, 72 | name TEXT NOT NULL, 73 | type TEXT, 74 | points REAL, 75 | rating INTEGER, 76 | tags TEXT, 77 | PRIMARY KEY (contest_id, [index]) 78 | ) 79 | """) 80 | self.conn.execute(""" 81 | CREATE INDEX IF NOT EXISTS ix_problem2_contest_id ON problem2 (contest_id) 82 | """) 83 | 84 | def cache_contests(self, contests): 85 | query = """ 86 | INSERT OR REPLACE INTO contest ( 87 | id, name, start_time, duration, type, phase, prepared_by 88 | ) VALUES (?, ?, ?, ?, ?, ?, ?) 89 | """ 90 | rc = self.conn.executemany(query, contests).rowcount 91 | self.conn.commit() 92 | return rc 93 | 94 | def fetch_contests(self): 95 | query = """ 96 | SELECT id, name, start_time, duration, type, phase, prepared_by FROM contest 97 | """ 98 | res = self.conn.execute(query).fetchall() 99 | return [cf.Contest._make(contest) for contest in res] 100 | 101 | @staticmethod 102 | def _squish_tags(problem): 103 | return ( 104 | problem.contestId, 105 | problem.problemsetName, 106 | problem.index, 107 | problem.name, 108 | problem.type, 109 | problem.points, 110 | problem.rating, 111 | json.dumps(problem.tags), 112 | ) 113 | 114 | def cache_problems(self, problems): 115 | query = """ 116 | INSERT OR REPLACE INTO problem ( 117 | contest_id, problemset_name, [index], name, type, points, rating, tags 118 | ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 119 | """ 120 | rc = self.conn.executemany( 121 | query, list(map(self._squish_tags, problems)) 122 | ).rowcount 123 | self.conn.commit() 124 | return rc 125 | 126 | @staticmethod 127 | def _unsquish_tags(problem): 128 | args, tags = problem[:-1], json.loads(problem[-1]) 129 | return cf.Problem(*args, tags) 130 | 131 | def fetch_problems(self): 132 | query = """ 133 | SELECT 134 | contest_id, problemset_name, [index], name, type, points, rating, tags 135 | FROM problem 136 | """ 137 | res = self.conn.execute(query).fetchall() 138 | return list(map(self._unsquish_tags, res)) 139 | 140 | def save_rating_changes(self, changes): 141 | change_tuples = [ 142 | ( 143 | change.contestId, 144 | change.handle, 145 | change.rank, 146 | change.ratingUpdateTimeSeconds, 147 | change.oldRating, 148 | change.newRating, 149 | ) 150 | for change in changes 151 | ] 152 | query = """ 153 | INSERT OR REPLACE INTO rating_change ( 154 | contest_id, handle, rank, rating_update_time, old_rating, new_rating 155 | ) VALUES (?, ?, ?, ?, ?, ?) 156 | """ 157 | rc = self.conn.executemany(query, change_tuples).rowcount 158 | self.conn.commit() 159 | return rc 160 | 161 | def clear_rating_changes(self, contest_id=None): 162 | if contest_id is None: 163 | query = 'DELETE FROM rating_change' 164 | self.conn.execute(query) 165 | else: 166 | query = 'DELETE FROM rating_change WHERE contest_id = ?' 167 | self.conn.execute(query, (contest_id,)) 168 | self.conn.commit() 169 | 170 | def get_users_with_more_than_n_contests(self, time_cutoff, n): 171 | query = """ 172 | SELECT 173 | handle, 174 | COUNT(*) AS num_contests 175 | FROM rating_change 176 | GROUP BY handle 177 | HAVING num_contests >= ? AND MAX(rating_update_time) >= ? 178 | """ 179 | res = self.conn.execute( 180 | query, 181 | ( 182 | n, 183 | time_cutoff, 184 | ), 185 | ).fetchall() 186 | return [user[0] for user in res] 187 | 188 | def get_all_rating_changes(self): 189 | query = """ 190 | SELECT 191 | contest_id, 192 | name, 193 | handle, 194 | rank, 195 | rating_update_time, 196 | old_rating, 197 | new_rating 198 | FROM rating_change r 199 | LEFT JOIN contest c ON r.contest_id = c.id 200 | ORDER BY rating_update_time 201 | """ 202 | res = self.conn.execute(query) 203 | return (cf.RatingChange._make(change) for change in res) 204 | 205 | def get_rating_changes_for_contest(self, contest_id): 206 | query = """ 207 | SELECT 208 | contest_id, 209 | name, 210 | handle, 211 | rank, 212 | rating_update_time, 213 | old_rating, 214 | new_rating 215 | FROM rating_change r 216 | LEFT JOIN contest c ON r.contest_id = c.id 217 | WHERE r.contest_id = ? 218 | """ 219 | res = self.conn.execute(query, (contest_id,)).fetchall() 220 | return [cf.RatingChange._make(change) for change in res] 221 | 222 | def has_rating_changes_saved(self, contest_id): 223 | query = 'SELECT contest_id FROM rating_change WHERE contest_id = ?' 224 | res = self.conn.execute(query, (contest_id,)).fetchone() 225 | return res is not None 226 | 227 | def get_rating_changes_for_handle(self, handle): 228 | query = """ 229 | SELECT 230 | contest_id, 231 | name, 232 | handle, 233 | rank, 234 | rating_update_time, 235 | old_rating, 236 | new_rating 237 | FROM rating_change r 238 | LEFT JOIN contest c ON r.contest_id = c.id 239 | WHERE r.handle = ? 240 | """ 241 | res = self.conn.execute(query, (handle,)).fetchall() 242 | return [cf.RatingChange._make(change) for change in res] 243 | 244 | def cache_problemset(self, problemset): 245 | query = """ 246 | INSERT OR REPLACE INTO problem2 ( 247 | contest_id, problemset_name, [index], name, type, points, rating, tags 248 | ) 249 | VALUES (?, ?, ?, ?, ?, ?, ?, ?) 250 | """ 251 | rc = self.conn.executemany( 252 | query, list(map(self._squish_tags, problemset)) 253 | ).rowcount 254 | self.conn.commit() 255 | return rc 256 | 257 | def fetch_problems2(self): 258 | query = """ 259 | SELECT 260 | contest_id, problemset_name, [index], name, type, points, rating, tags 261 | FROM problem2 262 | """ 263 | res = self.conn.execute(query).fetchall() 264 | return list(map(self._unsquish_tags, res)) 265 | 266 | def clear_problemset(self, contest_id=None): 267 | if contest_id is None: 268 | query = 'DELETE FROM problem2' 269 | self.conn.execute(query) 270 | else: 271 | query = 'DELETE FROM problem2 WHERE contest_id = ?' 272 | self.conn.execute(query, (contest_id,)) 273 | 274 | def fetch_problemset(self, contest_id): 275 | query = """ 276 | SELECT 277 | contest_id, problemset_name, [index], name, type, points, rating, tags 278 | FROM problem2 279 | WHERE contest_id = ? 280 | """ 281 | res = self.conn.execute(query, (contest_id,)).fetchall() 282 | return list(map(self._unsquish_tags, res)) 283 | 284 | def problemset_empty(self): 285 | query = 'SELECT 1 FROM problem2' 286 | res = self.conn.execute(query).fetchone() 287 | return res is None 288 | 289 | def close(self): 290 | self.conn.close() 291 | -------------------------------------------------------------------------------- /tle/util/discord_common.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import logging 4 | import random 5 | 6 | import discord 7 | from discord.ext import commands 8 | 9 | from tle import constants 10 | from tle.util import codeforces_api as cf, db, tasks 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | _CF_COLORS = (0xFFCA1F, 0x198BCC, 0xFF2020) 15 | _SUCCESS_GREEN = 0x28A745 16 | _ALERT_AMBER = 0xFFBF00 17 | 18 | 19 | def embed_neutral(desc, color=discord.Embed.Empty): 20 | return discord.Embed(description=str(desc), color=color) 21 | 22 | 23 | def embed_success(desc): 24 | return discord.Embed(description=str(desc), color=_SUCCESS_GREEN) 25 | 26 | 27 | def embed_alert(desc): 28 | return discord.Embed(description=str(desc), color=_ALERT_AMBER) 29 | 30 | 31 | def random_cf_color(): 32 | return random.choice(_CF_COLORS) 33 | 34 | 35 | def cf_color_embed(**kwargs): 36 | return discord.Embed(**kwargs, color=random_cf_color()) 37 | 38 | 39 | def set_same_cf_color(embeds): 40 | color = random_cf_color() 41 | for embed in embeds: 42 | embed.color = color 43 | 44 | 45 | def attach_image(embed, img_file): 46 | embed.set_image(url=f'attachment://{img_file.filename}') 47 | 48 | 49 | def set_author_footer(embed, user): 50 | embed.set_footer(text=f'Requested by {user}', icon_url=user.avatar_url) 51 | 52 | 53 | def send_error_if(*error_cls): 54 | """Decorator for `cog_command_error` methods. 55 | 56 | Decorated methods send the error in an alert embed when the error is an 57 | instance of one of the specified errors, otherwise the wrapped function is 58 | invoked. 59 | """ 60 | 61 | def decorator(func): 62 | @functools.wraps(func) 63 | async def wrapper(cog, ctx, error): 64 | if isinstance(error, error_cls): 65 | await ctx.send(embed=embed_alert(error)) 66 | error.handled = True 67 | else: 68 | await func(cog, ctx, error) 69 | 70 | return wrapper 71 | 72 | return decorator 73 | 74 | 75 | async def bot_error_handler(ctx, exception): 76 | if getattr(exception, 'handled', False): 77 | # Errors already handled in cogs should have .handled = True 78 | return 79 | 80 | if isinstance(exception, db.DatabaseDisabledError): 81 | await ctx.send( 82 | embed=embed_alert( 83 | 'Sorry, the database is not available. Some features are disabled.' 84 | ) 85 | ) 86 | elif isinstance(exception, commands.NoPrivateMessage): 87 | await ctx.send(embed=embed_alert('Commands are disabled in private channels')) 88 | elif isinstance(exception, commands.DisabledCommand): 89 | await ctx.send(embed=embed_alert('Sorry, this command is temporarily disabled')) 90 | elif isinstance(exception, (cf.CodeforcesApiError, commands.UserInputError)): 91 | await ctx.send(embed=embed_alert(exception)) 92 | else: 93 | msg = 'Ignoring exception in command {}:'.format(ctx.command) 94 | exc_info = type(exception), exception, exception.__traceback__ 95 | extra = { 96 | 'message_content': ctx.message.content, 97 | 'jump_url': ctx.message.jump_url, 98 | } 99 | logger.exception(msg, exc_info=exc_info, extra=extra) 100 | 101 | 102 | def once(func): 103 | """Decorator that wraps a corouting asuch that it is executed only once.""" 104 | first = True 105 | 106 | @functools.wraps(func) 107 | async def wrapper(*args, **kwargs): 108 | nonlocal first 109 | if first: 110 | first = False 111 | await func(*args, **kwargs) 112 | 113 | return wrapper 114 | 115 | 116 | def on_ready_event_once(bot): 117 | """Decorator to run a corouting only once when the bot is ready.""" 118 | 119 | def register_on_ready(func): 120 | @bot.event 121 | @once 122 | async def on_ready(): 123 | await func() 124 | 125 | return register_on_ready 126 | 127 | 128 | async def presence(bot): 129 | await bot.change_presence( 130 | activity=discord.Activity( 131 | type=discord.ActivityType.listening, name='your commands' 132 | ) 133 | ) 134 | await asyncio.sleep(60) 135 | 136 | @tasks.task(name='OrzUpdate', waiter=tasks.Waiter.fixed_delay(5 * 60)) 137 | async def presence_task(_): 138 | while True: 139 | target = random.choice( 140 | [ 141 | member 142 | for member in bot.get_all_members() 143 | if constants.TLE_PURGATORY 144 | not in {role.name for role in member.roles} 145 | ] 146 | ) 147 | await bot.change_presence( 148 | activity=discord.Game(name=f'{target.display_name} orz') 149 | ) 150 | await asyncio.sleep(10 * 60) 151 | 152 | presence_task.start() 153 | -------------------------------------------------------------------------------- /tle/util/events.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from discord.ext import commands 5 | 6 | # Event types 7 | 8 | 9 | class Event: 10 | """Base class for events.""" 11 | 12 | pass 13 | 14 | 15 | class ContestListRefresh(Event): 16 | def __init__(self, contests): 17 | self.contests = contests 18 | 19 | 20 | class RatingChangesUpdate(Event): 21 | def __init__(self, *, contest, rating_changes): 22 | self.contest = contest 23 | self.rating_changes = rating_changes 24 | 25 | 26 | # Event errors 27 | 28 | 29 | class EventError(commands.CommandError): 30 | pass 31 | 32 | 33 | class ListenerNotRegistered(EventError): 34 | def __init__(self, listener): 35 | super().__init__( 36 | f'Listener {listener.name} is not registered for event' 37 | f' {listener.event_cls.__name__}.' 38 | ) 39 | 40 | 41 | # Event system 42 | 43 | 44 | class EventSystem: 45 | """Rudimentary event system.""" 46 | 47 | def __init__(self): 48 | self.listeners_by_event = {} 49 | self.futures_by_event = {} 50 | self.logger = logging.getLogger(self.__class__.__name__) 51 | 52 | def add_listener(self, listener): 53 | listeners = self.listeners_by_event.setdefault(listener.event_cls, set()) 54 | listeners.add(listener) 55 | 56 | def remove_listener(self, listener): 57 | try: 58 | self.listeners_by_event[listener.event_cls].remove(listener) 59 | except KeyError: 60 | raise ListenerNotRegistered(listener) 61 | 62 | async def wait_for(self, event_cls, *, timeout=None): 63 | future = asyncio.get_running_loop().create_future() 64 | futures = self.futures_by_event.setdefault(event_cls, []) 65 | futures.append(future) 66 | return await asyncio.wait_for(future, timeout) 67 | 68 | def dispatch(self, event_cls, *args, **kwargs): 69 | self.logger.info(f'Dispatching event `{event_cls.__name__}`') 70 | event = event_cls(*args, **kwargs) 71 | for listener in self.listeners_by_event.get(event_cls, []): 72 | listener.trigger(event) 73 | futures = self.futures_by_event.pop(event_cls, []) 74 | for future in futures: 75 | if not future.done(): 76 | future.set_result(event) 77 | 78 | 79 | # Listener 80 | 81 | 82 | def _ensure_coroutine_func(func): 83 | if not asyncio.iscoroutinefunction(func): 84 | raise TypeError('The listener function must be a coroutine function.') 85 | 86 | 87 | class Listener: 88 | """A listener for a particular event. 89 | 90 | A listener must have a name, the event it should listen to and a coroutine 91 | function `func` that is called when the event is dispatched. 92 | """ 93 | 94 | def __init__(self, name, event_cls, func, *, with_lock=False): 95 | """Initialize the listener. 96 | 97 | `with_lock` controls whether execution of `func` should be guarded by 98 | an asyncio.Lock. 99 | """ 100 | _ensure_coroutine_func(func) 101 | self.name = name 102 | self.event_cls = event_cls 103 | self.func = func 104 | self.lock = asyncio.Lock() if with_lock else None 105 | self.logger = logging.getLogger(self.__class__.__name__) 106 | 107 | def trigger(self, event): 108 | asyncio.create_task(self._trigger(event)) 109 | 110 | async def _trigger(self, event): 111 | try: 112 | if self.lock: 113 | async with self.lock: 114 | await self.func(event) 115 | else: 116 | await self.func(event) 117 | except asyncio.CancelledError: 118 | raise 119 | except: 120 | self.logger.exception(f'Exception in listener `{self.name}`.') 121 | 122 | def __eq__(self, other): 123 | return isinstance(other, Listener) and (self.event_cls, self.func) == ( 124 | other.event_cls, 125 | other.func, 126 | ) 127 | 128 | def __hash__(self): 129 | return hash((self.event_cls, self.func)) 130 | 131 | 132 | class ListenerSpec: 133 | """A descriptor intended to be an interface between an instance and its listeners. 134 | 135 | It creates the expected listener when `__get__` is called from an instance 136 | for the first time. No two listener specs in the same class should have the 137 | same name. 138 | """ 139 | 140 | def __init__(self, name, event_cls, func, *, with_lock=False): 141 | """Initialize the listener spec. 142 | 143 | `with_lock` controls whether execution of `func` should be guarded by 144 | an asyncio.Lock. 145 | """ 146 | _ensure_coroutine_func(func) 147 | self.name = name 148 | self.event_cls = event_cls 149 | self.func = func 150 | self.with_lock = with_lock 151 | 152 | def __get__(self, instance, owner): 153 | if instance is None: 154 | return self 155 | try: 156 | listeners = instance.___listeners___ 157 | except AttributeError: 158 | listeners = instance.___listeners___ = {} 159 | if self.name not in listeners: 160 | # In Python <=3.7 iscoroutinefunction returns False for async 161 | # functions wrapped by functools.partial. 162 | # TODO: Use functools.partial when we move to Python 3.8. 163 | async def wrapper(event): 164 | return await self.func(instance, event) 165 | 166 | listeners[self.name] = Listener( 167 | self.name, self.event_cls, wrapper, with_lock=self.with_lock 168 | ) 169 | return listeners[self.name] 170 | 171 | 172 | def listener(*, name, event_cls, with_lock=False): 173 | """Returns a decorator that creates a `Listener` with the given options.""" 174 | 175 | def decorator(func): 176 | return Listener(name, event_cls, func, with_lock=with_lock) 177 | 178 | return decorator 179 | 180 | 181 | def listener_spec(*, name, event_cls, with_lock=False): 182 | """Returns a decorator that creates a `ListenerSpec` with the given options.""" 183 | 184 | def decorator(func): 185 | return ListenerSpec(name, event_cls, func, with_lock=with_lock) 186 | 187 | return decorator 188 | -------------------------------------------------------------------------------- /tle/util/font_downloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib.request 4 | from io import BytesIO 5 | from zipfile import ZipFile 6 | 7 | from tle import constants 8 | 9 | URL_BASE = 'https://noto-website-2.storage.googleapis.com/pkgs/' 10 | FONTS = [ 11 | constants.NOTO_SANS_CJK_BOLD_FONT_PATH, 12 | constants.NOTO_SANS_CJK_REGULAR_FONT_PATH, 13 | ] 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def _unzip(font, archive): 19 | with ZipFile(archive) as zipfile: 20 | if font not in zipfile.namelist(): 21 | raise KeyError( 22 | f'Expected font file {font} not present in downloaded zip archive.' 23 | ) 24 | zipfile.extract(font, constants.FONTS_DIR) 25 | 26 | 27 | def _download(font_path): 28 | font = os.path.basename(font_path) 29 | logger.info(f'Downloading font `{font}`.') 30 | with urllib.request.urlopen(f'{URL_BASE}{font}.zip') as resp: 31 | _unzip(font, BytesIO(resp.read())) 32 | 33 | 34 | def maybe_download(): 35 | for font_path in FONTS: 36 | if not os.path.isfile(font_path): 37 | _download(font_path) 38 | -------------------------------------------------------------------------------- /tle/util/graph_common.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | 5 | import discord 6 | import matplotlib 7 | import matplotlib.font_manager 8 | 9 | matplotlib.use('agg') # Explicitly set the backend to avoid issues 10 | 11 | from cycler import cycler 12 | from matplotlib import pyplot as plt 13 | 14 | from tle import constants 15 | 16 | rating_color_cycler = cycler( 17 | 'color', ['#5d4dff', '#009ccc', '#00ba6a', '#b99d27', '#cb2aff'] 18 | ) 19 | 20 | fontprop = matplotlib.font_manager.FontProperties( 21 | fname=constants.NOTO_SANS_CJK_REGULAR_FONT_PATH 22 | ) 23 | 24 | 25 | # String wrapper to avoid the underscore behavior in legends 26 | # 27 | # In legends, matplotlib ignores labels that begin with _ 28 | # https://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.legend 29 | # However, this check is only done for actual string objects. 30 | class StrWrap: 31 | def __init__(self, s): 32 | self.string = s 33 | 34 | def __str__(self): 35 | return self.string 36 | 37 | 38 | def get_current_figure_as_file(): 39 | filename = os.path.join(constants.TEMP_DIR, f'tempplot_{time.time()}.png') 40 | plt.savefig( 41 | filename, 42 | facecolor=plt.gca().get_facecolor(), 43 | bbox_inches='tight', 44 | pad_inches=0.25, 45 | ) 46 | 47 | with open(filename, 'rb') as file: 48 | discord_file = discord.File(io.BytesIO(file.read()), filename='plot.png') 49 | 50 | os.remove(filename) 51 | return discord_file 52 | 53 | 54 | def plot_rating_bg(ranks): 55 | ymin, ymax = plt.gca().get_ylim() 56 | bgcolor = plt.gca().get_facecolor() 57 | for rank in ranks: 58 | plt.axhspan( 59 | rank.low, 60 | rank.high, 61 | facecolor=rank.color_graph, 62 | alpha=0.8, 63 | edgecolor=bgcolor, 64 | linewidth=0.5, 65 | ) 66 | 67 | locs, labels = plt.xticks() 68 | for loc in locs: 69 | plt.axvline(loc, color=bgcolor, linewidth=0.5) 70 | plt.ylim(ymin, ymax) 71 | -------------------------------------------------------------------------------- /tle/util/handledict.py: -------------------------------------------------------------------------------- 1 | class HandleDict: 2 | """A case insensitive dictionary for handling usernames.""" 3 | 4 | def __init__(self): 5 | self._store = {} 6 | 7 | @staticmethod 8 | def _getlower(key): 9 | return key.lower() if isinstance(key, str) else key 10 | 11 | def __setitem__(self, key, value): 12 | # Use the lowercased key for lookups, but store the actual 13 | # key alongside the value. 14 | self._store[self._getlower(key)] = (key, value) 15 | 16 | def __getitem__(self, key): 17 | return self._store[self._getlower(key)][1] 18 | 19 | def __delitem__(self, key): 20 | del self._store[self._getlower(key)] 21 | 22 | def __iter__(self): 23 | return (cased_key for cased_key, mapped_value in self._store.values()) 24 | 25 | def items(self): 26 | return dict([value for value in self._store.values()]).items() 27 | 28 | def __repr__(self): 29 | return str(self.items()) 30 | -------------------------------------------------------------------------------- /tle/util/paginator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | 4 | _REACT_FIRST = '\N{BLACK LEFT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}' 5 | _REACT_PREV = '\N{BLACK LEFT-POINTING TRIANGLE}' 6 | _REACT_NEXT = '\N{BLACK RIGHT-POINTING TRIANGLE}' 7 | _REACT_LAST = '\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}' 8 | 9 | 10 | def chunkify(sequence, chunk_size): 11 | """Utility method to split a sequence into fixed size chunks.""" 12 | return [sequence[i : i + chunk_size] for i in range(0, len(sequence), chunk_size)] 13 | 14 | 15 | class PaginatorError(Exception): 16 | pass 17 | 18 | 19 | class NoPagesError(PaginatorError): 20 | pass 21 | 22 | 23 | class InsufficientPermissionsError(PaginatorError): 24 | pass 25 | 26 | 27 | class Paginated: 28 | def __init__(self, pages): 29 | self.pages = pages 30 | self.cur_page = None 31 | self.message = None 32 | self.reaction_map = { 33 | _REACT_FIRST: functools.partial(self.show_page, 1), 34 | _REACT_PREV: self.prev_page, 35 | _REACT_NEXT: self.next_page, 36 | _REACT_LAST: functools.partial(self.show_page, len(pages)), 37 | } 38 | 39 | async def show_page(self, page_num): 40 | if 1 <= page_num <= len(self.pages): 41 | content, embed = self.pages[page_num - 1] 42 | await self.message.edit(content=content, embed=embed) 43 | self.cur_page = page_num 44 | 45 | async def prev_page(self): 46 | await self.show_page(self.cur_page - 1) 47 | 48 | async def next_page(self): 49 | await self.show_page(self.cur_page + 1) 50 | 51 | async def paginate(self, bot, channel, wait_time, delete_after: float = None): 52 | content, embed = self.pages[0] 53 | self.message = await channel.send( 54 | content, embed=embed, delete_after=delete_after 55 | ) 56 | 57 | if len(self.pages) == 1: 58 | # No need to paginate. 59 | return 60 | 61 | self.cur_page = 1 62 | for react in self.reaction_map.keys(): 63 | await self.message.add_reaction(react) 64 | 65 | def check(reaction, user): 66 | return ( 67 | bot.user != user 68 | and reaction.message.id == self.message.id 69 | and reaction.emoji in self.reaction_map 70 | ) 71 | 72 | while True: 73 | try: 74 | reaction, user = await bot.wait_for( 75 | 'reaction_add', timeout=wait_time, check=check 76 | ) 77 | await reaction.remove(user) 78 | await self.reaction_map[reaction.emoji]() 79 | except asyncio.TimeoutError: 80 | await self.message.clear_reactions() 81 | break 82 | 83 | 84 | def paginate( 85 | bot, 86 | channel, 87 | pages, 88 | *, 89 | wait_time, 90 | set_pagenum_footers=False, 91 | delete_after: float = None, 92 | ): 93 | if not pages: 94 | raise NoPagesError() 95 | permissions = channel.permissions_for(channel.guild.me) 96 | if not permissions.manage_messages: 97 | raise InsufficientPermissionsError('Permission to manage messages required') 98 | if len(pages) > 1 and set_pagenum_footers: 99 | for i, (_content, embed) in enumerate(pages): 100 | embed.set_footer(text=f'Page {i + 1} / {len(pages)}') 101 | paginated = Paginated(pages) 102 | asyncio.create_task(paginated.paginate(bot, channel, wait_time, delete_after)) 103 | -------------------------------------------------------------------------------- /tle/util/ranklist/__init__.py: -------------------------------------------------------------------------------- 1 | from .ranklist import * 2 | -------------------------------------------------------------------------------- /tle/util/ranklist/ranklist.py: -------------------------------------------------------------------------------- 1 | from discord.ext import commands 2 | 3 | from tle.util.codeforces_api import RanklistRow, make_from_dict 4 | from tle.util.handledict import HandleDict 5 | from tle.util.ranklist.rating_calculator import CodeforcesRatingCalculator 6 | 7 | 8 | class RanklistError(commands.CommandError): 9 | def __init__(self, contest, message=None): 10 | if message is not None: 11 | super().__init__(message) 12 | self.contest = contest 13 | 14 | 15 | class ContestNotRatedError(RanklistError): 16 | def __init__(self, contest): 17 | super().__init__(contest, f'`{contest.name}` is not rated') 18 | 19 | 20 | class HandleNotPresentError(RanklistError): 21 | def __init__(self, contest, handle): 22 | super().__init__( 23 | contest, f'Handle `{handle}`` not present in standings of `{contest.name}`' 24 | ) 25 | self.handle = handle 26 | 27 | 28 | class DeltasNotPresentError(RanklistError): 29 | def __init__(self, contest): 30 | super().__init__( 31 | contest, f'Rating changes for `{contest.name}` not calculated or set.' 32 | ) 33 | 34 | 35 | class Ranklist: 36 | def __init__(self, contest, problems, standings, fetch_time, *, is_rated): 37 | self.contest = contest 38 | self.problems = problems 39 | self.standings = standings 40 | self.fetch_time = fetch_time 41 | self.is_rated = is_rated 42 | self.delta_by_handle = None 43 | self.deltas_status = None 44 | self.standing_by_id = None 45 | self._create_inverse_standings() 46 | 47 | def _create_inverse_standings(self): 48 | self.standing_by_id = HandleDict() 49 | for row in self.standings: 50 | id_ = self.get_ranklist_lookup_key(row) 51 | self.standing_by_id[id_] = row 52 | 53 | def remove_unofficial_contestants(self): 54 | """Remove unofficial contestants from the ranklist. 55 | 56 | To be used for cases when official ranklist contains unofficial 57 | contestants Currently this is seen is Educational Contests ranklist 58 | where div1 contestants are marked official in api result 59 | """ 60 | 61 | if self.delta_by_handle is None: 62 | raise DeltasNotPresentError(self.contest) 63 | 64 | official_standings = [] 65 | current_rated_rank = 1 66 | last_rated_rank = 0 67 | last_rated_score = (-1, -1) 68 | for contestant in self.standings: 69 | handle = self.get_ranklist_lookup_key(contestant) 70 | if handle in self.delta_by_handle: 71 | current_score = (contestant.points, contestant.penalty) 72 | standings_row = self.standing_by_id[handle]._asdict() 73 | standings_row['rank'] = ( 74 | current_rated_rank 75 | if current_score != last_rated_score 76 | else last_rated_rank 77 | ) 78 | official_standings.append(make_from_dict(RanklistRow, standings_row)) 79 | last_rated_rank = standings_row['rank'] 80 | last_rated_score = current_score 81 | current_rated_rank += 1 82 | 83 | self.standings = official_standings 84 | self._create_inverse_standings() 85 | 86 | def set_deltas(self, delta_by_handle): 87 | if not self.is_rated: 88 | raise ContestNotRatedError(self.contest) 89 | self.delta_by_handle = delta_by_handle.copy() 90 | self.deltas_status = 'Final' 91 | 92 | def predict(self, current_rating): 93 | if not self.is_rated: 94 | raise ContestNotRatedError(self.contest) 95 | standings = [ 96 | (id_, row.points, row.penalty, current_rating[id_]) 97 | for id_, row in self.standing_by_id.items() 98 | if id_ in current_rating 99 | ] 100 | if standings: 101 | self.delta_by_handle = CodeforcesRatingCalculator( 102 | standings 103 | ).calculate_rating_changes() 104 | self.deltas_status = 'Predicted' 105 | 106 | def get_delta(self, handle): 107 | if not self.is_rated: 108 | raise ContestNotRatedError(self.contest) 109 | if handle not in self.standing_by_id: 110 | raise HandleNotPresentError(self.contest, handle) 111 | return self.delta_by_handle.get(handle) 112 | 113 | def get_standing_row(self, handle): 114 | try: 115 | return self.standing_by_id[handle] 116 | except KeyError: 117 | raise HandleNotPresentError(self.contest, handle) 118 | 119 | @staticmethod 120 | def get_ranklist_lookup_key(contestant): 121 | return contestant.party.teamName or contestant.party.members[0].handle 122 | -------------------------------------------------------------------------------- /tle/util/ranklist/rating_calculator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Codeforces code to recalculate ratings 3 | by Mike Mirzayanov (mirzayanovmr@gmail.com) at https://codeforces.com/contest/1/submission/13861109 4 | Updated to use the current rating formula. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | 9 | import numpy as np 10 | from numpy.fft import fft, ifft 11 | 12 | 13 | def intdiv(x, y): 14 | return -(-x // y) if x < 0 else x // y 15 | 16 | 17 | @dataclass 18 | class Contestant: 19 | party: str 20 | points: float 21 | penalty: int 22 | rating: int 23 | need_rating: int = 0 24 | delta: int = 0 25 | rank: float = 0.0 26 | seed: float = 0.0 27 | 28 | 29 | class CodeforcesRatingCalculator: 30 | """Class to calculate rating changes and seeds based on contest standings.""" 31 | 32 | def __init__(self, standings): 33 | self.contestants = [ 34 | Contestant(handle, points, penalty, rating) 35 | for handle, points, penalty, rating in standings 36 | ] 37 | self._precalc_seed() 38 | self._reassign_ranks() 39 | self._process() 40 | self._update_delta() 41 | 42 | def calculate_rating_changes(self): 43 | """Return a mapping between contestants and their corresponding delta.""" 44 | return {contestant.party: contestant.delta for contestant in self.contestants} 45 | 46 | def get_seed(self, rating, me=None): 47 | """Get seed given a rating and user.""" 48 | seed = self.seed[rating] 49 | if me: 50 | seed -= self.elo_win_prob[rating - me.rating] 51 | return seed 52 | 53 | def _precalc_seed(self): 54 | MAX = 6144 55 | 56 | # Precompute the ELO win probability for all possible rating differences. 57 | self.elo_win_prob = np.roll(1 / (1 + pow(10, np.arange(-MAX, MAX) / 400)), -MAX) 58 | 59 | # Compute the rating histogram. 60 | count = np.zeros(2 * MAX) 61 | for a in self.contestants: 62 | count[a.rating] += 1 63 | 64 | # Precompute the seed for all possible ratings using FFT. 65 | self.seed = 1 + ifft(fft(count) * fft(self.elo_win_prob)).real 66 | 67 | def _reassign_ranks(self): 68 | """Find the rank of each contestant.""" 69 | contestants = self.contestants 70 | contestants.sort(key=lambda o: (-o.points, o.penalty)) 71 | points = penalty = rank = None 72 | for i in reversed(range(len(contestants))): 73 | if contestants[i].points != points or contestants[i].penalty != penalty: 74 | rank = i + 1 75 | points = contestants[i].points 76 | penalty = contestants[i].penalty 77 | contestants[i].rank = rank 78 | 79 | def _process(self): 80 | """Process and assign approximate delta for each contestant.""" 81 | for a in self.contestants: 82 | a.seed = self.get_seed(a.rating, a) 83 | mid_rank = (a.rank * a.seed) ** 0.5 84 | a.need_rating = self._rank_to_rating(mid_rank, a) 85 | a.delta = intdiv(a.need_rating - a.rating, 2) 86 | 87 | def _rank_to_rating(self, rank, me): 88 | """Binary Search to find the performance rating for a given rank.""" 89 | left, right = 1, 8000 90 | while right - left > 1: 91 | mid = (left + right) // 2 92 | if self.get_seed(mid, me) < rank: 93 | right = mid 94 | else: 95 | left = mid 96 | return left 97 | 98 | def _update_delta(self): 99 | """Update the delta of each contestant.""" 100 | contestants = self.contestants 101 | n = len(contestants) 102 | 103 | contestants.sort(key=lambda o: -o.rating) 104 | correction = intdiv(-sum(c.delta for c in contestants), n) - 1 105 | for contestant in contestants: 106 | contestant.delta += correction 107 | 108 | zero_sum_count = min(4 * round(n**0.5), n) 109 | delta_sum = -sum(contestants[i].delta for i in range(zero_sum_count)) 110 | correction = min(0, max(-10, intdiv(delta_sum, zero_sum_count))) 111 | for contestant in contestants: 112 | contestant.delta += correction 113 | -------------------------------------------------------------------------------- /tle/util/table.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | 3 | FULL_WIDTH = 1.66667 4 | WIDTH_MAPPING = {'F': FULL_WIDTH, 'H': 1, 'W': FULL_WIDTH, 'Na': 1, 'N': 1, 'A': 1} 5 | 6 | 7 | def width(s): 8 | return round(sum(WIDTH_MAPPING[unicodedata.east_asian_width(c)] for c in s)) 9 | 10 | 11 | class Content: 12 | def __init__(self, *args): 13 | self.data = args 14 | 15 | def sizes(self): 16 | return [width(str(x)) for x in self.data] 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | 22 | class Header(Content): 23 | def layout(self, style): 24 | return style.format_header(self.data) 25 | 26 | 27 | class Data(Content): 28 | def layout(self, style): 29 | return style.format_body(self.data) 30 | 31 | 32 | class Line: 33 | def __init__(self, c='-'): 34 | self.c = c 35 | 36 | def layout(self, style): 37 | self.data = [''] * style.ncols 38 | return style.format_line(self.c) 39 | 40 | 41 | class Style: 42 | def __init__(self, body, header=None): 43 | self._body = body 44 | self._header = header or body 45 | self.ncols = body.count('}') 46 | 47 | def _pad(self, data, fmt): 48 | S = [] 49 | lastc = None 50 | size = iter(self.sizes) 51 | datum = iter(data) 52 | for c in fmt: 53 | if lastc == ':': 54 | dstr = str(next(datum)) 55 | sz = str(next(size) - (width(dstr) - len(dstr))) 56 | if c in '<>^': 57 | S.append(c + sz) 58 | else: 59 | S.append(sz + c) 60 | else: 61 | S.append(c) 62 | lastc = c 63 | return ''.join(S) 64 | 65 | def format_header(self, data): 66 | return self._pad(data, self._header).format(*data) 67 | 68 | def format_line(self, c): 69 | data = [''] * self.ncols 70 | return self._pad(data, self._header).replace(':', ':' + c).format(*data) 71 | 72 | def format_body(self, data): 73 | return self._pad(data, self._body).format(*data) 74 | 75 | def set_colwidths(self, sizes): 76 | self.sizes = sizes 77 | 78 | 79 | class Table: 80 | def __init__(self, style): 81 | self.style = style 82 | self.rows = [] 83 | 84 | def append(self, row): 85 | self.rows.append(row) 86 | return self 87 | 88 | __add__ = append 89 | 90 | def __repr__(self): 91 | sizes = [row.sizes() for row in self.rows if isinstance(row, Content)] 92 | max_colsize = [max(s[i] for s in sizes) for i in range(self.style.ncols)] 93 | self.style.set_colwidths(max_colsize) 94 | return '\n'.join(row.layout(self.style) for row in self.rows) 95 | 96 | __str__ = __repr__ 97 | -------------------------------------------------------------------------------- /tle/util/tasks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from discord.ext import commands 5 | 6 | import tle.util.codeforces_common as cf_common 7 | 8 | 9 | class TaskError(commands.CommandError): 10 | pass 11 | 12 | 13 | class WaiterRequired(TaskError): 14 | def __init__(self, name): 15 | super().__init__(f'No waiter set for task `{name}`') 16 | 17 | 18 | class TaskAlreadyRunning(TaskError): 19 | def __init__(self, name): 20 | super().__init__(f'Attempt to start task `{name}` which is already running') 21 | 22 | 23 | def _ensure_coroutine_func(func): 24 | if not asyncio.iscoroutinefunction(func): 25 | raise TypeError('The decorated function must be a coroutine function.') 26 | 27 | 28 | class Waiter: 29 | def __init__(self, func, *, run_first=False, needs_instance=False): 30 | """Initializes a waiter with the given coroutine function `func`. 31 | 32 | `run_first` indicates whether this waiter should be run before the task's `func` 33 | when run for the first time. `needs_instance` indicates whether a self argument 34 | is required by the `func`. 35 | """ 36 | _ensure_coroutine_func(func) 37 | self.func = func 38 | self.run_first = run_first 39 | self.needs_instance = needs_instance 40 | 41 | async def wait(self, instance=None): 42 | if self.needs_instance: 43 | return await self.func(instance) 44 | else: 45 | return await self.func() 46 | 47 | @staticmethod 48 | def fixed_delay(delay, run_first=False): 49 | """Returns a waiter that always waits for a fixed time. 50 | 51 | `delay` is in seconds and the waiter returns the time waited. 52 | """ 53 | 54 | async def wait_func(): 55 | await asyncio.sleep(delay) 56 | return delay 57 | 58 | return Waiter(wait_func, run_first=run_first) 59 | 60 | @staticmethod 61 | def for_event(event_cls, run_first=True): 62 | """Returns a waiter that waits for the given event. 63 | 64 | The waiter returns the result of the event. 65 | """ 66 | 67 | async def wait_func(): 68 | return await cf_common.event_sys.wait_for(event_cls) 69 | 70 | return Waiter(wait_func, run_first=run_first) 71 | 72 | 73 | class ExceptionHandler: 74 | def __init__(self, func, *, needs_instance=False): 75 | """Initializes an exception handler with the given coroutine function `func`. 76 | 77 | `needs_instance` indicates whether a self argument is required by the `func`. 78 | """ 79 | _ensure_coroutine_func(func) 80 | self.func = func 81 | self.needs_instance = needs_instance 82 | 83 | async def handle(self, exception, instance=None): 84 | if self.needs_instance: 85 | await self.func(instance, exception) 86 | else: 87 | await self.func(exception) 88 | 89 | 90 | class Task: 91 | """A task that repeats until stopped. 92 | 93 | A task must have a name, a coroutine function `func` to execute 94 | periodically and another coroutine function `waiter` to wait on between 95 | calls to `func`. The return value of `waiter` is passed to `func` in the 96 | next call. An optional coroutine function `exception_handler` may be 97 | provided to which exceptions will be reported. 98 | """ 99 | 100 | def __init__(self, name, func, waiter, exception_handler=None, *, instance=None): 101 | """`instance`, if present, is passed as the first argument to `func`.""" 102 | _ensure_coroutine_func(func) 103 | self.name = name 104 | self.func = func 105 | self._waiter = waiter 106 | self._exception_handler = exception_handler 107 | self.instance = instance 108 | self.asyncio_task = None 109 | self.logger = logging.getLogger(self.__class__.__name__) 110 | 111 | def waiter(self, run_first=False): 112 | """Decorator that sets the coroutine as the waiter for this Task.""" 113 | 114 | def decorator(func): 115 | self._waiter = Waiter(func, run_first=run_first) 116 | return func 117 | 118 | return decorator 119 | 120 | def exception_handler(self): 121 | """Decorator that sets the function as the exception handler for this Task.""" 122 | 123 | def decorator(func): 124 | self._exception_handler = ExceptionHandler(func) 125 | return func 126 | 127 | return decorator 128 | 129 | @property 130 | def running(self): 131 | return self.asyncio_task is not None and not self.asyncio_task.done() 132 | 133 | def start(self): 134 | """Starts up the task.""" 135 | if self._waiter is None: 136 | raise WaiterRequired(self.name) 137 | if self.running: 138 | raise TaskAlreadyRunning(self.name) 139 | self.logger.info(f'Starting up task `{self.name}`.') 140 | self.asyncio_task = asyncio.create_task(self._task()) 141 | 142 | async def manual_trigger(self, arg=None): 143 | """Manually triggers the `func` with the optionally provided `arg`.""" 144 | self.logger.info(f'Manually triggering task `{self.name}`.') 145 | await self._execute_func(arg) 146 | 147 | async def stop(self): 148 | """Stops the task, interrupting the currently running coroutines.""" 149 | if self.running: 150 | self.logger.info(f'Stopping task `{self.name}`.') 151 | self.asyncio_task.cancel() 152 | await asyncio.sleep( 153 | 0 154 | ) # To ensure cancellation if called from within the task itself. 155 | 156 | async def _task(self): 157 | arg = None 158 | if self._waiter.run_first: 159 | arg = await self._waiter.wait(self.instance) 160 | while True: 161 | await self._execute_func(arg) 162 | arg = await self._waiter.wait(self.instance) 163 | 164 | async def _execute_func(self, arg): 165 | try: 166 | if self.instance is not None: 167 | await self.func(self.instance, arg) 168 | else: 169 | await self.func(arg) 170 | except asyncio.CancelledError: 171 | raise 172 | except Exception as ex: 173 | self.logger.warning( 174 | f'Exception in task `{self.name}`, ignoring.', exc_info=True 175 | ) 176 | if self._exception_handler is not None: 177 | await self._exception_handler.handle(ex, self.instance) 178 | 179 | 180 | class TaskSpec: 181 | """A descriptor intended to be an interface between an instance and its tasks. 182 | 183 | It creates the expected task when `__get__` is called from an instance for 184 | the first time. No two task specs in the same class should have the same 185 | name. 186 | """ 187 | 188 | def __init__(self, name, func, waiter=None, exception_handler=None): 189 | _ensure_coroutine_func(func) 190 | self.name = name 191 | self.func = func 192 | self._waiter = waiter 193 | self._exception_handler = exception_handler 194 | 195 | def waiter(self, run_first=False, needs_instance=True): 196 | """Decorator that sets the coroutine as the waiter for this TaskSpec.""" 197 | 198 | def decorator(func): 199 | self._waiter = Waiter( 200 | func, run_first=run_first, needs_instance=needs_instance 201 | ) 202 | return func 203 | 204 | return decorator 205 | 206 | def exception_handler(self, needs_instance=True): 207 | """Decorator that sets the coroutine as the exception handler for this TaskSpec.""" # noqa: E501 208 | 209 | def decorator(func): 210 | self._exception_handler = ExceptionHandler( 211 | func, needs_instance=needs_instance 212 | ) 213 | return func 214 | 215 | return decorator 216 | 217 | def __get__(self, instance, owner): 218 | if instance is None: 219 | return self 220 | try: 221 | tasks = instance.___tasks___ 222 | except AttributeError: 223 | tasks = instance.___tasks___ = {} 224 | if self.name not in tasks: 225 | tasks[self.name] = Task( 226 | self.name, 227 | self.func, 228 | self._waiter, 229 | self._exception_handler, 230 | instance=instance, 231 | ) 232 | return tasks[self.name] 233 | 234 | 235 | def task(*, name, waiter=None, exception_handler=None): 236 | """Returns a decorator that creates a `Task` with the given options.""" 237 | 238 | def decorator(func): 239 | return Task(name, func, waiter, exception_handler, instance=None) 240 | 241 | return decorator 242 | 243 | 244 | def task_spec(*, name, waiter=None, exception_handler=None): 245 | """Decorator that creates a `TaskSpec` descriptor with the given options.""" 246 | 247 | def decorator(func): 248 | return TaskSpec(name, func, waiter, exception_handler) 249 | 250 | return decorator 251 | --------------------------------------------------------------------------------