├── .gitattributes ├── .gitignore ├── Docker.md ├── Dockerfile ├── LICENSE ├── README.md ├── environment.template ├── extra ├── fonts.conf └── scrape_cf_contest_writers.py ├── poetry.lock ├── pyproject.toml ├── run.sh └── tle ├── __init__.py ├── __main__.py ├── cogs ├── acd_ai.py ├── cache_control.py ├── codeforces.py ├── contests.py ├── deactivated │ └── cses.py ├── duel.py ├── graphs.py ├── handles.py ├── hard75Challenge.py ├── lockout.py ├── logging.py ├── meta.py ├── ref_bot.py ├── starboard.py └── training.py ├── constants.py └── util ├── ACDLaddersProblems.py ├── __init__.py ├── cache_system2.py ├── codeforces_api.py ├── codeforces_common.py ├── cses_scraper.py ├── db ├── __init__.py ├── cache_db_conn.py └── user_db_conn.py ├── discord_common.py ├── elo.py ├── events.py ├── font_downloader.py ├── gemini_model_settings.py ├── graph_common.py ├── handledict.py ├── paginator.py ├── ranklist ├── __init__.py ├── ranklist.py └── rating_calculator.py ├── table.py └── tasks.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | files/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | environment 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | .idea/ 109 | pip-wheel-metadata/ 110 | 111 | # vscode 112 | .vscode/ 113 | 114 | # data 115 | /data 116 | 117 | # logs 118 | /logs 119 | -------------------------------------------------------------------------------- /Docker.md: -------------------------------------------------------------------------------- 1 | # How to run the bot inside a docker container 2 | ## Motivation 3 | Docker is a service that helps in creating isolation in the local environment. For example, if your machine runs on Windows with Python 2, you won't have to worry about running the bot that has been developed on Linux with Python 3.7 or 3.8. 4 | 5 | The introduced `Dockerfile` uses `Ubuntu 18.04` and `Python3.8` to run the bot in an isolated environment. 6 | ### Clone the repository 7 | 8 | ```console 9 | foo@bar:~$ git clone https://github.com/cheran-senthil/TLE 10 | ``` 11 | 12 | ### Building docker image 13 | 14 | 15 | - Navigate to `TLE` and Build the image using the following command: 16 | ```console 17 | foo@bar:~$ sudo docker build . 18 | ``` 19 | 20 | ### Setting up Environment Variables 21 | 22 | 23 | - Create a new file `environment` from `environment.template`. 24 | 25 | ```bash 26 | cp environment.template environment 27 | ``` 28 | 29 | Fill in appropriate variables in new "environment" file. 30 | 31 | 32 | - open the file `environment`. 33 | ```console 34 | export BOT_TOKEN="XXXXXXXXXXXXXXXXXXXXXXXX.XXXXXX.XXXXXXXXXXXXXXXXXXXXXXXXXXX" 35 | export LOGGING_COG_CHANNEL_ID="XXXXXXXXXXXXXXXXXX" 36 | ``` 37 | - Change the value of `BOT_TOKEN` with the token of the bot you created from [this step](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token). 38 | 39 | - Replace the value of `LOGGING_COG_CHANNEL_ID` with discord [channel id](https://support.discord.com/hc/en-us/articles/206346498-Where-can-I-find-my-User-Server-Message-ID-) that you want to use as a logging channel. 40 | 41 | ### Running the container 42 | 43 | 44 | - Get the id of the image you just built from `sudo docker images` and run: 45 | 46 | ```console 47 | foo@bar:~$ sudo docker run -v ${PWD}:/TLE -it --net host 48 | ``` 49 | 50 | PS: use `-d` flag to run in backgroud. Then to kill backgroud container, Get the id of the container using `sudo docker ps` and run `sudo docker kill ` 51 | 52 | ### Debugging/Running Commands inside the container 53 | 54 | To Run Commands inside the container 55 | 56 | - Get the id of the container you just run using `sudo docker ps` and run: 57 | 58 | ```console 59 | foo@bar:~$ sudo docker exec -it bash 60 | ``` 61 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | USER root 3 | WORKDIR /TLE 4 | 5 | RUN apt-get update 6 | RUN apt-get install -y git apt-utils sqlite3 7 | RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y libcairo2-dev libgirepository1.0-dev libpango1.0-dev pkg-config python3-dev gir1.2-pango-1.0 python3.8-venv libpython3.8-dev libjpeg-dev zlib1g-dev python3-pip 8 | RUN python3.8 -m pip install poetry 9 | 10 | COPY ./poetry.lock ./poetry.lock 11 | COPY ./pyproject.toml ./pyproject.toml 12 | 13 | RUN python3.8 -m poetry install 14 | 15 | COPY . . 16 | 17 | ENTRYPOINT ["/TLE/run.sh"] 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Cheran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TLE 2 | 3 | TLE is a Discord bot centered around Competitive Programming. 4 | 5 | ## Features 6 | 7 | The features of the bot are split into a number of cogs, each handling their own set of commands. 8 | 9 | ### Codeforces cogs 10 | 11 | - **Codeforces** Commands that can recommend problems or contests to users, taking their rating into account. 12 | - **Contests** Shows details of upcoming/running contests. 13 | - **Graphs** Plots various data gathered from Codeforces, e.g. rating distributions and user problem statistics. 14 | - **Handles** Gets or sets information about a specific user's Codeforces handle, or shows a list of Codeforces handles. 15 | - **Duel** Commands to set up a duel between two users. 16 | - **Training** Start a training session that gets harder and harder. 17 | - **Lockout** Integration of the round command of the Lockout Bot. 18 | 19 | ### Other cogs 20 | 21 | - **Starboard** Commands related to the starboard, which adds messages to a specific channel when enough users react with a ⭐️. 22 | - **CacheControl** Commands related to data caching. 23 | 24 | ## Installation 25 | > If you want to run the bot inside a docker container follow these [instructions](/Docker.md) 26 | 27 | Clone the repository 28 | 29 | ```bash 30 | git clone https://github.com/Denjell/TLE 31 | ``` 32 | 33 | ### Dependencies 34 | 35 | Now all dependencies need to be installed. TLE uses [Poetry](https://poetry.eustace.io/) to manage its python dependencies. After installing Poetry navigate to the root of the repo and run 36 | 37 | ```bash 38 | poetry install 39 | ``` 40 | 41 | > :warning: **TLE requires Python 3.8 or later!** 42 | 43 | If you are using Ubuntu with older versions of python, then do the following: 44 | 45 | ```bash 46 | apt-get install python3.8-venv libpython3.8-dev 47 | python3.8 -m pip install poetry 48 | python3.8 -m poetry install 49 | ``` 50 | 51 | --- 52 | 53 | #### Library dependencies 54 | 55 | TLE also depends on cairo and pango for graphics and text rendering, which you need to install. For Ubuntu, the relevant packages can be installed with: 56 | 57 | ```bash 58 | apt-get install libcairo2-dev libgirepository1.0-dev libpango1.0-dev pkg-config python3-dev gir1.2-pango-1.0 59 | ``` 60 | 61 | Additionally TLE uses pillow for graphics, which requires the following packages: 62 | 63 | ```bash 64 | apt-get install libjpeg-dev zlib1g-dev 65 | ``` 66 | 67 | ### Final steps 68 | 69 | You will need to setup a bot on your server before continuing, follow the directions [here](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token). Following this, you should have your bot appearing in your server and you should have the Discord bot token. Finally, go to the `Bot` settings in your App's Developer Portal (in the same page where you copied your Bot Token) and enable the `Server Members Intent` and `Message Content Intent`. 70 | 71 | Create a new file `environment`. 72 | 73 | ```bash 74 | cp environment.template environment 75 | ``` 76 | 77 | Fill in appropriate variables in new "environment" file. 78 | 79 | #### Environment Variables 80 | 81 | - **BOT_TOKEN**: the Discord Bot Token for your bot. 82 | - **LOGGING_COG_CHANNEL_ID**: the [Discord Channel ID](https://support.discord.com/hc/en-us/articles/206346498-Where-can-I-find-my-User-Server-Message-ID-) of a Discord Channel where you want error messages sent to. 83 | - **GEMINI_API_KEY**: the [Gemini API key](https://makersuite.google.com/app/apikey) for AI chat to function. 84 | - **TLE_ADMIN**: the name of the role that can run admin commands of the bot. If this is not set, the role name will default to "Admin". 85 | - **TLE_MODERATOR**: the name of the role that can run moderator commands of the bot. If this is not set, the role name will default to "Moderator". 86 | 87 | To start TLE just run: 88 | 89 | ```bash 90 | ./run.sh 91 | ``` 92 | 93 | ### Notes 94 | 95 | - In order to run admin-only commands, you need to have the `Admin` role, which needs to be created in your Discord server and assign it to yourself/other administrators. 96 | - In order to prevent the bot suggesting an author's problems to the author, a python file needs to be run (since this can not be done through the Codeforces API) which will save the authors for specific contests to a file. To do this run `python extra/scrape_cf_contest_writers.py` which will generate a JSON file that should be placed in the `data/misc/` folder. 97 | - In order to display CJK (East Asian) characters for usernames, we need appropriate fonts. Their size is ~36MB, so we don't keep in the repo itself and it is gitignored. They will be downloaded automatically when the bot is run if not already present. 98 | - One of the bot's features is to assign roles to users based on their rating on Codeforces. In order for this functionality to work properly, the following roles need to exist in your Discord server 99 | - Unrated 100 | - Newbie 101 | - Pupil 102 | - Specialist 103 | - Expert 104 | - Candidate Master 105 | - Master 106 | - International Master 107 | - Grandmaster 108 | - International Grandmaster 109 | - Legendary Grandmaster 110 | - One of the bot's commands require problemsets to be cached. Run `;cache problemsets all` at the very first time the bot is used. The command may take around 10 minutes to run. 111 | 112 | ## Usage 113 | 114 | In order to run bot commands you can either ping the bot at the beginning of the command or prefix the command with a semicolon (;), e.g. `;handle pretty`. 115 | 116 | In order to find available commands, you can run `;help` which will bring a list of commands/groups of commands which are available. To get more details about a specific command you can type `;help `. 117 | 118 | ## Contributing 119 | 120 | Pull requests are welcome. For major changes please open an issue first to discuss what you would like to change. 121 | 122 | Before submitting your PR, consider running some code formatter on the lines you touched or added. This will help reduce the time spent on fixing small styling issues in code review. Good options are [yapf](https://github.com/google/yapf) or [autopep8](https://github.com/hhatto/autopep8) which likely can be integrated into your favorite editor. 123 | 124 | Please refrain from formatting the whole file if you just change some small part of it. If you feel the need to tidy up some particularly egregious code, then do that in a separate PR. 125 | 126 | ## License 127 | 128 | [MIT](https://choosealicense.com/licenses/mit/) 129 | -------------------------------------------------------------------------------- /environment.template: -------------------------------------------------------------------------------- 1 | export BOT_TOKEN="XXXXXXXXXXXXXXXXXXXXXXXX.XXXXXX.XXXXXXXXXXXXXXXXXXXXXXXXXXX" 2 | export LOGGING_COG_CHANNEL_ID="XXXXXXXXXXXXXXXXXX" 3 | export GEMINI_API_KEY="XXXXXXXXXXXXXXXXXXXXXXXXXXX" 4 | -------------------------------------------------------------------------------- /extra/fonts.conf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TLE fontconfig file 5 | data/assets/fonts 6 | 7 | -------------------------------------------------------------------------------- /extra/scrape_cf_contest_writers.py: -------------------------------------------------------------------------------- 1 | """This script scrapes contests and their writers from Codeforces and saves 2 | them to a JSON file. This exists because there is no way to do this through 3 | the official API :( 4 | """ 5 | 6 | import json 7 | import urllib.request 8 | 9 | from lxml import html 10 | 11 | URL = 'https://codeforces.com/contests/page/{}?complete=true' 12 | JSONFILE = 'contest_writers.json' 13 | 14 | def get_page(pagenum): 15 | url = URL.format(pagenum) 16 | with urllib.request.urlopen(url) as f: 17 | text = f.read().decode() 18 | return html.fromstring(text) 19 | 20 | def get_contests(doc): 21 | contests = [] 22 | rows = doc.xpath('//div[@class="contests-table"]//table[1]//tr')[1:] 23 | for row in rows: 24 | contest_id = int(row.get('data-contestid')) 25 | name, writers, start, length, standings, registrants = row.xpath('td') 26 | writers = writers.text_content().split() 27 | contests.append({'id': contest_id, 'writers': writers}) 28 | return contests 29 | 30 | 31 | print('Fetching page 1') 32 | page1 = get_page(1) 33 | lastpage = int(page1.xpath('//span[@class="page-index"]')[-1].get('pageindex')) 34 | 35 | contests = get_contests(page1) 36 | print(f'Found {len(contests)} contests') 37 | 38 | for pagenum in range(2, lastpage + 1): 39 | print(f'Fetching page {pagenum}') 40 | page = get_page(pagenum) 41 | page_contests = get_contests(page) 42 | print(f'Found {len(page_contests)} contests') 43 | contests.extend(page_contests) 44 | 45 | print(f'Found total {len(contests)} contests') 46 | 47 | with open(JSONFILE, 'w') as f: 48 | json.dump(contests, f) 49 | print(f'Data written to {JSONFILE}') 50 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "aiocache" 3 | version = "0.11.1" 4 | description = "multi backend asyncio cache" 5 | category = "main" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [package.extras] 10 | dev = ["asynctest (>=0.11.0)", "codecov", "coverage", "flake8", "ipdb", "marshmallow", "pystache", "pytest", "pytest-asyncio", "pytest-mock", "sphinx", "sphinx-autobuild", "sphinx-rtd-theme", "black"] 11 | memcached = ["aiomcache (>=0.5.2)"] 12 | msgpack = ["msgpack (>=0.5.5)"] 13 | redis = ["aioredis (>=0.3.3)", "aioredis (>=1.0.0)"] 14 | 15 | [[package]] 16 | name = "aiohttp" 17 | version = "3.8.3" 18 | description = "Async http client/server framework (asyncio)" 19 | category = "main" 20 | optional = false 21 | python-versions = ">=3.6" 22 | 23 | [package.dependencies] 24 | aiosignal = ">=1.1.2" 25 | async-timeout = ">=4.0.0a3,<5.0" 26 | attrs = ">=17.3.0" 27 | charset-normalizer = ">=2.0,<3.0" 28 | frozenlist = ">=1.1.1" 29 | multidict = ">=4.5,<7.0" 30 | yarl = ">=1.0,<2.0" 31 | 32 | [package.extras] 33 | speedups = ["aiodns", "brotli", "cchardet"] 34 | 35 | [[package]] 36 | name = "aiosignal" 37 | version = "1.3.1" 38 | description = "aiosignal: a list of registered asynchronous callbacks" 39 | category = "main" 40 | optional = false 41 | python-versions = ">=3.7" 42 | 43 | [package.dependencies] 44 | frozenlist = ">=1.1.0" 45 | 46 | [[package]] 47 | name = "async-timeout" 48 | version = "4.0.2" 49 | description = "Timeout context manager for asyncio programs" 50 | category = "main" 51 | optional = false 52 | python-versions = ">=3.6" 53 | 54 | [[package]] 55 | name = "atomicwrites" 56 | version = "1.4.1" 57 | description = "Atomic file writes." 58 | category = "dev" 59 | optional = false 60 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 61 | 62 | [[package]] 63 | name = "attrs" 64 | version = "22.1.0" 65 | description = "Classes Without Boilerplate" 66 | category = "main" 67 | optional = false 68 | python-versions = ">=3.5" 69 | 70 | [package.extras] 71 | dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] 72 | docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] 73 | tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] 74 | tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "cloudpickle"] 75 | 76 | [[package]] 77 | name = "cachetools" 78 | version = "5.3.2" 79 | description = "Extensible memoizing collections and decorators" 80 | category = "main" 81 | optional = false 82 | python-versions = ">=3.7" 83 | 84 | [[package]] 85 | name = "certifi" 86 | version = "2023.11.17" 87 | description = "Python package for providing Mozilla's CA Bundle." 88 | category = "main" 89 | optional = false 90 | python-versions = ">=3.6" 91 | 92 | [[package]] 93 | name = "charset-normalizer" 94 | version = "2.1.1" 95 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." 96 | category = "main" 97 | optional = false 98 | python-versions = ">=3.6.0" 99 | 100 | [package.extras] 101 | unicode_backport = ["unicodedata2"] 102 | 103 | [[package]] 104 | name = "colorama" 105 | version = "0.4.6" 106 | description = "Cross-platform colored terminal text." 107 | category = "main" 108 | optional = false 109 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" 110 | 111 | [[package]] 112 | name = "contourpy" 113 | version = "1.0.6" 114 | description = "Python library for calculating contours of 2D quadrilateral grids" 115 | category = "main" 116 | optional = false 117 | python-versions = ">=3.7" 118 | 119 | [package.dependencies] 120 | numpy = ">=1.16" 121 | 122 | [package.extras] 123 | bokeh = ["bokeh", "selenium"] 124 | docs = ["docutils (<0.18)", "sphinx (<=5.2.0)", "sphinx-rtd-theme"] 125 | test = ["pytest", "matplotlib", "pillow", "flake8", "isort"] 126 | test-minimal = ["pytest"] 127 | test-no-codebase = ["pytest", "matplotlib", "pillow"] 128 | 129 | [[package]] 130 | name = "cycler" 131 | version = "0.11.0" 132 | description = "Composable style cycles" 133 | category = "main" 134 | optional = false 135 | python-versions = ">=3.6" 136 | 137 | [[package]] 138 | name = "discord.py" 139 | version = "2.1.0" 140 | description = "A Python wrapper for the Discord API" 141 | category = "main" 142 | optional = false 143 | python-versions = ">=3.8.0" 144 | 145 | [package.dependencies] 146 | aiohttp = ">=3.7.4,<4" 147 | 148 | [package.extras] 149 | docs = ["sphinx (==4.4.0)", "sphinxcontrib-trio (==1.1.2)", "sphinxcontrib-websupport", "typing-extensions (>=4.3,<5)"] 150 | speed = ["orjson (>=3.5.4)", "aiodns (>=1.1)", "brotli", "cchardet (==2.1.7)"] 151 | test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov", "pytest-mock", "typing-extensions (>=4.3,<5)"] 152 | voice = ["PyNaCl (>=1.3.0,<1.6)"] 153 | 154 | [[package]] 155 | name = "fonttools" 156 | version = "4.38.0" 157 | description = "Tools to manipulate font files" 158 | category = "main" 159 | optional = false 160 | python-versions = ">=3.7" 161 | 162 | [package.extras] 163 | all = ["fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "zopfli (>=0.1.4)", "lz4 (>=1.7.4.2)", "matplotlib", "sympy", "skia-pathops (>=0.5.0)", "uharfbuzz (>=0.23.0)", "brotlicffi (>=0.8.0)", "scipy", "brotli (>=1.0.1)", "munkres", "unicodedata2 (>=14.0.0)", "xattr"] 164 | graphite = ["lz4 (>=1.7.4.2)"] 165 | interpolatable = ["scipy", "munkres"] 166 | lxml = ["lxml (>=4.0,<5)"] 167 | pathops = ["skia-pathops (>=0.5.0)"] 168 | plot = ["matplotlib"] 169 | repacker = ["uharfbuzz (>=0.23.0)"] 170 | symfont = ["sympy"] 171 | type1 = ["xattr"] 172 | ufo = ["fs (>=2.2.0,<3)"] 173 | unicode = ["unicodedata2 (>=14.0.0)"] 174 | woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"] 175 | 176 | [[package]] 177 | name = "frozenlist" 178 | version = "1.3.3" 179 | description = "A list-like structure which implements collections.abc.MutableSequence" 180 | category = "main" 181 | optional = false 182 | python-versions = ">=3.7" 183 | 184 | [[package]] 185 | name = "google-ai-generativelanguage" 186 | version = "0.4.0" 187 | description = "Google Ai Generativelanguage API client library" 188 | category = "main" 189 | optional = false 190 | python-versions = ">=3.7" 191 | 192 | [package.dependencies] 193 | google-api-core = {version = ">=1.34.0,<2.0.0 || >=2.11.0,<3.0.0dev", extras = ["grpc"]} 194 | proto-plus = ">=1.22.3,<2.0.0dev" 195 | protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" 196 | 197 | [[package]] 198 | name = "google-api-core" 199 | version = "2.15.0" 200 | description = "Google API client core library" 201 | category = "main" 202 | optional = false 203 | python-versions = ">=3.7" 204 | 205 | [package.dependencies] 206 | google-auth = ">=2.14.1,<3.0.dev0" 207 | googleapis-common-protos = ">=1.56.2,<2.0.dev0" 208 | grpcio = [ 209 | {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, 210 | {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" or python_version >= \"3.11\" and extra == \"grpc\""}, 211 | ] 212 | grpcio-status = [ 213 | {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, 214 | {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" or python_version >= \"3.11\" and extra == \"grpc\""}, 215 | ] 216 | protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" 217 | requests = ">=2.18.0,<3.0.0.dev0" 218 | 219 | [package.extras] 220 | grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.49.1,<2.0.dev0)"] 221 | grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] 222 | grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] 223 | 224 | [[package]] 225 | name = "google-auth" 226 | version = "2.27.0" 227 | description = "Google Authentication Library" 228 | category = "main" 229 | optional = false 230 | python-versions = ">=3.7" 231 | 232 | [package.dependencies] 233 | cachetools = ">=2.0.0,<6.0" 234 | pyasn1-modules = ">=0.2.1" 235 | rsa = ">=3.1.4,<5" 236 | 237 | [package.extras] 238 | aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] 239 | enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] 240 | pyopenssl = ["pyopenssl (>=20.0.0)", "cryptography (>=38.0.3)"] 241 | reauth = ["pyu2f (>=0.1.5)"] 242 | requests = ["requests (>=2.20.0,<3.0.0.dev0)"] 243 | 244 | [[package]] 245 | name = "google-generativeai" 246 | version = "0.3.2" 247 | description = "Google Generative AI High level API client library and tools." 248 | category = "main" 249 | optional = false 250 | python-versions = ">=3.9" 251 | 252 | [package.dependencies] 253 | google-ai-generativelanguage = "0.4.0" 254 | google-api-core = "*" 255 | google-auth = "*" 256 | protobuf = "*" 257 | tqdm = "*" 258 | typing-extensions = "*" 259 | 260 | [package.extras] 261 | dev = ["absl-py", "black", "nose2", "pandas", "pytype", "pyyaml", "pillow", "ipython"] 262 | 263 | [[package]] 264 | name = "googleapis-common-protos" 265 | version = "1.62.0" 266 | description = "Common protobufs used in Google APIs" 267 | category = "main" 268 | optional = false 269 | python-versions = ">=3.7" 270 | 271 | [package.dependencies] 272 | protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" 273 | 274 | [package.extras] 275 | grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] 276 | 277 | [[package]] 278 | name = "grpcio" 279 | version = "1.60.0" 280 | description = "HTTP/2-based RPC framework" 281 | category = "main" 282 | optional = false 283 | python-versions = ">=3.7" 284 | 285 | [package.extras] 286 | protobuf = ["grpcio-tools (>=1.60.0)"] 287 | 288 | [[package]] 289 | name = "grpcio-status" 290 | version = "1.60.0" 291 | description = "Status proto mapping for gRPC" 292 | category = "main" 293 | optional = false 294 | python-versions = ">=3.6" 295 | 296 | [package.dependencies] 297 | googleapis-common-protos = ">=1.5.5" 298 | grpcio = ">=1.60.0" 299 | protobuf = ">=4.21.6" 300 | 301 | [[package]] 302 | name = "idna" 303 | version = "3.4" 304 | description = "Internationalized Domain Names in Applications (IDNA)" 305 | category = "main" 306 | optional = false 307 | python-versions = ">=3.5" 308 | 309 | [[package]] 310 | name = "kiwisolver" 311 | version = "1.4.4" 312 | description = "A fast implementation of the Cassowary constraint solver" 313 | category = "main" 314 | optional = false 315 | python-versions = ">=3.7" 316 | 317 | [[package]] 318 | name = "lxml" 319 | version = "4.9.1" 320 | description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." 321 | category = "main" 322 | optional = false 323 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" 324 | 325 | [package.extras] 326 | cssselect = ["cssselect (>=0.7)"] 327 | html5 = ["html5lib"] 328 | htmlsoup = ["beautifulsoup4"] 329 | source = ["Cython (>=0.29.7)"] 330 | 331 | [[package]] 332 | name = "matplotlib" 333 | version = "3.6.2" 334 | description = "Python plotting package" 335 | category = "main" 336 | optional = false 337 | python-versions = ">=3.8" 338 | 339 | [package.dependencies] 340 | contourpy = ">=1.0.1" 341 | cycler = ">=0.10" 342 | fonttools = ">=4.22.0" 343 | kiwisolver = ">=1.0.1" 344 | numpy = ">=1.19" 345 | packaging = ">=20.0" 346 | pillow = ">=6.2.0" 347 | pyparsing = ">=2.2.1" 348 | python-dateutil = ">=2.7" 349 | setuptools_scm = ">=7" 350 | 351 | [[package]] 352 | name = "more-itertools" 353 | version = "9.0.0" 354 | description = "More routines for operating on iterables, beyond itertools" 355 | category = "dev" 356 | optional = false 357 | python-versions = ">=3.7" 358 | 359 | [[package]] 360 | name = "multidict" 361 | version = "6.0.2" 362 | description = "multidict implementation" 363 | category = "main" 364 | optional = false 365 | python-versions = ">=3.7" 366 | 367 | [[package]] 368 | name = "numpy" 369 | version = "1.23.4" 370 | description = "NumPy is the fundamental package for array computing with Python." 371 | category = "main" 372 | optional = false 373 | python-versions = ">=3.8" 374 | 375 | [[package]] 376 | name = "packaging" 377 | version = "21.3" 378 | description = "Core utilities for Python packages" 379 | category = "main" 380 | optional = false 381 | python-versions = ">=3.6" 382 | 383 | [package.dependencies] 384 | pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" 385 | 386 | [[package]] 387 | name = "pandas" 388 | version = "1.5.1" 389 | description = "Powerful data structures for data analysis, time series, and statistics" 390 | category = "main" 391 | optional = false 392 | python-versions = ">=3.8" 393 | 394 | [package.dependencies] 395 | numpy = [ 396 | {version = ">=1.20.3", markers = "python_version < \"3.10\""}, 397 | {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, 398 | ] 399 | python-dateutil = ">=2.8.1" 400 | pytz = ">=2020.1" 401 | 402 | [package.extras] 403 | test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] 404 | 405 | [[package]] 406 | name = "pillow" 407 | version = "9.3.0" 408 | description = "Python Imaging Library (Fork)" 409 | category = "main" 410 | optional = false 411 | python-versions = ">=3.7" 412 | 413 | [package.extras] 414 | docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"] 415 | tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] 416 | 417 | [[package]] 418 | name = "pluggy" 419 | version = "1.0.0" 420 | description = "plugin and hook calling mechanisms for python" 421 | category = "dev" 422 | optional = false 423 | python-versions = ">=3.6" 424 | 425 | [package.extras] 426 | dev = ["pre-commit", "tox"] 427 | testing = ["pytest", "pytest-benchmark"] 428 | 429 | [[package]] 430 | name = "proto-plus" 431 | version = "1.23.0" 432 | description = "Beautiful, Pythonic protocol buffers." 433 | category = "main" 434 | optional = false 435 | python-versions = ">=3.6" 436 | 437 | [package.dependencies] 438 | protobuf = ">=3.19.0,<5.0.0dev" 439 | 440 | [package.extras] 441 | testing = ["google-api-core[grpc] (>=1.31.5)"] 442 | 443 | [[package]] 444 | name = "protobuf" 445 | version = "4.25.2" 446 | description = "" 447 | category = "main" 448 | optional = false 449 | python-versions = ">=3.8" 450 | 451 | [[package]] 452 | name = "py" 453 | version = "1.11.0" 454 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 455 | category = "dev" 456 | optional = false 457 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 458 | 459 | [[package]] 460 | name = "pyasn1" 461 | version = "0.5.1" 462 | description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" 463 | category = "main" 464 | optional = false 465 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" 466 | 467 | [[package]] 468 | name = "pyasn1-modules" 469 | version = "0.3.0" 470 | description = "A collection of ASN.1-based protocols modules" 471 | category = "main" 472 | optional = false 473 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" 474 | 475 | [package.dependencies] 476 | pyasn1 = ">=0.4.6,<0.6.0" 477 | 478 | [[package]] 479 | name = "pycairo" 480 | version = "1.22.0" 481 | description = "Python interface for cairo" 482 | category = "main" 483 | optional = false 484 | python-versions = ">=3.7" 485 | 486 | [[package]] 487 | name = "pygobject" 488 | version = "3.42.2" 489 | description = "Python bindings for GObject Introspection" 490 | category = "main" 491 | optional = false 492 | python-versions = ">=3.6, <4" 493 | 494 | [package.dependencies] 495 | pycairo = ">=1.16,<2.0" 496 | 497 | [[package]] 498 | name = "pyparsing" 499 | version = "3.0.9" 500 | description = "pyparsing module - Classes and methods to define and execute parsing grammars" 501 | category = "main" 502 | optional = false 503 | python-versions = ">=3.6.8" 504 | 505 | [package.extras] 506 | diagrams = ["railroad-diagrams", "jinja2"] 507 | 508 | [[package]] 509 | name = "pytest" 510 | version = "3.10.1" 511 | description = "pytest: simple powerful testing with Python" 512 | category = "dev" 513 | optional = false 514 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 515 | 516 | [package.dependencies] 517 | atomicwrites = ">=1.0" 518 | attrs = ">=17.4.0" 519 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 520 | more-itertools = ">=4.0.0" 521 | pluggy = ">=0.7" 522 | py = ">=1.5.0" 523 | six = ">=1.10.0" 524 | 525 | [[package]] 526 | name = "python-dateutil" 527 | version = "2.8.2" 528 | description = "Extensions to the standard Python datetime module" 529 | category = "main" 530 | optional = false 531 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" 532 | 533 | [package.dependencies] 534 | six = ">=1.5" 535 | 536 | [[package]] 537 | name = "pytz" 538 | version = "2022.6" 539 | description = "World timezone definitions, modern and historical" 540 | category = "main" 541 | optional = false 542 | python-versions = "*" 543 | 544 | [[package]] 545 | name = "ratelimit" 546 | version = "2.2.1" 547 | description = "API rate limit decorator" 548 | category = "main" 549 | optional = false 550 | python-versions = "*" 551 | 552 | [[package]] 553 | name = "requests" 554 | version = "2.31.0" 555 | description = "Python HTTP for Humans." 556 | category = "main" 557 | optional = false 558 | python-versions = ">=3.7" 559 | 560 | [package.dependencies] 561 | certifi = ">=2017.4.17" 562 | charset-normalizer = ">=2,<4" 563 | idna = ">=2.5,<4" 564 | urllib3 = ">=1.21.1,<3" 565 | 566 | [package.extras] 567 | socks = ["PySocks (>=1.5.6,!=1.5.7)"] 568 | use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] 569 | 570 | [[package]] 571 | name = "rsa" 572 | version = "4.9" 573 | description = "Pure-Python RSA implementation" 574 | category = "main" 575 | optional = false 576 | python-versions = ">=3.6,<4" 577 | 578 | [package.dependencies] 579 | pyasn1 = ">=0.1.3" 580 | 581 | [[package]] 582 | name = "scipy" 583 | version = "1.9.3" 584 | description = "Fundamental algorithms for scientific computing in Python" 585 | category = "main" 586 | optional = false 587 | python-versions = ">=3.8" 588 | 589 | [package.dependencies] 590 | numpy = ">=1.18.5,<1.26.0" 591 | 592 | [package.extras] 593 | test = ["pytest", "pytest-cov", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack"] 594 | doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-panels (>=0.5.2)", "matplotlib (>2)", "numpydoc", "sphinx-tabs"] 595 | dev = ["mypy", "typing-extensions", "pycodestyle", "flake8"] 596 | 597 | [[package]] 598 | name = "seaborn" 599 | version = "0.10.1" 600 | description = "seaborn: statistical data visualization" 601 | category = "main" 602 | optional = false 603 | python-versions = ">=3.6" 604 | 605 | [package.dependencies] 606 | matplotlib = ">=2.1.2" 607 | numpy = ">=1.13.3" 608 | pandas = ">=0.22.0" 609 | scipy = ">=1.0.1" 610 | 611 | [[package]] 612 | name = "setuptools-scm" 613 | version = "7.0.5" 614 | description = "the blessed package to manage your versions by scm tags" 615 | category = "main" 616 | optional = false 617 | python-versions = ">=3.7" 618 | 619 | [package.dependencies] 620 | packaging = ">=20.0" 621 | tomli = ">=1.0.0" 622 | typing-extensions = "*" 623 | 624 | [package.extras] 625 | test = ["pytest (>=6.2)", "virtualenv (>20)"] 626 | toml = ["setuptools (>=42)"] 627 | 628 | [[package]] 629 | name = "six" 630 | version = "1.16.0" 631 | description = "Python 2 and 3 compatibility utilities" 632 | category = "main" 633 | optional = false 634 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 635 | 636 | [[package]] 637 | name = "tomli" 638 | version = "2.0.1" 639 | description = "A lil' TOML parser" 640 | category = "main" 641 | optional = false 642 | python-versions = ">=3.7" 643 | 644 | [[package]] 645 | name = "tqdm" 646 | version = "4.66.1" 647 | description = "Fast, Extensible Progress Meter" 648 | category = "main" 649 | optional = false 650 | python-versions = ">=3.7" 651 | 652 | [package.dependencies] 653 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 654 | 655 | [package.extras] 656 | dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] 657 | notebook = ["ipywidgets (>=6)"] 658 | slack = ["slack-sdk"] 659 | telegram = ["requests"] 660 | 661 | [[package]] 662 | name = "typing-extensions" 663 | version = "4.4.0" 664 | description = "Backported and Experimental Type Hints for Python 3.7+" 665 | category = "main" 666 | optional = false 667 | python-versions = ">=3.7" 668 | 669 | [[package]] 670 | name = "urllib3" 671 | version = "2.1.0" 672 | description = "HTTP library with thread-safe connection pooling, file post, and more." 673 | category = "main" 674 | optional = false 675 | python-versions = ">=3.8" 676 | 677 | [package.extras] 678 | brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] 679 | socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] 680 | zstd = ["zstandard (>=0.18.0)"] 681 | 682 | [[package]] 683 | name = "yarl" 684 | version = "1.8.1" 685 | description = "Yet another URL library" 686 | category = "main" 687 | optional = false 688 | python-versions = ">=3.7" 689 | 690 | [package.dependencies] 691 | idna = ">=2.0" 692 | multidict = ">=4.0" 693 | 694 | [metadata] 695 | lock-version = "1.1" 696 | python-versions = "^3.9" 697 | content-hash = "3c5ae8ce9273a451c7e006d6a19f5841e8068f8359ffe597f5493a1f90d74d89" 698 | 699 | [metadata.files] 700 | aiocache = [] 701 | aiohttp = [] 702 | aiosignal = [] 703 | async-timeout = [] 704 | atomicwrites = [] 705 | attrs = [] 706 | cachetools = [] 707 | certifi = [] 708 | charset-normalizer = [] 709 | colorama = [] 710 | contourpy = [] 711 | cycler = [] 712 | "discord.py" = [] 713 | fonttools = [] 714 | frozenlist = [] 715 | google-ai-generativelanguage = [] 716 | google-api-core = [] 717 | google-auth = [] 718 | google-generativeai = [] 719 | googleapis-common-protos = [] 720 | grpcio = [] 721 | grpcio-status = [] 722 | idna = [] 723 | kiwisolver = [] 724 | lxml = [] 725 | matplotlib = [] 726 | more-itertools = [] 727 | multidict = [] 728 | numpy = [] 729 | packaging = [] 730 | pandas = [] 731 | pillow = [] 732 | pluggy = [] 733 | proto-plus = [] 734 | protobuf = [] 735 | py = [] 736 | pyasn1 = [] 737 | pyasn1-modules = [] 738 | pycairo = [] 739 | pygobject = [] 740 | pyparsing = [] 741 | pytest = [] 742 | python-dateutil = [] 743 | pytz = [] 744 | ratelimit = [] 745 | requests = [] 746 | rsa = [] 747 | scipy = [] 748 | seaborn = [] 749 | setuptools-scm = [] 750 | six = [] 751 | tomli = [] 752 | tqdm = [] 753 | typing-extensions = [] 754 | urllib3 = [] 755 | yarl = [] 756 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "TLE" 3 | version = "0.2.0" 4 | description = "" 5 | authors = ["meooow25 "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.9" 9 | "discord.py" = "^2.1.0" 10 | seaborn = "^0.10.1" 11 | lxml = "^4.6" 12 | pillow = "^9.0" 13 | pycairo = "^1.19.1" 14 | PyGObject = "^3.34.0" 15 | aiocache = "^0.11.1" 16 | requests = "^2.31.0" 17 | google-generativeai = "^0.3.2" 18 | ratelimit = "^2.2.1" 19 | 20 | [tool.poetry.dev-dependencies] 21 | pytest = "^3.0" 22 | 23 | [build-system] 24 | requires = ["poetry>=0.12"] 25 | build-backend = "poetry.masonry.api" 26 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get to a predictable directory, the directory of this script 4 | cd "$(dirname "$0")" 5 | 6 | [ -e environment ] && . ./environment 7 | 8 | while true; do 9 | git pull 10 | poetry install 11 | FONTCONFIG_FILE=$PWD/extra/fonts.conf poetry run python -m tle 12 | 13 | (( $? != 42 )) && break 14 | 15 | echo '===================================================================' 16 | echo '= Restarting =' 17 | echo '===================================================================' 18 | done 19 | -------------------------------------------------------------------------------- /tle/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.0' 2 | -------------------------------------------------------------------------------- /tle/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import distutils.util 4 | import logging 5 | import os 6 | import discord 7 | from logging.handlers import TimedRotatingFileHandler 8 | from os import environ 9 | from pathlib import Path 10 | 11 | import seaborn as sns 12 | from discord.ext import commands 13 | from matplotlib import pyplot as plt 14 | 15 | from tle import constants 16 | from tle.util import codeforces_common as cf_common 17 | from tle.util import discord_common, font_downloader 18 | 19 | 20 | 21 | def setup(): 22 | # Make required directories. 23 | for path in constants.ALL_DIRS: 24 | os.makedirs(path, exist_ok=True) 25 | 26 | # logging to console and file on daily interval 27 | logging.basicConfig(format='{asctime}:{levelname}:{name}:{message}', style='{', 28 | datefmt='%d-%m-%Y %H:%M:%S', level=logging.INFO, 29 | handlers=[logging.StreamHandler(), 30 | TimedRotatingFileHandler(constants.LOG_FILE_PATH, when='D', 31 | backupCount=3, utc=True)]) 32 | 33 | # matplotlib and seaborn 34 | plt.rcParams['figure.figsize'] = 7.0, 3.5 35 | sns.set() 36 | options = { 37 | 'axes.edgecolor': '#A0A0C5', 38 | 'axes.spines.top': False, 39 | 'axes.spines.right': False, 40 | } 41 | sns.set_style('darkgrid', options) 42 | 43 | # Download fonts if necessary 44 | font_downloader.maybe_download() 45 | 46 | 47 | async def main(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--nodb', action='store_true') 50 | args = parser.parse_args() 51 | 52 | token = environ.get('BOT_TOKEN') 53 | if not token: 54 | logging.error('Token required') 55 | return 56 | 57 | setup() 58 | 59 | intents = discord.Intents.default() 60 | intents.members = True 61 | intents.message_content = True 62 | 63 | bot = commands.Bot(command_prefix=commands.when_mentioned_or(discord_common._BOT_PREFIX), intents=intents) 64 | bot.help_command = discord_common.TleHelp() 65 | cogs = [file.stem for file in Path('tle', 'cogs').glob('*.py')] 66 | for extension in cogs: 67 | await bot.load_extension(f'tle.cogs.{extension}') 68 | logging.info(f'Cogs loaded: {", ".join(bot.cogs)}') 69 | 70 | def no_dm_check(ctx): 71 | if ctx.guild is None: 72 | raise commands.NoPrivateMessage('Private messages not permitted.') 73 | return True 74 | 75 | # Restrict bot usage to inside guild channels only. 76 | bot.add_check(no_dm_check) 77 | 78 | # cf_common.initialize needs to run first, so it must be set as the bot's 79 | # on_ready event handler rather than an on_ready listener. 80 | @discord_common.on_ready_event_once(bot) 81 | async def init(): 82 | await cf_common.initialize(args.nodb) 83 | asyncio.create_task(discord_common.presence(bot)) 84 | 85 | bot.add_listener(discord_common.bot_error_handler, name='on_command_error') 86 | await bot.start(token) 87 | 88 | 89 | if __name__ == '__main__': 90 | asyncio.run(main()) 91 | -------------------------------------------------------------------------------- /tle/cogs/acd_ai.py: -------------------------------------------------------------------------------- 1 | from discord.ext import commands 2 | import logging 3 | import google.generativeai as genai 4 | from tle.util import gemini_model_settings as settings 5 | from tle.util import discord_common 6 | from tle.util import codeforces_common as cf_common 7 | import requests 8 | from PIL import Image 9 | from tle import constants 10 | from ratelimit import limits, sleep_and_retry 11 | 12 | 13 | 14 | ONE_MINUTE = 60 15 | MAX_CALLS_PER_MINUTE = 60 16 | GEMINI_API_KEY = constants.GEMINI_API_KEY 17 | 18 | IMAGE_TYPES = ["jpeg", "png", "webp", "heic", "heif"] 19 | NO_RESPONSE_MESSAGE = "Sorry, can't answer that. possible reasons might be no response, recitation, safety issue or blocked content. if you are getting this message multiple times, make a separate chat/private thread." 20 | DEV_ID = 501026469569363988 21 | logger = logging.getLogger(__name__) 22 | 23 | class ACD_AI_COG_ERROR(commands.CommandError): 24 | pass 25 | 26 | class ACD_AI(commands.Cog): 27 | def __init__(self, bot): 28 | self.bot = bot 29 | 30 | '''developer's discord user id (vedantmishra69)''' 31 | self.dev_id: int = DEV_ID 32 | 33 | genai.configure(api_key=GEMINI_API_KEY) 34 | 35 | '''gemini model for text inputs''' 36 | self.text_model = genai.GenerativeModel(model_name="gemini-pro", 37 | generation_config=settings.text_generation_config, 38 | safety_settings=settings.text_safety_settings) 39 | 40 | '''gemini model for image inputs''' 41 | self.image_model = genai.GenerativeModel(model_name="gemini-1.5-flash", 42 | generation_config=settings.image_generation_config, 43 | safety_settings=settings.image_safety_settings) 44 | 45 | '''to store chat session instances for each opened thread (thread_id: chat_session)''' 46 | self.chats: dict = {} 47 | 48 | @commands.Cog.listener() 49 | async def on_ready(self): 50 | '''to delete existing threads when the bot wakes up''' 51 | await self.delete_threads_on_start() 52 | 53 | @commands.group(brief='AI Chat Bot', 54 | invoke_without_command=True) 55 | @cf_common.user_guard(group='ai') 56 | async def ai(self,ctx,*args): 57 | ''' 58 | AI chat powered by Google's Gemini Pro. 59 | it responds to both text and image inputs, 60 | feel free to drop code snippets as well. 61 | ''' 62 | await ctx.send_help(ctx.command) 63 | 64 | async def delete_threads(self, channel_id): 65 | '''deletes thread for the given channel''' 66 | try: 67 | channel = await self.bot.fetch_channel(channel_id) 68 | for thread in channel.threads: 69 | if thread.owner.id == self.bot.user.id: 70 | await thread.delete() 71 | if thread.id in self.chats: self.chats.pop(thread.id) 72 | except Exception as e: 73 | logger.warn(f"Couldn't delete threads for {channel_id}: {e}") 74 | 75 | async def delete_threads_on_start(self): 76 | '''deletes existing threads when the bot wakes up''' 77 | try: 78 | async for guild in self.bot.fetch_guilds(): 79 | channel_id = cf_common.user_db.get_ai_channel(guild.id) 80 | if channel_id: await self.delete_threads(channel_id) 81 | except Exception as e: 82 | logger.warn(f"Couldn't fetch guilds\n{e}") 83 | 84 | def check_channel(self, ctx): 85 | '''checks if a channel has been set and the current channel is the set one or not''' 86 | channel_id = cf_common.user_db.get_ai_channel(ctx.guild.id) 87 | if not channel_id: 88 | raise ACD_AI_COG_ERROR('There is no ai channel. Set one with ``;ai set_channel``.') 89 | if ctx.channel.id != channel_id: 90 | raise ACD_AI_COG_ERROR(f"You must use this command in ai channel: <#{channel_id}>") 91 | 92 | async def print_response(self, response, message): 93 | '''prints the response to thread channel recieved from get_response()''' 94 | async with message.channel.typing(): 95 | strings = [] 96 | for index in range(0, len(response), 2000): 97 | strings.append(response[index: min(len(response), index + 2000)]) 98 | try: 99 | await message.reply(strings[0]) 100 | for string in strings[1:]: 101 | await message.channel.send(string) 102 | except Exception as e: 103 | message.reply("No response, discord issue :(") 104 | logger.warn(f"last message: {message.content}\n{e}") 105 | 106 | @sleep_and_retry 107 | @limits(calls=MAX_CALLS_PER_MINUTE, period=ONE_MINUTE) 108 | async def get_response(self, message, chat): 109 | ''' 110 | takes message object and chat session instance 111 | sends the message content (text/image) to gemini model 112 | sends the response recieved from gemini to print_response() 113 | ''' 114 | async with message.channel.typing(): 115 | image = None 116 | response = NO_RESPONSE_MESSAGE 117 | attachment = message.attachments[0] if message.attachments else None 118 | if not attachment: 119 | try: 120 | response = chat.send_message(message.content).text 121 | if not isinstance(response, str): 122 | raise Exception 123 | except Exception as e: 124 | logger.warn(f"last message: {message.content}\n{e}".format()) 125 | elif attachment.content_type.split('/')[0] == "image": 126 | if attachment.content_type.split('/')[1] in IMAGE_TYPES: 127 | try: 128 | image = Image.open(requests.get(attachment.url, stream = True).raw) 129 | response = self.image_model.generate_content([message.content, image] if message.content else image).text 130 | if not isinstance(response, str): 131 | raise Exception 132 | except Exception as e: 133 | response = "Unable to process that image." 134 | logger.warn(f"last message: {attachment.url}\n{e}".format()) 135 | else: response = "Invalid image format. please use JPEG, PNG, WEBP, HEIC or HEIF" 136 | else: response = "Attachment not supported." 137 | self.bot.loop.create_task(self.print_response(response, message)) 138 | 139 | async def start_thread(self, message, is_public): 140 | '''starts a thread on ;ai chat or ;ai private commands''' 141 | try: 142 | thread = await message.channel.create_thread(name = f"Session with {message.author.name}", slowmode_delay = 1, auto_archive_duration = 60, message = message if is_public else None) 143 | await thread.add_user(message.author) 144 | hello_text = self.text_model.generate_content("hi").text 145 | async with thread.typing(): 146 | await thread.send(content = f'<@{message.author.id}> {hello_text}') 147 | self.chats[str(thread.id)] = self.text_model.start_chat(history=[{'role':'user', 'parts': ["hi"]}, {'role': 'model', 'parts': [hello_text]}]) 148 | except Exception as e: 149 | await message.channel.send(embed=discord_common.embed_alert("Could not start a thread.")) 150 | logger.warn(e) 151 | 152 | @ai.command(brief='gets channel for ai.') 153 | async def get_channel(self, ctx): 154 | '''gets channel to be used for ai.''' 155 | channel_id = cf_common.user_db.get_ai_channel(ctx.guild.id) 156 | if not channel_id: 157 | raise ACD_AI_COG_ERROR('There is no ai channel. Set one with ``;ai set_channel``.') 158 | await ctx.send(embed=discord_common.embed_success(f"Current ai channel: <#{channel_id}>")) 159 | 160 | @ai.command(brief='sets channel for ai.') 161 | @commands.has_any_role(constants.TLE_ADMIN, constants.TLE_MODERATOR) 162 | async def set_channel(self, ctx): 163 | '''sets channel to be used for ai.''' 164 | channel_id = cf_common.user_db.get_ai_channel(ctx.guild.id) 165 | if channel_id: await self.delete_threads(channel_id) 166 | cf_common.user_db.set_ai_channel(ctx.guild.id, ctx.channel.id) 167 | await ctx.send(embed=discord_common.embed_success('AI channel saved successfully')) 168 | 169 | @ai.command(brief='deletes all active chat threads.') 170 | @commands.has_any_role(constants.TLE_ADMIN, constants.TLE_MODERATOR) 171 | async def clear(self, ctx): 172 | '''deletes all active chat and private threads''' 173 | self.check_channel(ctx) 174 | async with ctx.channel.typing(): 175 | await self.delete_threads(ctx.channel.id) 176 | await ctx.send(embed=discord_common.embed_success(f"All threads deleted for ``{ctx.channel.name}``.")) 177 | 178 | @commands.Cog.listener() 179 | async def on_message(self, message): 180 | '''sends user's input for processing to gemini''' 181 | if str(message.channel.id) in self.chats and message.author != self.bot.user: 182 | self.bot.loop.create_task(self.get_response(message, self.chats[str(message.channel.id)])) 183 | 184 | @ai.command(brief='creates a public chat thread.') 185 | async def chat(self, ctx): 186 | self.check_channel(ctx) 187 | await self.start_thread(ctx.message, True) 188 | 189 | @ai.command(brief='creates a private chat thread.') 190 | async def private(self, ctx): 191 | self.check_channel(ctx) 192 | await self.start_thread(ctx.message, False) 193 | 194 | @discord_common.send_error_if(ACD_AI_COG_ERROR, cf_common.ResolveHandleError, 195 | cf_common.FilterError) 196 | async def cog_command_error(self, ctx, error): 197 | pass 198 | 199 | async def setup(bot): 200 | await bot.add_cog(ACD_AI(bot)) -------------------------------------------------------------------------------- /tle/cogs/cache_control.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import time 3 | import traceback 4 | 5 | from discord.ext import commands 6 | 7 | from tle import constants 8 | from tle.util import codeforces_common as cf_common 9 | 10 | 11 | def timed_command(coro): 12 | @functools.wraps(coro) 13 | async def wrapper(cog, ctx, *args): 14 | await ctx.send('Running...') 15 | begin = time.time() 16 | await coro(cog, ctx, *args) 17 | elapsed = time.time() - begin 18 | await ctx.send(f'Completed in {elapsed:.2f} seconds') 19 | 20 | return wrapper 21 | 22 | 23 | class CacheControl(commands.Cog): 24 | """Cog to manually trigger update of cached data. Intended for dev/admin use.""" 25 | 26 | def __init__(self, bot): 27 | self.bot = bot 28 | 29 | @commands.group(brief='Commands to force reload of cache', 30 | invoke_without_command=True) 31 | @commands.has_role(constants.TLE_ADMIN) 32 | async def cache(self, ctx): 33 | await ctx.send_help('cache') 34 | 35 | @cache.command() 36 | @commands.has_role(constants.TLE_ADMIN) 37 | @timed_command 38 | async def contests(self, ctx): 39 | await cf_common.cache2.contest_cache.reload_now() 40 | 41 | @cache.command() 42 | @commands.has_role(constants.TLE_ADMIN) 43 | @timed_command 44 | async def problems(self, ctx): 45 | await cf_common.cache2.problem_cache.reload_now() 46 | 47 | @cache.command(usage='[missing|all|contest_id]') 48 | @commands.has_role(constants.TLE_ADMIN) 49 | @timed_command 50 | async def ratingchanges(self, ctx, contest_id='missing'): 51 | """Defaults to 'missing'. Mode 'all' clears existing cached changes. 52 | Mode 'contest_id' clears existing changes with the given contest id. 53 | """ 54 | if contest_id not in ('all', 'missing'): 55 | try: 56 | contest_id = int(contest_id) 57 | except ValueError: 58 | return 59 | if contest_id == 'all': 60 | await ctx.send('This will take a while') 61 | count = await cf_common.cache2.rating_changes_cache.fetch_all_contests() 62 | elif contest_id == 'missing': 63 | await ctx.send('This may take a while') 64 | count = await cf_common.cache2.rating_changes_cache.fetch_missing_contests() 65 | else: 66 | count = await cf_common.cache2.rating_changes_cache.fetch_contest(contest_id) 67 | await ctx.send(f'Done, fetched {count} changes and recached handle ratings') 68 | 69 | @cache.command(usage='contest_id|all') 70 | @commands.has_role(constants.TLE_ADMIN) 71 | @timed_command 72 | async def problemsets(self, ctx, contest_id): 73 | """Mode 'all' clears all existing cached problems. Mode 'contest_id' 74 | clears existing problems with the given contest id. 75 | """ 76 | if contest_id == 'all': 77 | await ctx.send('This will take a while') 78 | count = await cf_common.cache2.problemset_cache.update_for_all() 79 | else: 80 | try: 81 | contest_id = int(contest_id) 82 | except ValueError: 83 | return 84 | count = await cf_common.cache2.problemset_cache.update_for_contest(contest_id) 85 | await ctx.send(f'Done, fetched {count} problems') 86 | 87 | 88 | async def setup(bot): 89 | await bot.add_cog(CacheControl(bot)) 90 | -------------------------------------------------------------------------------- /tle/cogs/deactivated/cses.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import defaultdict 3 | 4 | from discord.ext import commands 5 | from tle.util import cses_scraper as cses 6 | from tle.util import discord_common 7 | from tle.util import table 8 | from tle.util import tasks 9 | 10 | 11 | def score(placings): 12 | points = {1: 8, 2: 5, 3: 3, 4: 2, 5: 1} 13 | #points = {1:5, 2:4, 3:3, 4:2, 5:1} 14 | return sum(points[rank] for rank in placings) 15 | 16 | 17 | class CSES(commands.Cog): 18 | def __init__(self, bot): 19 | self.bot = bot 20 | self.short_placings = {} 21 | self.fast_placings = {} 22 | self.reloading = False 23 | 24 | @commands.Cog.listener() 25 | @discord_common.once 26 | async def on_ready(self): 27 | self._cache_data.start() 28 | 29 | @tasks.task_spec(name='ProblemsetCacheUpdate', 30 | waiter=tasks.Waiter.fixed_delay(30*60)) 31 | async def _cache_data(self, _): 32 | await self._reload() 33 | 34 | async def _reload(self): 35 | self.reloading = True 36 | short_placings = defaultdict(list) 37 | fast_placings = defaultdict(list) 38 | try: 39 | for pid in await cses.get_problems(): 40 | fast, short = await cses.get_problem_leaderboard(pid) 41 | for i in range(len(fast)): 42 | fast_placings[fast[i]].append(i + 1) 43 | for i in range(len(short)): 44 | short_placings[short[i]].append(i + 1) 45 | self.short_placings = short_placings 46 | self.fast_placings = fast_placings 47 | finally: 48 | self.reloading = False 49 | 50 | def format_leaderboard(self, top, placings): 51 | if not top: 52 | return 'Failed to load :<' 53 | 54 | header = ' 1st 2nd 3rd 4th 5th '.split(' ') 55 | 56 | style = table.Style( 57 | header = '{:>} {:>} {:>} {:>} {:>} {:>} {:>}', 58 | body = '{:>} | {:>} {:>} {:>} {:>} {:>} | {:>} pts' 59 | ) 60 | 61 | t = table.Table(style) 62 | t += table.Header(*header) 63 | 64 | for user, points in top: 65 | hist = [placings[user].count(i + 1) for i in range(5)] 66 | t += table.Data(user, *hist, points) 67 | 68 | return str(t) 69 | 70 | def leaderboard(self, placings, num): 71 | leaderboard = sorted( 72 | ((k, score(v)) for k, v in placings.items() if k != 'N/A'), 73 | key=lambda x: x[1], 74 | reverse=True) 75 | 76 | top = leaderboard[:num] 77 | 78 | return self.format_leaderboard(top, placings) 79 | 80 | def leaderboard_individual(self, placings, handles): 81 | leaderboard = sorted( 82 | ((k, score(v)) for k, v in placings.items() if k != 'N/A' and k in handles), 83 | key=lambda x: x[1], 84 | reverse=True) 85 | 86 | included = [handle for handle, score in leaderboard] 87 | leaderboard += [(handle, 0) for handle in handles if handle not in included] 88 | 89 | top = leaderboard 90 | 91 | return self.format_leaderboard(top, placings) 92 | 93 | @property 94 | def fastest(self, num=10): 95 | return self.leaderboard(self.fast_placings, num) 96 | 97 | @property 98 | def shortest(self, num=10): 99 | return self.leaderboard(self.short_placings, num) 100 | 101 | def fastest_individual(self, handles): 102 | return self.leaderboard_individual(self.fast_placings, handles) 103 | 104 | def shortest_individual(self, handles): 105 | return self.leaderboard_individual(self.short_placings, handles) 106 | 107 | @commands.command(brief='Shows compiled CSES leaderboard', usage='[handles...]') 108 | async def cses(self, ctx, *handles: str): 109 | """Shows compiled CSES leaderboard. If handles are given, leaderboard will contain only those indicated handles, otherwise leaderboard will contain overall top ten.""" 110 | if not handles: 111 | await ctx.send('```\n' 'Fastest\n' + self.fastest + '\n\n' + 'Shortest\n' + self.shortest + '\n' + '```') 112 | elif len(handles) > 10: 113 | await ctx.send('```Please indicate at most 10 users```') 114 | else: 115 | handles = set(handles) 116 | await ctx.send('```\n' 'Fastest\n' + self.fastest_individual(handles) + '\n\n' + 'Shortest\n' + self.shortest_individual(handles) + '\n' + '```') 117 | 118 | @commands.command(brief='Force update the CSES leaderboard') 119 | async def _updatecses(self, ctx): 120 | """Shows compiled CSES leaderboard.""" 121 | if self.reloading: 122 | await ctx.send("Have some patience, I'm already reloading!") 123 | else: 124 | await self._reload() 125 | await ctx.send('CSES leaderboards updated!') 126 | 127 | 128 | async def setup(bot): 129 | await bot.add_cog(CSES(bot)) 130 | -------------------------------------------------------------------------------- /tle/cogs/hard75Challenge.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | from typing import List 4 | import math 5 | import time 6 | from collections import defaultdict 7 | 8 | import discord 9 | from discord.ext import commands 10 | 11 | from tle import constants 12 | from tle.util import codeforces_api as cf 13 | from tle.util import codeforces_common as cf_common 14 | from tle.util import discord_common 15 | from tle.util.db.user_db_conn import Gitgud 16 | from tle.util import paginator 17 | from tle.util import cache_system2 18 | from tle.util import table 19 | from tle.util import ACDLaddersProblems as acdProbs 20 | 21 | class Hard75CogError(commands.CommandError): 22 | pass 23 | 24 | class Hard75Challenge(commands.Cog): 25 | def __init__(self, bot): 26 | self.bot = bot 27 | self.converter = commands.MemberConverter() 28 | 29 | @commands.group(brief='Hard 75 challenge', 30 | invoke_without_command=True) 31 | @cf_common.user_guard(group='hard75') 32 | async def hard75(self,ctx,*args): 33 | """ 34 | Hard75 is a challenge mode. The goal is to solve 2 codeforces problems every day for 75 days. 35 | You can request your daily problems by using `;hard75 letsgo` 36 | If you manage to solve both problem before midnight (UTC) use `;hard75 completed` to increaes your current streak. 37 | If you don't solve both problems or miss a single day your current streak will reset back to 0. 38 | The bot will keep track of your streak (current and longest) and there is also a leaderboard. 39 | """ 40 | await ctx.send_help(ctx.command) 41 | 42 | async def _postProblemEmbed(self, ctx, problem_name): 43 | problem = cf_common.cache2.problem_cache.problem_by_name[problem_name] 44 | title = f'{problem.index}. {problem.name}' 45 | desc = cf_common.cache2.contest_cache.get_contest(problem.contestId).name 46 | embed = discord.Embed(title=title, url=problem.url, description=desc) 47 | embed.add_field(name='Rating', value=problem.rating) 48 | await ctx.send(embed=embed) 49 | 50 | 51 | async def _checkAcdProbs(self,rating,submissions): 52 | solved = {sub.problem.name for sub in submissions} 53 | problems = [prob for prob in acdProbs.getProblems(rating) 54 | if (prob['name'] not in solved)] 55 | 56 | if not problems: 57 | return {} 58 | class dotdict(dict): 59 | """dot.notation access to dictionary attributes""" 60 | __getattr__ = dict.get 61 | __setattr__ = dict.__setitem__ 62 | __delattr__ = dict.__delitem__ 63 | for problem in problems: 64 | if (cf_common.cache2.problem_cache.problem_by_name[problem['name']].rating ==rating 65 | and cf_common.cache2.problem_cache.problem_by_name[problem['name']].contestId == problem['contestId']): 66 | return dotdict(problem) 67 | return {} 68 | 69 | 70 | async def _pickProblem(self, handle, rating, submissions): 71 | #if a ACD Ladder problem is available then give that! 72 | acdProblem=await self._checkAcdProbs(rating,submissions) 73 | if(len(acdProblem)): 74 | return acdProblem 75 | solved = {sub.problem.name for sub in submissions} 76 | problems = [prob for prob in cf_common.cache2.problem_cache.problems 77 | if (prob.rating == rating 78 | and prob.name not in solved)] 79 | 80 | def check(problem): # check that the user isn't the author and it's not a nonstanard problem 81 | return (not cf_common.is_nonstandard_problem(problem) and 82 | not cf_common.is_contest_writer(problem.contestId, handle)) 83 | 84 | problems = list(filter(check, problems)) 85 | if not problems: 86 | raise Hard75CogError('Great! You have finished all available problems, do atcoder now lol!') 87 | 88 | problems.sort(key=lambda problem: cf_common.cache2.contest_cache.get_contest(problem.contestId).startTimeSeconds) 89 | choice = max(random.randrange(len(problems)) for _ in range(5)) 90 | return problems[choice] 91 | 92 | async def _checkProblemsSolved(self, handle, p1_name, p2_name): 93 | submissions = await cf.user.status(handle=handle) 94 | solved = {sub.problem.name for sub in submissions if sub.verdict == 'OK'} 95 | return p1_name in solved,p2_name in solved 96 | 97 | def _generateStreakEmbed(self, handle, current_streak, longest_streak, last_updated): 98 | embed = discord.Embed(title=f'{handle}s Hard75 grind!') 99 | today=datetime.datetime.utcnow().strftime('%Y-%m-%d') 100 | last_updated_str = "never" if last_updated=='0' else last_updated 101 | last_updated_str = "today" if last_updated==today else last_updated_str 102 | embed.add_field(name='current streak', value=current_streak) 103 | embed.add_field(name='longest streak', value=longest_streak) 104 | embed.add_field(name='last problem solved', value=last_updated_str) 105 | return embed 106 | 107 | @hard75.command(brief='Get Hard75 leaderboard', aliases=['lb', 'ranklist']) 108 | @cf_common.user_guard(group='hard75') 109 | async def leaderboard(self,ctx): 110 | """ 111 | Ranklist of the top contestants (based on longest streak) 112 | """ 113 | data = [(ctx.guild.get_member(int(user_id)), longest_streak, current_streak) 114 | for user_id, longest_streak, current_streak in cf_common.user_db.get_hard75_LeaderBoard()] 115 | data = [(member, longest_streak, current_streak) 116 | for member, longest_streak, current_streak in data 117 | if member is not None] 118 | if not data: 119 | raise Hard75CogError('No One has completed anything as of now - leaderboard is empty!') 120 | 121 | _PER_PAGE = 10 122 | 123 | def make_page(chunk, page_num): 124 | style = table.Style('{:>} {:<} {:>} {:>}') 125 | t = table.Table(style) 126 | t += table.Header('#', 'Name', 'Longest', 'Current') 127 | t += table.Line() 128 | for index, (member, longestStreak, currentStreak) in enumerate(chunk): 129 | lstreakstr = f'{longestStreak}' 130 | cstreakstr = f'{currentStreak}' 131 | memberstr = f'{member.display_name}' 132 | t += table.Data(_PER_PAGE * page_num + index + 1, 133 | memberstr, lstreakstr, cstreakstr) 134 | 135 | table_str = f'```\n{t}\n```' 136 | embed = discord_common.cf_color_embed(description = table_str) 137 | return 'Leaderboard', embed 138 | 139 | pages = [make_page(chunk, k) for k, chunk in enumerate( 140 | paginator.chunkify(data, _PER_PAGE))] 141 | paginator.paginate(self.bot, ctx.channel, pages, 142 | wait_time=5 * 60, set_pagenum_footers=True) 143 | 144 | 145 | @hard75.command(brief='Get users streak statistics', aliases=['st'], usage='[@member|user_id]') 146 | @cf_common.user_guard(group='hard75') 147 | async def streak(self,ctx, member: discord.Member = None): 148 | """ 149 | See the progress of @member on the challenge. If member is not given you see your progress. 150 | """ 151 | user_id = member.id if member else ctx.author.id 152 | handle, = await cf_common.resolve_handles(ctx, self.converter, ('!' + str(user_id),)) 153 | res=cf_common.user_db.get_hard75_status(user_id) 154 | if res is None: 155 | raise Hard75CogError(f'{member.display_name} hasn\'t started the Hard75 challenge (`;hard75 letsgo`)') 156 | current_streak,longest_streak,last_updated=res 157 | 158 | embed = self._generateStreakEmbed(handle, current_streak, longest_streak, last_updated) 159 | await ctx.send(f'Thanks for participating in the challenge!', embed=embed) 160 | 161 | @hard75.command(brief='Plots rating change during hard75', aliases=['prog'], usage='[@member|user_id]') 162 | @cf_common.user_guard(group='hard75') 163 | async def progress(self, ctx, member: discord.Member = None): 164 | """ 165 | Plots the rating graph for member during hard75 challenge period. 166 | """ 167 | user_id = member.id if member else ctx.author.id 168 | user_name = member.display_name if member else ctx.author.display_name 169 | handle, = await cf_common.resolve_handles(ctx, self.converter, ('!' + str(user_id),)) 170 | user = cf_common.user_db.fetch_cf_user(handle) 171 | if not user.maxRating: 172 | raise Hard75CogError(f'User {handle} is not rated') 173 | dates = cf_common.user_db.get_Hard75Window(user_id) 174 | if not dates: 175 | raise Hard75CogError(f'{user_name} hasn\'t started the Hard75 challenge (`;hard75 letsgo`)') 176 | first_date, last_date = dates 177 | if len(last_date) == 1: 178 | raise Hard75CogError(f'{user_name} needs to complete atleast one challenge (`;hard75 completed`)') 179 | yy1, mm1, dd1 = first_date.split('-') 180 | yy2, mm2, dd2 = last_date.split('-') 181 | start_date = str(f"d>={dd1}{mm1}{yy1}") 182 | end_date = str(f"d<{dd2}{mm2}{yy2}") 183 | grpahs = self.bot.get_cog('Graphs') 184 | for command in grpahs.walk_commands(): 185 | if command.name == 'rating': 186 | await command.__call__(ctx, handle, start_date, end_date) 187 | break 188 | 189 | 190 | 191 | 192 | @hard75.command(brief='Request hard75 problems for today', aliases=['start']) 193 | @cf_common.user_guard(group='hard75') 194 | async def letsgo(self,ctx): 195 | """ 196 | Assigns 2 problems per day (would be fetched from ACDLadders later) 197 | 1. same level* 198 | 2. level+ 200* 199 | *-> both of them are rounded to the nearest 100 200 | """ 201 | handle, = await cf_common.resolve_handles(ctx, self.converter, ('!' + str(ctx.author),)) 202 | user = cf_common.user_db.fetch_cf_user(handle) 203 | user_id = ctx.author.id 204 | today=datetime.datetime.utcnow().strftime('%Y-%m-%d') 205 | activeChallenge = cf_common.user_db.check_Hard75Challenge(user_id, today) 206 | if activeChallenge: # problems are already there simply return from the DB 207 | c1_id,p1_id,p1_name,c2_id,p2_id,p2_name=cf_common.user_db.get_Hard75Challenge(user_id, today) 208 | p1_solved, p2_solved = await self._checkProblemsSolved(handle, p1_name, p2_name) 209 | if p1_solved and p2_solved: 210 | # TODO: make function for it and use beautifier for printing 211 | dt = datetime.datetime.now() 212 | timeLeft=((24 - dt.hour - 1) * 60 * 60) + ((60 - dt.minute - 1) * 60) + (60 - dt.second) 213 | h=int(timeLeft/3600) 214 | m=int((timeLeft-h*3600)/60) 215 | embed = discord.Embed(title="Life isn't just about coding!",description=f"You need to wait {handle}!") 216 | embed.add_field(name='Time Remaining for next challenge', value=f"{h} Hours : {m} Mins") 217 | await ctx.send(f'You have already completed todays challenge! Life isn\'t just about coding!! Go home, talk to family and friends, touch grass, hit the gym!', embed=embed) 218 | return 219 | #else return that problems have already been assigned. 220 | await ctx.send(f'You have already been assigned the problems for [`{datetime.datetime.utcnow().strftime("%Y-%m-%d")}`] `{handle}` ') 221 | await self._postProblemEmbed(ctx, p1_name) 222 | await self._postProblemEmbed(ctx, p2_name) 223 | return 224 | rating = round(user.effective_rating, -2) 225 | rating = max(800, rating) 226 | rating = min(3000, rating) 227 | rating1 = rating # this is the rating for the problem 1 228 | rating2 = rating1+200 # this is the rating for the problem 2 229 | submissions = await cf.user.status(handle=handle) 230 | problem1 = await self._pickProblem(handle, rating1, submissions) 231 | problem2 = await self._pickProblem(handle, rating2, submissions) 232 | res=cf_common.user_db.new_Hard75Challenge(user_id,handle,problem1.index,problem1.contestId,problem1.name,problem2.index,problem2.contestId,problem2.name,user.effective_rating, today) 233 | if res!=1: 234 | raise Hard75CogError("Issues while writing to db please contact mod team!") 235 | await ctx.send(f'Hard75 problems for `{handle}` [`{datetime.datetime.utcnow().strftime("%Y-%m-%d")}`]') 236 | await self._postProblemEmbed(ctx, problem1.name) 237 | await self._postProblemEmbed(ctx, problem2.name) 238 | 239 | 240 | @hard75.command(brief='Mark hard75 problems for today as completed', aliases=['done']) 241 | @cf_common.user_guard(group='hard75') 242 | async def completed(self, ctx): 243 | """ 244 | Use this command once you have completed both of your daily problems 245 | """ 246 | handle, = await cf_common.resolve_handles(ctx, self.converter, ('!' + str(ctx.author),)) 247 | user_id = ctx.message.author.id 248 | today=datetime.datetime.utcnow().strftime('%Y-%m-%d') 249 | activeChallenge = cf_common.user_db.check_Hard75Challenge(user_id, today) 250 | if not activeChallenge: 251 | raise Hard75CogError(f'You have not been assigned any problems today! Use `;hard75 letsgo` to get the pair of problems!') 252 | 253 | c1_id,p1_id,p1_name,c2_id,p2_id,p2_name=cf_common.user_db.get_Hard75Challenge(user_id, today) 254 | p1_solved,p2_solved = await self._checkProblemsSolved(handle, p1_name, p2_name) 255 | 256 | if not p1_solved and not p2_solved: 257 | await ctx.send('You haven\'t completed any of the problems!') 258 | await self._postProblemEmbed(ctx, p1_name) 259 | await self._postProblemEmbed(ctx, p2_name) 260 | return 261 | if not p1_solved: 262 | await ctx.send('You haven\'t completed the easy problem!') 263 | await self._postProblemEmbed(ctx, p1_name) 264 | return 265 | if not p2_solved: 266 | await ctx.send('You haven\'t completed the hard problem!') 267 | await self._postProblemEmbed(ctx, p2_name) 268 | return 269 | 270 | # else update accordingly DB 271 | assigned_date,last_update=cf_common.user_db.get_Hard75Date(user_id) 272 | if(last_update==today): 273 | raise Hard75CogError(f"Your progress has already been updated for `{today}`") 274 | if(assigned_date!=today): 275 | await ctx.send(f"OOPS! you didn't solve the problems in the 24H window! You were required to solve it on `{assigned_date}`") 276 | 277 | # else the user has completed his task on the given day hence let's update it 278 | current_streak, longest_streak=cf_common.user_db.get_Hard75UserStat(user_id) 279 | 280 | yesterday=datetime.datetime.utcnow()-datetime.timedelta(days=1) 281 | yesterday=yesterday.strftime('%Y-%m-%d') 282 | 283 | #check if streak continues! 284 | if(last_update==yesterday): 285 | current_streak+=1 286 | else: 287 | current_streak=0 288 | if(current_streak==0): # on first day! 289 | current_streak=1 290 | if current_streak % 75 == 0: 291 | await ctx.guild.get_member(user_id).add_roles(discord.utils.get(ctx.guild.roles, name = "Hard75 x 1")) 292 | 293 | longest_streak=max(current_streak,longest_streak) 294 | rc=cf_common.user_db.updateStreak_Hard75Challenge(user_id,current_streak,longest_streak, today) 295 | if(rc!=1): 296 | raise Hard75CogError('Some issue while monitoring progress! Please contact the mod team!.') 297 | 298 | embed = self._generateStreakEmbed(handle, current_streak, longest_streak, today) 299 | 300 | # mention an embed which includes the streak day of the user! 301 | await ctx.send(f'Congratulations `{handle}`! You have completed your daily challenge ', embed=embed) 302 | 303 | 304 | 305 | @discord_common.send_error_if(Hard75CogError, cf_common.ResolveHandleError, 306 | cf_common.FilterError) 307 | async def cog_command_error(self, ctx, error): 308 | pass 309 | 310 | 311 | async def setup(bot): 312 | await bot.add_cog(Hard75Challenge(bot)) 313 | -------------------------------------------------------------------------------- /tle/cogs/logging.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | 5 | from discord.ext import commands 6 | 7 | from tle.util import discord_common 8 | 9 | root_logger = logging.getLogger() 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Logging(commands.Cog, logging.Handler): 14 | def __init__(self, bot, channel_id): 15 | logging.Handler.__init__(self) 16 | self.bot = bot 17 | self.channel_id = channel_id 18 | self.queue = asyncio.Queue() 19 | self.task = None 20 | self.logger = logging.getLogger(self.__class__.__name__) 21 | 22 | @commands.Cog.listener() 23 | @discord_common.once 24 | async def on_ready(self): 25 | self.task = asyncio.create_task(self._log_task()) 26 | width = 79 27 | stars, msg = f'{"*" * width}', f'***{"Bot running":^{width - 6}}***' 28 | self.logger.log(level=100, msg=stars) 29 | self.logger.log(level=100, msg=msg) 30 | self.logger.log(level=100, msg=stars) 31 | 32 | async def _log_task(self): 33 | while True: 34 | record = await self.queue.get() 35 | channel = self.bot.get_channel(self.channel_id) 36 | if channel is None: 37 | # Channel no longer exists. 38 | root_logger.removeHandler(self) 39 | self.logger.warning('Logging channel not available, disabling Discord log handler.') 40 | break 41 | try: 42 | msg = self.format(record) 43 | # Not all errors will have message_contents or jump urls. 44 | try: 45 | await channel.send( 46 | 'Original Command: {}\nJump Url: {}'.format( 47 | record.message_content, record.jump_url)) 48 | except AttributeError: 49 | pass 50 | discord_msg_char_limit = 2000 51 | char_limit = discord_msg_char_limit - 2 * len('```') 52 | too_long = len(msg) > char_limit 53 | msg = msg[:char_limit] 54 | await channel.send('```{}```'.format(msg)) 55 | if too_long: 56 | await channel.send('`Check logs for full stack trace`') 57 | except: 58 | self.handleError(record) 59 | 60 | # logging.Handler overrides below. 61 | 62 | def emit(self, record): 63 | self.queue.put_nowait(record) 64 | 65 | def close(self): 66 | if self.task: 67 | self.task.cancel() 68 | 69 | 70 | async def setup(bot): 71 | logging_cog_channel_id = os.environ.get('LOGGING_COG_CHANNEL_ID') 72 | if logging_cog_channel_id is None: 73 | logger.info('Skipping installation of logging cog as logging channel is not provided.') 74 | return 75 | 76 | logging_cog = Logging(bot, int(logging_cog_channel_id)) 77 | logging_cog.setLevel(logging.WARNING) 78 | logging_cog.setFormatter(logging.Formatter(fmt='{asctime}:{levelname}:{name}:{message}', 79 | style='{', datefmt='%d-%m-%Y %H:%M:%S')) 80 | root_logger.addHandler(logging_cog) 81 | await bot.add_cog(logging_cog) 82 | -------------------------------------------------------------------------------- /tle/cogs/meta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import time 5 | import textwrap 6 | 7 | from discord.ext import commands 8 | 9 | from tle import constants 10 | from tle.util.codeforces_common import pretty_time_format 11 | 12 | RESTART = 42 13 | 14 | 15 | # Adapted from numpy sources. 16 | # https://github.com/numpy/numpy/blob/master/setup.py#L64-85 17 | def git_history(): 18 | def _minimal_ext_cmd(cmd): 19 | # construct minimal environment 20 | env = {} 21 | for k in ['SYSTEMROOT', 'PATH']: 22 | v = os.environ.get(k) 23 | if v is not None: 24 | env[k] = v 25 | # LANGUAGE is used on win32 26 | env['LANGUAGE'] = 'C' 27 | env['LANG'] = 'C' 28 | env['LC_ALL'] = 'C' 29 | out = subprocess.Popen(cmd, stdout = subprocess.PIPE, env=env).communicate()[0] 30 | return out 31 | try: 32 | out = _minimal_ext_cmd(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 33 | branch = out.strip().decode('ascii') 34 | out = _minimal_ext_cmd(['git', 'log', '--oneline', '-5']) 35 | history = out.strip().decode('ascii') 36 | return ( 37 | 'Branch:\n' + 38 | textwrap.indent(branch, ' ') + 39 | '\nCommits:\n' + 40 | textwrap.indent(history, ' ') 41 | ) 42 | except OSError: 43 | return "Fetching git info failed" 44 | 45 | 46 | class Meta(commands.Cog): 47 | def __init__(self, bot): 48 | self.bot = bot 49 | self.start_time = time.time() 50 | 51 | @commands.group(brief='Bot control', invoke_without_command=True) 52 | async def meta(self, ctx): 53 | """Command the bot or get information about the bot.""" 54 | await ctx.send_help(ctx.command) 55 | 56 | @meta.command(brief='Restarts TLE') 57 | @commands.has_role(constants.TLE_ADMIN) 58 | async def restart(self, ctx): 59 | """Restarts the bot.""" 60 | # Really, we just exit with a special code 61 | # the magic is handled elsewhere 62 | await ctx.send('Restarting...') 63 | os._exit(RESTART) 64 | 65 | @meta.command(brief='Kill TLE') 66 | @commands.has_role(constants.TLE_ADMIN) 67 | async def kill(self, ctx): 68 | """Restarts the bot.""" 69 | await ctx.send('Dying...') 70 | os._exit(0) 71 | 72 | @meta.command(brief='Is TLE up?') 73 | async def ping(self, ctx): 74 | """Replies to a ping.""" 75 | start = time.perf_counter() 76 | message = await ctx.send(':ping_pong: Pong!') 77 | end = time.perf_counter() 78 | duration = (end - start) * 1000 79 | await message.edit(content=f'REST API latency: {int(duration)}ms\n' 80 | f'Gateway API latency: {int(self.bot.latency * 1000)}ms') 81 | 82 | @meta.command(brief='Get git information') 83 | async def git(self, ctx): 84 | """Replies with git information.""" 85 | await ctx.send('```yaml\n' + git_history() + '```') 86 | 87 | @meta.command(brief='Prints bot uptime') 88 | async def uptime(self, ctx): 89 | """Replies with how long TLE has been up.""" 90 | await ctx.send('TLE has been running for ' + 91 | pretty_time_format(time.time() - self.start_time)) 92 | 93 | @meta.command(brief='Print bot guilds') 94 | @commands.has_role(constants.TLE_ADMIN) 95 | async def guilds(self, ctx): 96 | "Replies with info on the bot's guilds" 97 | msg = [f'Guild ID: {guild.id} | Name: {guild.name} | Owner: {guild.owner.id} | Icon: {guild.icon}' 98 | for guild in self.bot.guilds] 99 | await ctx.send('```' + '\n'.join(msg) + '```') 100 | 101 | 102 | async def setup(bot): 103 | await bot.add_cog(Meta(bot)) 104 | -------------------------------------------------------------------------------- /tle/cogs/ref_bot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | from discord.ext import commands 4 | from tle.util import discord_common 5 | import logging 6 | from tle.util import codeforces_common as cf_common 7 | from tle import constants 8 | 9 | logger = logging.getLogger(__name__) 10 | RATING_LIMIT = 1500 11 | 12 | class ReferralBotCogError(commands.CommandError): 13 | pass 14 | 15 | class ReferralBot(commands.Cog): 16 | def __init__(self, bot): 17 | self.bot = bot 18 | self.url = "https://ref-portal-indol.vercel.app/api/secret?cfUserName={cf_handle}&discordId={discord_id}" 19 | self.converter = commands.MemberConverter() 20 | 21 | @commands.group(brief='Referral bot', 22 | invoke_without_command=True) 23 | @cf_common.user_guard(group='ref') 24 | async def ref(self,ctx,*args): 25 | """ 26 | Submits request for referral on ref portal. 27 | """ 28 | await ctx.send_help(ctx.command) 29 | 30 | @ref.command(brief='gets channel for referral request.') 31 | async def get_channel(self, ctx): 32 | """ 33 | Gets channel to be used for requesting referral form. 34 | """ 35 | channel_id = cf_common.user_db.get_ref_channel(ctx.guild.id) 36 | if not channel_id: 37 | raise ReferralBotCogError('There is no referral channel. Set one with ``;ref set_channel``.') 38 | await ctx.send(embed=discord_common.embed_success(f"Current referral channel: <#{channel_id}>")) 39 | 40 | @ref.command(brief='sets channel for referral request.') 41 | @commands.has_any_role(constants.TLE_ADMIN, constants.TLE_MODERATOR) 42 | async def set_channel(self, ctx): 43 | """ 44 | Sets channel to be used for requesting referral form. 45 | """ 46 | cf_common.user_db.set_ref_channel(ctx.guild.id, ctx.channel.id) 47 | await ctx.send(embed=discord_common.embed_success('Referral channel saved successfully')) 48 | 49 | @ref.command(brief='requests referral form.') 50 | @cf_common.user_guard(group='ref') 51 | async def get(self, ctx): 52 | """ 53 | Requests referral form from the ref portal. 54 | """ 55 | channel_id = cf_common.user_db.get_ref_channel(ctx.guild.id) 56 | if not channel_id: 57 | raise ReferralBotCogError('There is no referral channel. Set one with ``;ref set_channel``.') 58 | if ctx.channel.id != channel_id: 59 | raise ReferralBotCogError(f"You must use this command in referral channel: <#{channel_id}>") 60 | cf_handle, = await cf_common.resolve_handles(ctx, self.converter, ('!' + str(ctx.author.id),)) 61 | discord_id = ctx.author.name 62 | url = self.url.format(cf_handle=cf_handle, discord_id=discord_id) 63 | user = cf_common.user_db.fetch_cf_user(cf_handle) 64 | if user.maxRating < RATING_LIMIT: 65 | await ctx.reply(embed=discord_common.embed_alert(f"You need to have your maximum codeforces rating >= {RATING_LIMIT}.")) 66 | else: 67 | dm_channel = await ctx.author.create_dm() 68 | payload = { 69 | "secretKey": "6QGMP4QD8amDPnTBC3Tfwo8L4Ckny4Cl", 70 | "secretAccessKey": "M4MICU67LFq5UH2NLaLSgbOaRBjliuO5" 71 | } 72 | try: 73 | response = json.loads(requests.get(url, headers=payload).text) 74 | if "url" in response: 75 | res = response["url"] 76 | await dm_channel.send(embed=discord_common.embed_success(f"Here is your referral form link: {res}\n\nPlease note that entering an **invalid job id** may result in your **banishment** from the server.")) 77 | await ctx.reply(embed=discord_common.embed_success("Sent!")) 78 | else: await dm_channel.send(embed=discord_common.embed_alert("No URL available.")) 79 | except Exception as e: 80 | ctx.reply(embed=discord_common.embed_alert("No response from the server.")) 81 | logger.warn(e) 82 | 83 | @discord_common.send_error_if(ReferralBotCogError, cf_common.ResolveHandleError, 84 | cf_common.FilterError) 85 | async def cog_command_error(self, ctx, error): 86 | pass 87 | 88 | 89 | 90 | async def setup(bot): 91 | await bot.add_cog(ReferralBot(bot)) 92 | -------------------------------------------------------------------------------- /tle/cogs/starboard.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import discord 5 | from discord.ext import commands 6 | 7 | from tle import constants 8 | from tle.util import codeforces_common as cf_common 9 | from tle.util import discord_common 10 | 11 | _STAR = '\N{WHITE MEDIUM STAR}' 12 | _STAR_ORANGE = 0xffaa10 13 | _STAR_THRESHOLD = 5 14 | 15 | 16 | class StarboardCogError(commands.CommandError): 17 | pass 18 | 19 | 20 | class Starboard(commands.Cog): 21 | def __init__(self, bot): 22 | self.bot = bot 23 | self.locks = {} 24 | self.logger = logging.getLogger(self.__class__.__name__) 25 | 26 | @commands.Cog.listener() 27 | async def on_raw_reaction_add(self, payload): 28 | if str(payload.emoji) != _STAR or payload.guild_id is None: 29 | return 30 | res = cf_common.user_db.get_starboard(payload.guild_id) 31 | if res is None: 32 | return 33 | starboard_channel_id = int(res[0]) 34 | try: 35 | await self.check_and_add_to_starboard(starboard_channel_id, payload) 36 | except StarboardCogError as e: 37 | self.logger.info(f'Failed to starboard: {e!r}') 38 | 39 | @commands.Cog.listener() 40 | async def on_raw_message_delete(self, payload): 41 | if payload.guild_id is None: 42 | return 43 | res = cf_common.user_db.get_starboard(payload.guild_id) 44 | if res is None: 45 | return 46 | starboard_channel_id = int(res[0]) 47 | if payload.channel_id != starboard_channel_id: 48 | return 49 | cf_common.user_db.remove_starboard_message(starboard_msg_id=payload.message_id) 50 | self.logger.info(f'Removed message {payload.message_id} from starboard') 51 | 52 | @staticmethod 53 | def prepare_embed(message): 54 | # Adapted from https://github.com/Rapptz/RoboDanny/blob/rewrite/cogs/stars.py 55 | embed = discord.Embed(color=_STAR_ORANGE, timestamp=message.created_at) 56 | embed.add_field(name='Channel', value=message.channel.mention) 57 | embed.add_field(name='Jump to', value=f'[Original]({message.jump_url})') 58 | 59 | if message.content: 60 | embed.add_field(name='Content', value=message.content, inline=False) 61 | 62 | if message.embeds: 63 | data = message.embeds[0] 64 | if data.type == 'image': 65 | embed.set_image(url=data.url) 66 | 67 | if message.attachments: 68 | file = message.attachments[0] 69 | if file.url.lower().endswith(('png', 'jpeg', 'jpg', 'gif', 'webp')): 70 | embed.set_image(url=file.url) 71 | else: 72 | embed.add_field(name='Attachment', value=f'[{file.filename}]({file.url})', inline=False) 73 | 74 | embed.set_footer(text=str(message.author), icon_url=message.author.avatar) 75 | return embed 76 | 77 | async def check_and_add_to_starboard(self, starboard_channel_id, payload): 78 | guild = self.bot.get_guild(payload.guild_id) 79 | starboard_channel = guild.get_channel(starboard_channel_id) 80 | if starboard_channel is None: 81 | raise StarboardCogError('Starboard channel not found') 82 | 83 | channel = self.bot.get_channel(payload.channel_id) 84 | message = await channel.fetch_message(payload.message_id) 85 | if ((message.type != discord.MessageType.default and message.type != discord.MessageType.reply) 86 | or (len(message.content) == 0 and len(message.attachments) == 0)): 87 | raise StarboardCogError('Cannot starboard this message') 88 | 89 | reaction_count = sum(reaction.count for reaction in message.reactions 90 | if str(reaction) == _STAR) 91 | if reaction_count < _STAR_THRESHOLD: 92 | return 93 | lock = self.locks.get(payload.guild_id) 94 | if lock is None: 95 | self.locks[payload.guild_id] = lock = asyncio.Lock() 96 | 97 | async with lock: 98 | if cf_common.user_db.check_exists_starboard_message(message.id): 99 | return 100 | embed = self.prepare_embed(message) 101 | starboard_message = await starboard_channel.send(embed=embed) 102 | cf_common.user_db.add_starboard_message(message.id, starboard_message.id, guild.id) 103 | self.logger.info(f'Added message {message.id} to starboard (Last reaction by {payload.user_id})') 104 | 105 | @commands.group(brief='Starboard commands', 106 | invoke_without_command=True) 107 | async def starboard(self, ctx): 108 | """Group for commands involving the starboard.""" 109 | await ctx.send_help(ctx.command) 110 | 111 | @starboard.command(brief='Set starboard to current channel') 112 | @commands.has_role(constants.TLE_ADMIN) 113 | async def here(self, ctx): 114 | """Set the current channel as starboard.""" 115 | res = cf_common.user_db.get_starboard(ctx.guild.id) 116 | if res is not None: 117 | raise StarboardCogError('The starboard channel is already set. Use `clear` before ' 118 | 'attempting to set a different channel as starboard.') 119 | cf_common.user_db.set_starboard(ctx.guild.id, ctx.channel.id) 120 | await ctx.send(embed=discord_common.embed_success('Starboard channel set')) 121 | 122 | @starboard.command(brief='Clear starboard settings') 123 | @commands.has_role(constants.TLE_ADMIN) 124 | async def clear(self, ctx): 125 | """Stop tracking starboard messages and remove the currently set starboard channel 126 | from settings.""" 127 | cf_common.user_db.clear_starboard(ctx.guild.id) 128 | cf_common.user_db.clear_starboard_messages_for_guild(ctx.guild.id) 129 | await ctx.send(embed=discord_common.embed_success('Starboard channel cleared')) 130 | 131 | @starboard.command(brief='Remove a message from starboard') 132 | @commands.has_role(constants.TLE_ADMIN) 133 | async def remove(self, ctx, original_message_id: int): 134 | """Remove a particular message from the starboard database.""" 135 | rc = cf_common.user_db.remove_starboard_message(original_msg_id=original_message_id) 136 | if rc: 137 | await ctx.send(embed=discord_common.embed_success('Successfully removed')) 138 | else: 139 | await ctx.send(embed=discord_common.embed_alert('Not found in database')) 140 | 141 | @discord_common.send_error_if(StarboardCogError) 142 | async def cog_command_error(self, ctx, error): 143 | pass 144 | 145 | 146 | async def setup(bot): 147 | await bot.add_cog(Starboard(bot)) 148 | -------------------------------------------------------------------------------- /tle/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATA_DIR = 'data' 4 | LOGS_DIR = 'logs' 5 | 6 | ASSETS_DIR = os.path.join(DATA_DIR, 'assets') 7 | DB_DIR = os.path.join(DATA_DIR, 'db') 8 | MISC_DIR = os.path.join(DATA_DIR, 'misc') 9 | TEMP_DIR = os.path.join(DATA_DIR, 'temp') 10 | 11 | USER_DB_FILE_PATH = os.path.join(DB_DIR, 'user.db') 12 | CACHE_DB_FILE_PATH = os.path.join(DB_DIR, 'cache.db') 13 | 14 | FONTS_DIR = os.path.join(ASSETS_DIR, 'fonts') 15 | 16 | NOTO_SANS_CJK_BOLD_FONT_PATH = os.path.join(FONTS_DIR, 'NotoSansCJK-Bold.ttc') 17 | NOTO_SANS_CJK_REGULAR_FONT_PATH = os.path.join(FONTS_DIR, 'NotoSansCJK-Regular.ttc') 18 | 19 | CONTEST_WRITERS_JSON_FILE_PATH = os.path.join(MISC_DIR, 'contest_writers.json') 20 | 21 | LOG_FILE_PATH = os.path.join(LOGS_DIR, 'tle.log') 22 | 23 | ALL_DIRS = (attrib_value for attrib_name, attrib_value in list(globals().items()) 24 | if attrib_name.endswith('DIR')) 25 | 26 | TLE_ADMIN = os.environ.get('TLE_ADMIN', 'Admin') 27 | TLE_MODERATOR = os.environ.get('TLE_MODERATOR', 'Moderator') 28 | GEMINI_API_KEY = os.environ["GEMINI_API_KEY"] 29 | -------------------------------------------------------------------------------- /tle/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACodeDaily/TLE-ACodeDaily/288f74898ae7bf6ace2ee69118eb54ccb1eadf58/tle/util/__init__.py -------------------------------------------------------------------------------- /tle/util/cache_system2.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | from aiocache import cached 5 | 6 | from collections import defaultdict 7 | from discord.ext import commands 8 | 9 | from tle.util import codeforces_common as cf_common 10 | from tle.util import codeforces_api as cf 11 | from tle.util import events 12 | from tle.util import tasks 13 | from tle.util import paginator 14 | from tle.util.ranklist import Ranklist 15 | 16 | logger = logging.getLogger(__name__) 17 | _CONTESTS_PER_BATCH_IN_CACHE_UPDATES = 100 18 | CONTEST_BLACKLIST = {1308, 1309, 1431, 1432} 19 | 20 | 21 | def _is_blacklisted(contest): 22 | return contest.id in CONTEST_BLACKLIST 23 | 24 | 25 | class CacheError(commands.CommandError): 26 | pass 27 | 28 | 29 | class ContestCacheError(CacheError): 30 | pass 31 | 32 | 33 | class ContestNotFound(ContestCacheError): 34 | def __init__(self, contest_id): 35 | super().__init__(f'Contest with id `{contest_id}` not found') 36 | self.contest_id = contest_id 37 | 38 | 39 | class ContestCache: 40 | _NORMAL_CONTEST_RELOAD_DELAY = 30 * 60 41 | _EXCEPTION_CONTEST_RELOAD_DELAY = 5 * 60 42 | _ACTIVE_CONTEST_RELOAD_DELAY = 5 * 60 43 | _ACTIVATE_BEFORE = 20 * 60 44 | 45 | _RUNNING_PHASES = ('CODING', 'PENDING_SYSTEM_TEST', 'SYSTEM_TEST') 46 | 47 | def __init__(self, cache_master): 48 | self.cache_master = cache_master 49 | 50 | self.contests = [] 51 | self.contest_by_id = {} 52 | self.contests_by_phase = {phase: [] for phase in cf.Contest.PHASES} 53 | self.contests_by_phase['_RUNNING'] = [] 54 | self.contests_last_cache = 0 55 | 56 | self.reload_lock = asyncio.Lock() 57 | self.reload_exception = None 58 | self.next_delay = None 59 | 60 | self.logger = logging.getLogger(self.__class__.__name__) 61 | 62 | async def run(self): 63 | await self._try_disk() 64 | self._update_task.start() 65 | 66 | async def reload_now(self): 67 | """Force a reload. If currently reloading it will wait until done.""" 68 | reloading = self.reload_lock.locked() 69 | if reloading: 70 | # Wait until reload complete. 71 | # To wait until lock is free, await acquire then release immediately. 72 | async with self.reload_lock: 73 | pass 74 | else: 75 | await self._update_task.manual_trigger() 76 | 77 | if self.reload_exception: 78 | raise self.reload_exception 79 | 80 | def get_contest(self, contest_id): 81 | try: 82 | return self.contest_by_id[contest_id] 83 | except KeyError: 84 | raise ContestNotFound(contest_id) 85 | 86 | def get_problemset(self, contest_id): 87 | return self.cache_master.conn.get_problemset_from_contest(contest_id) 88 | 89 | def get_contests_in_phase(self, phase): 90 | return self.contests_by_phase[phase] 91 | 92 | async def _try_disk(self): 93 | async with self.reload_lock: 94 | contests = self.cache_master.conn.fetch_contests() 95 | if not contests: 96 | self.logger.info('Contest cache on disk is empty.') 97 | return 98 | await self._update(contests, from_api=False) 99 | 100 | @tasks.task_spec(name='ContestCacheUpdate') 101 | async def _update_task(self, _): 102 | async with self.reload_lock: 103 | self.next_delay = await self._reload_contests() 104 | self.reload_exception = None 105 | 106 | @_update_task.waiter() 107 | async def _update_task_waiter(self): 108 | await asyncio.sleep(self.next_delay) 109 | 110 | @_update_task.exception_handler() 111 | async def _update_task_exception_handler(self, ex): 112 | self.reload_exception = ex 113 | self.next_delay = self._EXCEPTION_CONTEST_RELOAD_DELAY 114 | 115 | async def _reload_contests(self): 116 | contests = await cf.contest.list() 117 | delay = await self._update(contests) 118 | return delay 119 | 120 | async def _update(self, contests, from_api=True): 121 | self.logger.info(f'{len(contests)} contests fetched from {"API" if from_api else "disk"}') 122 | contests.sort(key=lambda contest: (contest.startTimeSeconds, contest.id)) 123 | 124 | if from_api: 125 | rc = self.cache_master.conn.cache_contests(contests) 126 | self.logger.info(f'{rc} contests stored in database') 127 | 128 | contests_by_phase = {phase: [] for phase in cf.Contest.PHASES} 129 | contests_by_phase['_RUNNING'] = [] 130 | contest_by_id = {} 131 | for contest in contests: 132 | contests_by_phase[contest.phase].append(contest) 133 | contest_by_id[contest.id] = contest 134 | if contest.phase in self._RUNNING_PHASES: 135 | contests_by_phase['_RUNNING'].append(contest) 136 | 137 | now = time.time() 138 | delay = self._NORMAL_CONTEST_RELOAD_DELAY 139 | 140 | for contest in contests_by_phase['BEFORE']: 141 | at = contest.startTimeSeconds - self._ACTIVATE_BEFORE 142 | if at > now: 143 | # Reload at _ACTIVATE_BEFORE before contest to monitor contest delays. 144 | delay = min(delay, at - now) 145 | else: 146 | # The contest starts in <= _ACTIVATE_BEFORE. 147 | # Reload at contest start, or after _ACTIVE_CONTEST_RELOAD_DELAY, whichever comes first. 148 | delay = min(contest.startTimeSeconds - now, self._ACTIVE_CONTEST_RELOAD_DELAY) 149 | 150 | if contests_by_phase['_RUNNING']: 151 | # If any contest is running, reload at an increased rate to detect FINISHED 152 | delay = min(delay, self._ACTIVE_CONTEST_RELOAD_DELAY) 153 | 154 | self.contests = contests 155 | self.contests_by_phase = contests_by_phase 156 | self.contest_by_id = contest_by_id 157 | self.contests_last_cache = time.time() 158 | 159 | cf_common.event_sys.dispatch(events.ContestListRefresh, self.contests.copy()) 160 | 161 | return delay 162 | 163 | 164 | class ProblemCache: 165 | _RELOAD_INTERVAL = 6 * 60 * 60 166 | 167 | def __init__(self, cache_master): 168 | self.cache_master = cache_master 169 | 170 | self.problems = [] 171 | self.problem_by_name = {} 172 | self.problems_last_cache = 0 173 | 174 | self.reload_lock = asyncio.Lock() 175 | self.reload_exception = None 176 | 177 | self.logger = logging.getLogger(self.__class__.__name__) 178 | 179 | async def run(self): 180 | await self._try_disk() 181 | self._update_task.start() 182 | 183 | async def reload_now(self): 184 | """Force a reload. If currently reloading it will wait until done.""" 185 | reloading = self.reload_lock.locked() 186 | if reloading: 187 | # Wait until reload complete. 188 | # To wait until lock is free, await acquire then release immediately. 189 | async with self.reload_lock: 190 | pass 191 | else: 192 | await self._update_task.manual_trigger() 193 | 194 | if self.reload_exception: 195 | raise self.reload_exception 196 | 197 | async def _try_disk(self): 198 | async with self.reload_lock: 199 | problems = self.cache_master.conn.fetch_problems() 200 | if not problems: 201 | self.logger.info('Problem cache on disk is empty.') 202 | return 203 | self.problems = problems 204 | self.problem_by_name = {problem.name: problem for problem in problems} 205 | self.logger.info(f'{len(self.problems)} problems fetched from disk') 206 | 207 | @tasks.task_spec(name='ProblemCacheUpdate', 208 | waiter=tasks.Waiter.fixed_delay(_RELOAD_INTERVAL)) 209 | async def _update_task(self, _): 210 | async with self.reload_lock: 211 | await self._reload_problems() 212 | self.reload_exception = None 213 | 214 | @_update_task.exception_handler() 215 | async def _update_task_exception_handler(self, ex): 216 | self.reload_exception = ex 217 | 218 | async def _reload_problems(self): 219 | problems, _ = await cf.problemset.problems() 220 | await self._update(problems) 221 | 222 | async def _update(self, problems): 223 | self.logger.info(f'{len(problems)} problems fetched from API') 224 | contest_map = {problem.contestId: self.cache_master.contest_cache.contest_by_id.get(problem.contestId) 225 | for problem in problems} 226 | 227 | def keep(problem): 228 | return (contest_map[problem.contestId] and 229 | problem.has_metadata()) 230 | 231 | filtered_problems = list(filter(keep, problems)) 232 | problem_by_name = { 233 | problem.name: problem # This will discard some valid problems 234 | for problem in filtered_problems 235 | } 236 | self.logger.info(f'Keeping {len(problem_by_name)} problems') 237 | 238 | self.problems = list(problem_by_name.values()) 239 | self.problem_by_name = problem_by_name 240 | self.problems_last_cache = time.time() 241 | 242 | rc = self.cache_master.conn.cache_problems(self.problems) 243 | self.logger.info(f'{rc} problems stored in database') 244 | 245 | 246 | class ProblemsetCacheError(CacheError): 247 | pass 248 | 249 | 250 | class ProblemsetNotCached(ProblemsetCacheError): 251 | def __init__(self, contest_id): 252 | super().__init__(f'Problemset for contest with id {contest_id} not cached.') 253 | 254 | 255 | class ProblemsetCache: 256 | _MONITOR_PERIOD_SINCE_CONTEST_END = 14 * 24 * 60 * 60 257 | _RELOAD_DELAY = 60 * 60 258 | 259 | def __init__(self, cache_master): 260 | self.problems = [] 261 | # problem -> list of contests in which it appears 262 | self.problem_to_contests = defaultdict(list) 263 | self.cache_master = cache_master 264 | self.update_lock = asyncio.Lock() 265 | self.logger = logging.getLogger(self.__class__.__name__) 266 | 267 | async def run(self): 268 | if self.cache_master.conn.problemset_empty(): 269 | self.logger.warning('Problemset cache on disk is empty. This must be populated ' 270 | 'manually before use.') 271 | self._update_task.start() 272 | 273 | async def update_for_contest(self, contest_id): 274 | """Update problemset for a particular contest. Intended for manual trigger.""" 275 | async with self.update_lock: 276 | contest = self.cache_master.contest_cache.get_contest(contest_id) 277 | problemset, _ = await self._fetch_problemsets([contest], force_fetch=True) 278 | self.cache_master.conn.clear_problemset(contest_id) 279 | self._save_problems(problemset) 280 | return len(problemset) 281 | 282 | async def update_for_all(self): 283 | """Update problemsets for all finished contests. Intended for manual trigger.""" 284 | async with self.update_lock: 285 | contests = self.cache_master.contest_cache.contests_by_phase['FINISHED'] 286 | problemsets, _ = await self._fetch_problemsets(contests, force_fetch=True) 287 | self.cache_master.conn.clear_problemset() 288 | self._save_problems(problemsets) 289 | return len(problemsets) 290 | 291 | @tasks.task_spec(name='ProblemsetCacheUpdate', 292 | waiter=tasks.Waiter.fixed_delay(_RELOAD_DELAY)) 293 | async def _update_task(self, _): 294 | async with self.update_lock: 295 | contests = self.cache_master.contest_cache.contests_by_phase['FINISHED'] 296 | new_problems, updated_problems = await self._fetch_problemsets(contests) 297 | self._save_problems(new_problems + updated_problems) 298 | self._update_from_disk() 299 | self.logger.info(f'{len(new_problems)} new problems saved and {len(updated_problems)} ' 300 | 'saved problems updated.') 301 | 302 | async def _fetch_problemsets(self, contests, *, force_fetch=False): 303 | # We assume it is possible for problems in the same contest to get assigned rating at 304 | # different times. 305 | new_contest_ids = [] 306 | contests_to_refetch = [] # List of (id, set of saved rated problem indices) pairs. 307 | if force_fetch: 308 | new_contest_ids = [contest.id for contest in contests] 309 | else: 310 | now = time.time() 311 | for contest in contests: 312 | if now > contest.end_time + self._MONITOR_PERIOD_SINCE_CONTEST_END: 313 | # Contest too old, we do not want to check it. 314 | continue 315 | problemset = self.cache_master.conn.fetch_problemset(contest.id) 316 | if not problemset: 317 | new_contest_ids.append(contest.id) 318 | continue 319 | rated_problem_idx = {prob.index for prob in problemset if prob.rating is not None} 320 | if len(rated_problem_idx) < len(problemset): 321 | contests_to_refetch.append((contest.id, rated_problem_idx)) 322 | 323 | new_problems, updated_problems = [], [] 324 | for contest_id in new_contest_ids: 325 | new_problems += await self._fetch_for_contest(contest_id) 326 | for contest_id, rated_problem_idx in contests_to_refetch: 327 | updated_problems += [prob for prob in await self._fetch_for_contest(contest_id) 328 | if prob.rating is not None and prob.index not in rated_problem_idx] 329 | 330 | return new_problems, updated_problems 331 | 332 | async def _fetch_for_contest(self, contest_id): 333 | try: 334 | _, problemset, _ = await cf.contest.standings(contest_id=contest_id, from_=1, 335 | count=1) 336 | except cf.CodeforcesApiError as er: 337 | self.logger.warning(f'Problemset fetch failed for contest {contest_id}. {er!r}') 338 | problemset = [] 339 | return problemset 340 | 341 | def _save_problems(self, problems): 342 | rc = self.cache_master.conn.cache_problemset(problems) 343 | self.logger.info(f'Saved {rc} problems to database.') 344 | 345 | def get_problemset(self, contest_id): 346 | problemset = self.cache_master.conn.fetch_problemset(contest_id) 347 | if not problemset: 348 | raise ProblemsetNotCached(contest_id) 349 | return problemset 350 | 351 | def _update_from_disk(self): 352 | self.problems = self.cache_master.conn.fetch_problems2() 353 | self.problem_to_contests = defaultdict(list) 354 | for problem in self.problems: 355 | try: 356 | contest = cf_common.cache2.contest_cache.get_contest(problem.contestId) 357 | problem_id = (problem.name, contest.startTimeSeconds) 358 | self.problem_to_contests[problem_id].append(contest.id) 359 | except ContestNotFound: 360 | pass 361 | 362 | 363 | class RatingChangesCache: 364 | _RATED_DELAY = 36 * 60 * 60 365 | _RELOAD_DELAY = 10 * 60 366 | 367 | def __init__(self, cache_master): 368 | self.cache_master = cache_master 369 | self.monitored_contests = [] 370 | self.handle_rating_cache = {} 371 | self.logger = logging.getLogger(self.__class__.__name__) 372 | 373 | async def run(self): 374 | self._refresh_handle_cache() 375 | if not self.handle_rating_cache: 376 | self.logger.warning('Rating changes cache on disk is empty. This must be populated ' 377 | 'manually before use.') 378 | self._update_task.start() 379 | 380 | async def fetch_contest(self, contest_id): 381 | """Fetch rating changes for a particular contest. Intended for manual trigger.""" 382 | contest = self.cache_master.contest_cache.contest_by_id[contest_id] 383 | changes = await self._fetch([contest]) 384 | self.cache_master.conn.clear_rating_changes(contest_id=contest_id) 385 | self._save_changes(changes) 386 | return len(changes) 387 | 388 | async def fetch_all_contests(self): 389 | """Fetch rating changes for all contests. Intended for manual trigger.""" 390 | self.cache_master.conn.clear_rating_changes() 391 | return await self.fetch_missing_contests() 392 | 393 | async def fetch_missing_contests(self): 394 | """Fetch rating changes for contests which are not saved in database. Intended for 395 | manual trigger.""" 396 | contests = self.cache_master.contest_cache.contests_by_phase['FINISHED'] 397 | contests = [ 398 | contest for contest in contests if not self.has_rating_changes_saved(contest.id)] 399 | total_changes = 0 400 | for contests_chunk in paginator.chunkify(contests, _CONTESTS_PER_BATCH_IN_CACHE_UPDATES): 401 | contests_chunk = await self._fetch(contests_chunk) 402 | self._save_changes(contests_chunk) 403 | total_changes += len(contests_chunk) 404 | return total_changes 405 | 406 | def is_newly_finished_without_rating_changes(self, contest): 407 | now = time.time() 408 | return (contest.phase == 'FINISHED' and 409 | now - contest.end_time < self._RATED_DELAY and 410 | not self.has_rating_changes_saved(contest.id)) 411 | 412 | @tasks.task_spec(name='RatingChangesCacheUpdate', 413 | waiter=tasks.Waiter.for_event(events.ContestListRefresh)) 414 | async def _update_task(self, _): 415 | # Some notes: 416 | # A hack phase is tagged as FINISHED with empty list of rating changes. After the hack 417 | # phase, the phase changes to systest then again FINISHED. Since we cannot differentiate 418 | # between the two FINISHED phases, we are forced to fetch during both. 419 | # A contest also has empty list if it is unrated. We assume that is the case if 420 | # _RATED_DELAY time has passed since the contest end. 421 | 422 | to_monitor = [ 423 | contest for contest in 424 | self.cache_master.contest_cache.contests_by_phase['FINISHED'] 425 | if self.is_newly_finished_without_rating_changes(contest) 426 | and not _is_blacklisted(contest) 427 | ] 428 | 429 | cur_ids = {contest.id for contest in self.monitored_contests} 430 | new_ids = {contest.id for contest in to_monitor} 431 | if new_ids != cur_ids: 432 | await self._monitor_task.stop() 433 | if to_monitor: 434 | self.monitored_contests = to_monitor 435 | self._monitor_task.start() 436 | else: 437 | self.monitored_contests = [] 438 | 439 | @tasks.task_spec(name='RatingChangesCacheUpdate.MonitorNewlyFinishedContests', 440 | waiter=tasks.Waiter.fixed_delay(_RELOAD_DELAY)) 441 | async def _monitor_task(self, _): 442 | self.monitored_contests = [ 443 | contest for contest in self.monitored_contests 444 | if self.is_newly_finished_without_rating_changes(contest) 445 | and not _is_blacklisted(contest) 446 | ] 447 | 448 | if not self.monitored_contests: 449 | self.logger.info('Rated changes fetched for contests that were being monitored.') 450 | await self._monitor_task.stop() 451 | return 452 | 453 | contest_changes_pairs = await self._fetch(self.monitored_contests) 454 | # Sort by the rating update time of the first change in the list of changes, assuming 455 | # every change in the list has the same time. 456 | contest_changes_pairs.sort(key=lambda pair: pair[1][0].ratingUpdateTimeSeconds) 457 | self._save_changes(contest_changes_pairs) 458 | for contest, changes in contest_changes_pairs: 459 | cf_common.event_sys.dispatch(events.RatingChangesUpdate, contest=contest, 460 | rating_changes=changes) 461 | 462 | async def _fetch(self, contests): 463 | all_changes = [] 464 | for contest in contests: 465 | try: 466 | changes = await cf.contest.ratingChanges(contest_id=contest.id) 467 | self.logger.info(f'{len(changes)} rating changes fetched for contest {contest.id}') 468 | if changes: 469 | all_changes.append((contest, changes)) 470 | except cf.CodeforcesApiError as er: 471 | self.logger.warning(f'Fetch rating changes failed for contest {contest.id}, ignoring. {er!r}') 472 | pass 473 | return all_changes 474 | 475 | def _save_changes(self, contest_changes_pairs): 476 | flattened = [change for _, changes in contest_changes_pairs for change in changes] 477 | if not flattened: 478 | return 479 | rc = self.cache_master.conn.save_rating_changes(flattened) 480 | self.logger.info(f'Saved {rc} changes to database.') 481 | self._refresh_handle_cache() 482 | 483 | def _refresh_handle_cache(self): 484 | changes = self.cache_master.conn.get_all_rating_changes() 485 | handle_rating_cache = {} 486 | for change in changes: 487 | handle_rating_cache[change.handle] = change.newRating 488 | self.handle_rating_cache = handle_rating_cache 489 | self.logger.info(f'Ratings for {len(handle_rating_cache)} handles cached') 490 | 491 | def get_users_with_more_than_n_contests(self, time_cutoff, n): 492 | return self.cache_master.conn.get_users_with_more_than_n_contests(time_cutoff, n) 493 | 494 | def get_rating_changes_for_contest(self, contest_id): 495 | return self.cache_master.conn.get_rating_changes_for_contest(contest_id) 496 | 497 | def has_rating_changes_saved(self, contest_id): 498 | return self.cache_master.conn.has_rating_changes_saved(contest_id) 499 | 500 | def get_rating_changes_for_handle(self, handle): 501 | return self.cache_master.conn.get_rating_changes_for_handle(handle) 502 | 503 | def get_current_rating(self, handle, default_if_absent=False): 504 | return self.handle_rating_cache.get(handle, 505 | cf.DEFAULT_RATING if default_if_absent else None) 506 | 507 | def get_all_ratings_before_timestamp(self, timestamp): 508 | res = self.cache_master.conn.get_all_ratings_before_timestamp(timestamp) 509 | return {ratingchange.handle: ratingchange for ratingchange in res} 510 | 511 | def get_all_ratings(self): 512 | return list(self.handle_rating_cache.values()) 513 | 514 | 515 | class RanklistCacheError(CacheError): 516 | pass 517 | 518 | 519 | class RanklistNotMonitored(RanklistCacheError): 520 | def __init__(self, contest): 521 | super().__init__(f'The ranklist for `{contest.name}` is not being monitored') 522 | self.contest = contest 523 | 524 | 525 | class RanklistCache: 526 | _RELOAD_DELAY = 2 * 60 527 | 528 | def __init__(self, cache_master): 529 | self.cache_master = cache_master 530 | self.monitored_contests = [] 531 | self.ranklist_by_contest = {} 532 | self.logger = logging.getLogger(self.__class__.__name__) 533 | 534 | async def run(self): 535 | self._update_task.start() 536 | 537 | # Currently ranklist monitoring only supports caching unofficial ranklists 538 | # If official ranklist is asked, the cache will throw RanklistNotMonitored Error 539 | def get_ranklist(self, contest, show_official): 540 | if show_official or contest.id not in self.ranklist_by_contest: 541 | raise RanklistNotMonitored(contest) 542 | return self.ranklist_by_contest[contest.id] 543 | 544 | @tasks.task_spec(name='RanklistCacheUpdate', 545 | waiter=tasks.Waiter.for_event(events.ContestListRefresh)) 546 | async def _update_task(self, _): 547 | contests_by_phase = self.cache_master.contest_cache.contests_by_phase 548 | running_contests = contests_by_phase['_RUNNING'] 549 | 550 | rating_cache = self.cache_master.rating_changes_cache 551 | finished_contests = [ 552 | contest for contest in contests_by_phase['FINISHED'] 553 | if not _is_blacklisted(contest) 554 | and rating_cache.is_newly_finished_without_rating_changes(contest) 555 | ] 556 | 557 | to_monitor = running_contests + finished_contests 558 | cur_ids = {contest.id for contest in self.monitored_contests} 559 | new_ids = {contest.id for contest in to_monitor} 560 | if new_ids != cur_ids: 561 | await self._monitor_task.stop() 562 | if to_monitor: 563 | self.monitored_contests = to_monitor 564 | self._monitor_task.start() 565 | else: 566 | self.ranklist_by_contest = {} 567 | 568 | @tasks.task_spec(name='RanklistCacheUpdate.MonitorActiveContests', 569 | waiter=tasks.Waiter.fixed_delay(_RELOAD_DELAY)) 570 | async def _monitor_task(self, _): 571 | cache = self.cache_master.rating_changes_cache 572 | self.monitored_contests = [ 573 | contest for contest in self.monitored_contests 574 | if not _is_blacklisted(contest) and ( 575 | contest.phase != 'FINISHED' 576 | or cache.is_newly_finished_without_rating_changes(contest)) 577 | ] 578 | 579 | if not self.monitored_contests: 580 | self.ranklist_by_contest = {} 581 | self.logger.info('No more active contests for which to monitor ranklists.') 582 | await self._monitor_task.stop() 583 | return 584 | 585 | ranklist_by_contest = await self._fetch(self.monitored_contests) 586 | # If any ranklist could not be fetched, the old ranklist is kept. 587 | for contest_id, ranklist in ranklist_by_contest.items(): 588 | self.ranklist_by_contest[contest_id] = ranklist 589 | 590 | @staticmethod 591 | async def _get_contest_details(contest_id, show_unofficial): 592 | contest, problems, standings = await cf.contest.standings(contest_id=contest_id, 593 | show_unofficial=show_unofficial) 594 | # Exclude PRACTICE and MANAGER 595 | standings = [row for row in standings 596 | if row.party.participantType in ('CONTESTANT', 'OUT_OF_COMPETITION', 'VIRTUAL')] 597 | 598 | return contest, problems, standings 599 | 600 | # Fetch final rating changes from CF. 601 | # For older contests. 602 | async def _get_ranklist_with_fetched_changes(self, contest_id, show_unofficial): 603 | contest, problems, standings = await self._get_contest_details(contest_id, show_unofficial) 604 | now = time.time() 605 | 606 | is_rated = False 607 | try: 608 | changes = await cf.contest.ratingChanges(contest_id=contest_id) 609 | # For contests intended to be rated but declared unrated, an empty list is returned. 610 | is_rated = len(changes) > 0 611 | except cf.RatingChangesUnavailableError: 612 | pass 613 | 614 | ranklist = None 615 | if is_rated: 616 | ranklist = Ranklist(contest, problems, standings, now, is_rated=is_rated) 617 | delta_by_handle = {change.handle: change.newRating - change.oldRating 618 | for change in changes} 619 | ranklist.set_deltas(delta_by_handle) 620 | 621 | return ranklist 622 | 623 | # Rating changes have not been applied yet, predict rating changes. 624 | # For running/recent/unrated contests. 625 | async def _get_ranklist_with_predicted_changes(self, contest_id, show_unofficial): 626 | contest, problems, standings = await self._get_contest_details(contest_id, show_unofficial) 627 | now = time.time() 628 | 629 | standings_official = None 630 | if not show_unofficial: 631 | standings_official = standings 632 | else: 633 | _, _, standings_official = await cf.contest.standings(contest_id=contest_id) 634 | 635 | has_teams = any(row.party.teamId is not None for row in standings_official) 636 | if cf_common.is_nonstandard_contest(contest) or has_teams: 637 | # The contest is not traditionally rated 638 | ranklist = Ranklist(contest, problems, standings, now, is_rated=False) 639 | else: 640 | current_rating = await CacheSystem.getUsersEffectiveRating(activeOnly=False) 641 | current_rating = {row.party.members[0].handle: current_rating.get(row.party.members[0].handle, 1500) 642 | for row in standings_official} 643 | if 'Educational' in contest.name: 644 | # For some reason educational contests return all contestants in ranklist even 645 | # when unofficial contestants are not requested. 646 | current_rating = {handle: rating 647 | for handle, rating in current_rating.items() if rating < 2100} 648 | ranklist = Ranklist(contest, problems, standings, now, is_rated=True) 649 | ranklist.predict(current_rating) 650 | return ranklist 651 | 652 | async def generate_ranklist(self, contest_id, *, fetch_changes=False, predict_changes=False, show_unofficial=True): 653 | assert fetch_changes ^ predict_changes 654 | 655 | ranklist = None 656 | if fetch_changes: 657 | ranklist = await self._get_ranklist_with_fetched_changes(contest_id, show_unofficial) 658 | if ranklist is None: 659 | # Either predict_changes was true or fetching rating changes failed 660 | ranklist = await self._get_ranklist_with_predicted_changes(contest_id, show_unofficial) 661 | 662 | # for some reason Educational contests also have div1 peeps in the official standings. 663 | # hence we need to manually weed them out 664 | if not show_unofficial and 'Educational' in ranklist.contest.name: 665 | ranklist.remove_unofficial_contestants() 666 | 667 | return ranklist 668 | 669 | async def generate_vc_ranklist(self, contest_id, handle_to_member_id): 670 | handles = list(handle_to_member_id.keys()) 671 | contest, problems, standings = await cf.contest.standings(contest_id=contest_id, 672 | show_unofficial=True) 673 | # Exclude PRACTICE, MANAGER and OUR_OF_COMPETITION 674 | standings = [row for row in standings 675 | if row.party.participantType == 'CONTESTANT' or 676 | row.party.members[0].handle in handles] 677 | standings.sort(key=lambda row: row.rank) 678 | standings = [row._replace(rank=i + 1) for i, row in enumerate(standings)] 679 | now = time.time() 680 | rating_changes = await cf.contest.ratingChanges(contest_id=contest_id) 681 | current_official_rating = {rating_change.handle: rating_change.oldRating 682 | for rating_change in rating_changes} 683 | 684 | # TODO: assert that none of the given handles are in the official standings. 685 | handles = [row.party.members[0].handle for row in standings 686 | if row.party.members[0].handle in handles and 687 | row.party.participantType == 'VIRTUAL'] 688 | current_vc_rating = {handle: cf_common.user_db.get_vc_rating(handle_to_member_id.get(handle)) 689 | for handle in handles} 690 | ranklist = Ranklist(contest, problems, standings, now, is_rated=True) 691 | delta_by_handle = {} 692 | for handle in handles: 693 | mixed_ratings = current_official_rating.copy() 694 | mixed_ratings[handle] = current_vc_rating.get(handle) 695 | ranklist.predict(mixed_ratings) 696 | delta_by_handle[handle] = ranklist.delta_by_handle.get(handle, 0) 697 | 698 | ranklist.delta_by_handle = delta_by_handle 699 | return ranklist 700 | 701 | async def _fetch(self, contests): 702 | ranklist_by_contest = {} 703 | for contest in contests: 704 | try: 705 | ranklist = await self.generate_ranklist(contest.id, predict_changes=True) 706 | ranklist_by_contest[contest.id] = ranklist 707 | self.logger.info(f'Ranklist fetched for contest {contest.id}') 708 | except cf.CodeforcesApiError as er: 709 | self.logger.warning(f'Ranklist fetch failed for contest {contest.id}. {er!r}') 710 | 711 | return ranklist_by_contest 712 | 713 | 714 | class CacheSystem: 715 | def __init__(self, conn): 716 | self.conn = conn 717 | self.contest_cache = ContestCache(self) 718 | self.problem_cache = ProblemCache(self) 719 | self.rating_changes_cache = RatingChangesCache(self) 720 | self.ranklist_cache = RanklistCache(self) 721 | self.problemset_cache = ProblemsetCache(self) 722 | 723 | async def run(self): 724 | await self.rating_changes_cache.run() 725 | await self.ranklist_cache.run() 726 | await self.contest_cache.run() 727 | await self.problem_cache.run() 728 | await self.problemset_cache.run() 729 | 730 | @staticmethod 731 | @cached(ttl=30 * 60) 732 | async def getUsersEffectiveRating(*, activeOnly=None): 733 | """ Returns a dictionary mapping user handle to his effective rating for all the users. 734 | """ 735 | ratedList = await cf.user.ratedList(activeOnly=activeOnly) 736 | users_effective_rating_dict = {user.handle: user.effective_rating 737 | for user in ratedList} 738 | return users_effective_rating_dict 739 | -------------------------------------------------------------------------------- /tle/util/codeforces_api.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | import functools 5 | from collections import namedtuple, deque, defaultdict 6 | 7 | import aiohttp 8 | 9 | from discord.ext import commands 10 | from tle.util import codeforces_common as cf_common 11 | 12 | API_BASE_URL = 'https://codeforces.com/api/' 13 | CONTEST_BASE_URL = 'https://codeforces.com/contest/' 14 | CONTESTS_BASE_URL = 'https://codeforces.com/contests/' 15 | GYM_BASE_URL = 'https://codeforces.com/gym/' 16 | PROFILE_BASE_URL = 'https://codeforces.com/profile/' 17 | ACMSGURU_BASE_URL = 'https://codeforces.com/problemsets/acmsguru/' 18 | GYM_ID_THRESHOLD = 100000 19 | DEFAULT_RATING = 800 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | Rank = namedtuple('Rank', 'low high title title_abbr color_graph color_embed') 24 | 25 | RATED_RANKS = ( 26 | Rank(-10 ** 9, 1200, 'Newbie', 'N', '#CCCCCC', 0x808080), 27 | Rank(1200, 1400, 'Pupil', 'P', '#77FF77', 0x008000), 28 | Rank(1400, 1600, 'Specialist', 'S', '#77DDBB', 0x03a89e), 29 | Rank(1600, 1900, 'Expert', 'E', '#AAAAFF', 0x0000ff), 30 | Rank(1900, 2100, 'Candidate Master', 'CM', '#FF88FF', 0xaa00aa), 31 | Rank(2100, 2300, 'Master', 'M', '#FFCC88', 0xff8c00), 32 | Rank(2300, 2400, 'International Master', 'IM', '#FFBB55', 0xf57500), 33 | Rank(2400, 2600, 'Grandmaster', 'GM', '#FF7777', 0xff3030), 34 | Rank(2600, 3000, 'International Grandmaster', 'IGM', '#FF3333', 0xff0000), 35 | Rank(3000, 10 ** 9, 'Legendary Grandmaster', 'LGM', '#AA0000', 0xcc0000) 36 | ) 37 | UNRATED_RANK = Rank(None, None, 'Unrated', None, None, None) 38 | 39 | 40 | def rating2rank(rating): 41 | if rating is None: 42 | return UNRATED_RANK 43 | for rank in RATED_RANKS: 44 | if rank.low <= rating < rank.high: 45 | return rank 46 | 47 | 48 | # Data classes 49 | 50 | class User(namedtuple('User', 'handle firstName lastName country city organization contribution ' 51 | 'rating maxRating lastOnlineTimeSeconds registrationTimeSeconds ' 52 | 'friendOfCount titlePhoto')): 53 | __slots__ = () 54 | 55 | @property 56 | def effective_rating(self): 57 | return self.rating if self.rating is not None else DEFAULT_RATING 58 | 59 | @property 60 | def rank(self): 61 | return rating2rank(self.rating) 62 | 63 | @property 64 | def url(self): 65 | return f'{PROFILE_BASE_URL}{self.handle}' 66 | 67 | 68 | RatingChange = namedtuple('RatingChange', 69 | 'contestId contestName handle rank ratingUpdateTimeSeconds oldRating newRating') 70 | 71 | 72 | class Contest(namedtuple('Contest', 'id name startTimeSeconds durationSeconds type phase preparedBy')): 73 | __slots__ = () 74 | PHASES = 'BEFORE CODING PENDING_SYSTEM_TEST SYSTEM_TEST FINISHED'.split() 75 | 76 | @property 77 | def end_time(self): 78 | return self.startTimeSeconds + self.durationSeconds 79 | 80 | @property 81 | def url(self): 82 | return f'{CONTEST_BASE_URL if self.id < GYM_ID_THRESHOLD else GYM_BASE_URL}{self.id}' 83 | 84 | @property 85 | def register_url(self): 86 | return f'{CONTESTS_BASE_URL}{self.id}' 87 | 88 | def matches(self, markers): 89 | def strfilt(s): 90 | return ''.join(x for x in s.lower() if x.isalnum()) 91 | return any(strfilt(marker) in strfilt(self.name) for marker in markers) 92 | 93 | class Party(namedtuple('Party', ('contestId members participantType teamId teamName ghost room ' 94 | 'startTimeSeconds'))): 95 | __slots__ = () 96 | PARTICIPANT_TYPES = ('CONTESTANT', 'PRACTICE', 'VIRTUAL', 'MANAGER', 'OUT_OF_COMPETITION') 97 | 98 | 99 | Member = namedtuple('Member', 'handle') 100 | 101 | 102 | class Problem(namedtuple('Problem', 'contestId problemsetName index name type points rating tags')): 103 | __slots__ = () 104 | 105 | @property 106 | def contest_identifier(self): 107 | return f'{self.contestId}{self.index}' 108 | 109 | @property 110 | def url(self): 111 | if self.contestId is None: 112 | assert self.problemsetName == 'acmsguru', f'Unknown problemset {self.problemsetName}' 113 | return f'{ACMSGURU_BASE_URL}problem/99999/{self.index}' 114 | base = CONTEST_BASE_URL if self.contestId < GYM_ID_THRESHOLD else GYM_BASE_URL 115 | return f'{base}{self.contestId}/problem/{self.index}' 116 | 117 | def has_metadata(self): 118 | return self.contestId is not None and self.rating is not None 119 | 120 | def _matching_tags_dict(self, match_tags): 121 | """Returns a dict with matching tags.""" 122 | tags = defaultdict(list) 123 | for match_tag in match_tags: 124 | for tag in self.tags: 125 | if match_tag in tag: 126 | tags[match_tag].append(tag) 127 | return dict(tags) 128 | 129 | def matches_all_tags(self, match_tags): 130 | match_tags = set(match_tags) 131 | return len(self._matching_tags_dict(match_tags)) == len(match_tags) 132 | 133 | def matches_any_tag(self, match_tags): 134 | match_tags = set(match_tags) 135 | return len(self._matching_tags_dict(match_tags)) > 0 136 | 137 | def get_matched_tags(self, match_tags): 138 | return [ 139 | tag for tags in self._matching_tags_dict(match_tags).values() 140 | for tag in tags 141 | ] 142 | 143 | ProblemStatistics = namedtuple('ProblemStatistics', 'contestId index solvedCount') 144 | 145 | Submission = namedtuple('Submissions', 146 | 'id contestId problem author programmingLanguage verdict creationTimeSeconds relativeTimeSeconds') 147 | 148 | RanklistRow = namedtuple('RanklistRow', 'party rank points penalty problemResults') 149 | 150 | ProblemResult = namedtuple('ProblemResult', 151 | 'points penalty rejectedAttemptCount type bestSubmissionTimeSeconds') 152 | 153 | 154 | def make_from_dict(namedtuple_cls, dict_): 155 | field_vals = [dict_.get(field) for field in namedtuple_cls._fields] 156 | return namedtuple_cls._make(field_vals) 157 | 158 | 159 | # Error classes 160 | 161 | class CodeforcesApiError(commands.CommandError): 162 | """Base class for all API related errors.""" 163 | def __init__(self, message=None): 164 | super().__init__(message or 'Codeforces API error. There is nothing you or the Admins of the Discord server can do to fix it. We need to wait until Mike does his job.') 165 | 166 | 167 | class TrueApiError(CodeforcesApiError): 168 | """An error originating from a valid response of the API.""" 169 | def __init__(self, comment, message=None): 170 | super().__init__(message) 171 | self.comment = comment 172 | 173 | 174 | class ClientError(CodeforcesApiError): 175 | """An error caused by a request to the API failing.""" 176 | def __init__(self): 177 | super().__init__('Error connecting to Codeforces API') 178 | 179 | 180 | class HandleNotFoundError(TrueApiError): 181 | def __init__(self, comment, handle): 182 | super().__init__(comment, f'Handle `{handle}` not found on Codeforces') 183 | self.handle = handle 184 | 185 | 186 | class HandleInvalidError(TrueApiError): 187 | def __init__(self, comment, handle): 188 | super().__init__(comment, f'`{handle}` is not a valid Codeforces handle') 189 | self.handle = handle 190 | 191 | 192 | class CallLimitExceededError(TrueApiError): 193 | def __init__(self, comment): 194 | super().__init__(comment, 'Codeforces API call limit exceeded') 195 | 196 | 197 | class ContestNotFoundError(TrueApiError): 198 | def __init__(self, comment, contest_id): 199 | super().__init__(comment, f'Contest with ID `{contest_id}` not found on Codeforces') 200 | 201 | 202 | class RatingChangesUnavailableError(TrueApiError): 203 | def __init__(self, comment, contest_id): 204 | super().__init__(comment, f'Rating changes unavailable for contest with ID `{contest_id}`') 205 | 206 | 207 | # Codeforces API query methods 208 | 209 | _session = None 210 | 211 | 212 | async def initialize(): 213 | global _session 214 | _session = aiohttp.ClientSession() 215 | 216 | 217 | def _bool_to_str(value): 218 | if type(value) is bool: 219 | return 'true' if value else 'false' 220 | raise TypeError(f'Expected bool, got {value} of type {type(value)}') 221 | 222 | 223 | def cf_ratelimit(f): 224 | tries = 3 225 | per_second = 1 226 | last = deque([0]*per_second) 227 | 228 | @functools.wraps(f) 229 | async def wrapped(*args, **kwargs): 230 | for i in range(tries): 231 | now = time.time() 232 | 233 | # Next valid slot is 1s after the `per_second`th last request 234 | next_valid = max(now, 1 + last[0]) 235 | last.append(next_valid) 236 | last.popleft() 237 | 238 | # Delay as needed 239 | delay = next_valid - now 240 | if delay > 0: 241 | await asyncio.sleep(delay) 242 | 243 | try: 244 | return await f(*args, **kwargs) 245 | except (ClientError, CallLimitExceededError) as e: 246 | logger.info(f'Try {i+1}/{tries} at query failed.') 247 | logger.info(repr(e)) 248 | if i < tries - 1: 249 | logger.info(f'Retrying...') 250 | else: 251 | logger.info(f'Aborting.') 252 | raise e 253 | return wrapped 254 | 255 | 256 | @cf_ratelimit 257 | async def _query_api(path, data=None): 258 | url = API_BASE_URL + path 259 | try: 260 | logger.info(f'Querying CF API at {url} with {data}') 261 | # Explicitly state encoding (though aiohttp accepts gzip by default) 262 | headers = {'Accept-Encoding': 'gzip'} 263 | async with _session.post(url, data=data, headers=headers) as resp: 264 | try: 265 | respjson = await resp.json() 266 | except aiohttp.ContentTypeError: 267 | logger.warning(f'CF API did not respond with JSON, status {resp.status}.') 268 | raise CodeforcesApiError 269 | if resp.status == 200: 270 | return respjson['result'] 271 | comment = f'HTTP Error {resp.status}, {respjson.get("comment")}' 272 | except aiohttp.ClientError as e: 273 | logger.error(f'Request to CF API encountered error: {e!r}') 274 | raise ClientError from e 275 | logger.warning(f'Query to CF API failed: {comment}') 276 | if 'limit exceeded' in comment: 277 | raise CallLimitExceededError(comment) 278 | raise TrueApiError(comment) 279 | 280 | 281 | class contest: 282 | @staticmethod 283 | async def list(*, gym=None): 284 | params = {} 285 | if gym is not None: 286 | params['gym'] = _bool_to_str(gym) 287 | resp = await _query_api('contest.list', params) 288 | return [make_from_dict(Contest, contest_dict) for contest_dict in resp] 289 | 290 | @staticmethod 291 | async def ratingChanges(*, contest_id): 292 | params = {'contestId': contest_id} 293 | try: 294 | resp = await _query_api('contest.ratingChanges', params) 295 | except TrueApiError as e: 296 | if 'not found' in e.comment: 297 | raise ContestNotFoundError(e.comment, contest_id) 298 | if 'Rating changes are unavailable' in e.comment: 299 | raise RatingChangesUnavailableError(e.comment, contest_id) 300 | raise 301 | return [make_from_dict(RatingChange, change_dict) for change_dict in resp] 302 | 303 | @staticmethod 304 | async def standings(*, contest_id, from_=None, count=None, handles=None, room=None, 305 | show_unofficial=None): 306 | params = {'contestId': contest_id} 307 | if from_ is not None: 308 | params['from'] = from_ 309 | if count is not None: 310 | params['count'] = count 311 | if handles is not None: 312 | params['handles'] = ';'.join(handles) 313 | if room is not None: 314 | params['room'] = room 315 | if show_unofficial is not None: 316 | params['showUnofficial'] = _bool_to_str(show_unofficial) 317 | try: 318 | resp = await _query_api('contest.standings', params) 319 | except TrueApiError as e: 320 | if 'not found' in e.comment: 321 | raise ContestNotFoundError(e.comment, contest_id) 322 | raise 323 | contest_ = make_from_dict(Contest, resp['contest']) 324 | problems = [make_from_dict(Problem, problem_dict) for problem_dict in resp['problems']] 325 | for row in resp['rows']: 326 | row['party']['members'] = [make_from_dict(Member, member) 327 | for member in row['party']['members']] 328 | row['party'] = make_from_dict(Party, row['party']) 329 | row['problemResults'] = [make_from_dict(ProblemResult, problem_result) 330 | for problem_result in row['problemResults']] 331 | ranklist = [make_from_dict(RanklistRow, row_dict) for row_dict in resp['rows']] 332 | return contest_, problems, ranklist 333 | 334 | 335 | class problemset: 336 | @staticmethod 337 | async def problems(*, tags=None, problemset_name=None): 338 | params = {} 339 | if tags is not None: 340 | params['tags'] = ';'.join(tags) 341 | if problemset_name is not None: 342 | params['problemsetName'] = problemset_name 343 | resp = await _query_api('problemset.problems', params) 344 | problems = [make_from_dict(Problem, problem_dict) for problem_dict in resp['problems']] 345 | problemstats = [make_from_dict(ProblemStatistics, problemstat_dict) for problemstat_dict in 346 | resp['problemStatistics']] 347 | return problems, problemstats 348 | 349 | def user_info_chunkify(handles): 350 | """ 351 | Querying user.info using POST requests is limited to 10000 handles or 2**16 352 | bytes, so requests might need to be split into chunks 353 | """ 354 | SIZE_LIMIT = 2**16 355 | HANDLE_LIMIT = 10000 356 | chunk = [] 357 | size = 0 358 | for handle in handles: 359 | if size + len(handle) > SIZE_LIMIT or len(chunk) == HANDLE_LIMIT: 360 | yield chunk 361 | chunk = [] 362 | size = 0 363 | chunk.append(handle) 364 | size += len(handle) + 1 365 | if chunk: 366 | yield chunk 367 | 368 | class user: 369 | @staticmethod 370 | async def info(*, handles): 371 | chunks = list(user_info_chunkify(handles)) 372 | if len(chunks) > 1: 373 | logger.warning(f'cf.info request with {len(handles)} handles,' 374 | f'will be chunkified into {len(chunks)} requests.') 375 | 376 | result = [] 377 | for chunk in chunks: 378 | params = {'handles': ';'.join(chunk)} 379 | try: 380 | resp = await _query_api('user.info', params) 381 | except TrueApiError as e: 382 | if 'not found' in e.comment: 383 | # Comment format is "handles: User with handle ***** not found" 384 | handle = e.comment.partition('not found')[0].split()[-1] 385 | raise HandleNotFoundError(e.comment, handle) 386 | raise 387 | result += [make_from_dict(User, user_dict) for user_dict in resp] 388 | return [cf_common.fix_urls(user) for user in result] 389 | 390 | @staticmethod 391 | def correct_rating_changes(*, resp): 392 | adaptO = [1400, 900, 550, 300, 150, 50] 393 | adaptN = [900, 550, 300, 150, 50, 0] 394 | for r in resp: 395 | if (len(r) > 0): 396 | if (r[0].newRating <= 1200): 397 | for ind in range(0,(min(6, len(r)))): 398 | r[ind] = RatingChange(r[ind].contestId, r[ind].contestName, r[ind].handle, r[ind].rank, r[ind].ratingUpdateTimeSeconds, r[ind].oldRating+adaptO[ind], r[ind].newRating+adaptN[ind]) 399 | else: 400 | r[0] = RatingChange(r[0].contestId, r[0].contestName, r[0].handle, r[0].rank, r[0].ratingUpdateTimeSeconds, r[0].oldRating+1500, r[0].newRating) 401 | for r in resp: 402 | oldPerf = 0 403 | for ind in range(0,len(r)): 404 | r[ind] = RatingChange(r[ind].contestId, r[ind].contestName, r[ind].handle, r[ind].rank, r[ind].ratingUpdateTimeSeconds, oldPerf, r[ind].oldRating + 4*(r[ind].newRating-r[ind].oldRating)) 405 | oldPerf = r[ind].oldRating + 4*(r[ind].newRating-r[ind].oldRating) 406 | return resp 407 | 408 | 409 | @staticmethod 410 | async def rating(*, handle): 411 | params = {'handle': handle} 412 | try: 413 | resp = await _query_api('user.rating', params) 414 | except TrueApiError as e: 415 | if 'not found' in e.comment: 416 | raise HandleNotFoundError(e.comment, handle) 417 | if 'should contain' in e.comment: 418 | raise HandleInvalidError(e.comment, handle) 419 | raise 420 | return [make_from_dict(RatingChange, ratingchange_dict) for ratingchange_dict in resp] 421 | 422 | @staticmethod 423 | async def ratedList(*, activeOnly=None): 424 | params = {} 425 | if activeOnly is not None: 426 | params['activeOnly'] = _bool_to_str(activeOnly) 427 | resp = await _query_api('user.ratedList', params) 428 | return [make_from_dict(User, user_dict) for user_dict in resp] 429 | 430 | @staticmethod 431 | async def status(*, handle, from_=None, count=None): 432 | params = {'handle': handle} 433 | if from_ is not None: 434 | params['from'] = from_ 435 | if count is not None: 436 | params['count'] = count 437 | try: 438 | resp = await _query_api('user.status', params) 439 | except TrueApiError as e: 440 | if 'not found' in e.comment: 441 | raise HandleNotFoundError(e.comment, handle) 442 | if 'should contain' in e.comment: 443 | raise HandleInvalidError(e.comment, handle) 444 | raise 445 | for submission in resp: 446 | submission['problem'] = make_from_dict(Problem, submission['problem']) 447 | submission['author']['members'] = [make_from_dict(Member, member) 448 | for member in submission['author']['members']] 449 | submission['author'] = make_from_dict(Party, submission['author']) 450 | return [make_from_dict(Submission, submission_dict) for submission_dict in resp] 451 | 452 | 453 | async def _needs_fixing(handles): 454 | to_fix = [] 455 | chunks = user_info_chunkify(handles) 456 | for handle_chunk in chunks: 457 | while handle_chunk: 458 | try: 459 | cf_users = await user.info(handles=handle_chunk) 460 | 461 | # Users could still have changed capitalization 462 | for handle, cf_user in zip(handle_chunk, cf_users): 463 | assert handle.lower() == cf_user.handle.lower() 464 | if handle != cf_user.handle: 465 | to_fix.append(handle) 466 | break 467 | except HandleNotFoundError as e: 468 | to_fix.append(e.handle) 469 | handle_chunk.remove(e.handle) 470 | time.sleep(1) 471 | return to_fix 472 | 473 | 474 | async def _resolve_redirect(handle): 475 | url = PROFILE_BASE_URL + handle 476 | async with _session.head(url) as r: 477 | if r.status == 200: 478 | return handle 479 | if r.status == 301 or r.status == 302: 480 | redirected = r.headers.get('Location') 481 | if '/profile/' not in redirected: 482 | # Ended up not on profile page, probably invalid handle 483 | return None 484 | return redirected.split('/profile/')[-1] 485 | raise CodeforcesApiError( 486 | f'Something went wrong trying to redirect {url}') 487 | 488 | 489 | async def _resolve_handle_mapping(handles_to_fix): 490 | redirections = {} 491 | failed = [] 492 | for handle in handles_to_fix: 493 | new_handle = await _resolve_redirect(handle) 494 | if not new_handle: 495 | redirections[handle] = None 496 | else: 497 | cf_user, = await user.info(handles=[new_handle]) 498 | redirections[handle] = cf_user 499 | time.sleep(1) 500 | return redirections 501 | 502 | 503 | async def resolve_redirects(handles): 504 | handles_to_fix = await _needs_fixing(handles) 505 | handle_mapping = await _resolve_handle_mapping(handles_to_fix) 506 | return handle_mapping 507 | -------------------------------------------------------------------------------- /tle/util/codeforces_common.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import logging 4 | import math 5 | import time 6 | import datetime 7 | from collections import defaultdict 8 | import itertools 9 | from discord.ext import commands 10 | import discord 11 | 12 | from tle import constants 13 | from tle.util import cache_system2 14 | from tle.util import codeforces_api as cf 15 | from tle.util import db 16 | from tle.util import events 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | # Connection to database 21 | user_db = None 22 | 23 | # Cache system 24 | cache2 = None 25 | 26 | # Event system 27 | event_sys = events.EventSystem() 28 | 29 | _contest_id_to_writers_map = None 30 | 31 | _initialize_done = False 32 | 33 | active_groups = defaultdict(set) 34 | 35 | 36 | async def initialize(nodb): 37 | global cache2 38 | global user_db 39 | global event_sys 40 | global _contest_id_to_writers_map 41 | global _initialize_done 42 | 43 | if _initialize_done: 44 | # This happens if the bot loses connection to Discord and on_ready is triggered again 45 | # when it reconnects. 46 | return 47 | 48 | await cf.initialize() 49 | 50 | if nodb: 51 | user_db = db.DummyUserDbConn() 52 | else: 53 | user_db = db.UserDbConn(constants.USER_DB_FILE_PATH) 54 | 55 | cache_db = db.CacheDbConn(constants.CACHE_DB_FILE_PATH) 56 | cache2 = cache_system2.CacheSystem(cache_db) 57 | await cache2.run() 58 | 59 | try: 60 | with open(constants.CONTEST_WRITERS_JSON_FILE_PATH) as f: 61 | data = json.load(f) 62 | _contest_id_to_writers_map = {contest['id']: [s.lower() for s in contest['writers']] for contest in data} 63 | logger.info('Contest writers loaded from JSON file') 64 | except FileNotFoundError: 65 | logger.warning('JSON file containing contest writers not found') 66 | 67 | _initialize_done = True 68 | 69 | 70 | # algmyr's guard idea: 71 | def user_guard(*, group, get_exception=None): 72 | active = active_groups[group] 73 | 74 | def guard(fun): 75 | @functools.wraps(fun) 76 | async def f(self, ctx, *args, **kwargs): 77 | user = ctx.message.author.id 78 | if user in active: 79 | logger.info(f'{user} repeatedly calls {group} group') 80 | if get_exception is not None: 81 | raise get_exception() 82 | return 83 | active.add(user) 84 | try: 85 | await fun(self, ctx, *args, **kwargs) 86 | finally: 87 | active.remove(user) 88 | 89 | return f 90 | 91 | return guard 92 | 93 | 94 | def is_contest_writer(contest_id, handle): 95 | if _contest_id_to_writers_map is None: 96 | return False 97 | writers = _contest_id_to_writers_map.get(contest_id) 98 | return writers and handle.lower() in writers 99 | 100 | 101 | _NONSTANDARD_CONTEST_INDICATORS = [ 102 | 'wild', 'fools', 'unrated', 'surprise', 'unknown', 'friday', 'q#', 'testing', 103 | 'marathon', 'kotlin', 'onsite', 'experimental', 'abbyy', 'icpc'] 104 | 105 | 106 | def is_nonstandard_contest(contest): 107 | return any(string in contest.name.lower() for string in _NONSTANDARD_CONTEST_INDICATORS) 108 | 109 | def is_nonstandard_problem(problem): 110 | return (is_nonstandard_contest(cache2.contest_cache.get_contest(problem.contestId)) or 111 | problem.matches_all_tags(['*special'])) 112 | 113 | 114 | async def get_visited_contests(handles : [str]): 115 | """ Returns a set of contest ids of contests that any of the given handles 116 | has at least one non-CE submission. 117 | """ 118 | user_submissions = [await cf.user.status(handle=handle) for handle in handles] 119 | problem_to_contests = cache2.problemset_cache.problem_to_contests 120 | 121 | contest_ids = [] 122 | for sub in itertools.chain.from_iterable(user_submissions): 123 | if sub.verdict == 'COMPILATION_ERROR': 124 | continue 125 | try: 126 | contest = cache2.contest_cache.get_contest(sub.problem.contestId) 127 | problem_id = (sub.problem.name, contest.startTimeSeconds) 128 | contest_ids += problem_to_contests[problem_id] 129 | except cache_system2.ContestNotFound: 130 | pass 131 | return set(contest_ids) 132 | 133 | # These are special rated-for-all contests which have a combined ranklist for onsite and online 134 | # participants. The onsite participants have their submissions marked as out of competition. Just 135 | # Codeforces things. 136 | _RATED_FOR_ONSITE_CONTEST_IDS = [ 137 | 86, # Yandex.Algorithm 2011 Round 2 https://codeforces.com/contest/86 138 | 173, # Croc Champ 2012 - Round 1 https://codeforces.com/contest/173 139 | 335, # MemSQL start[c]up Round 2 - online version https://codeforces.com/contest/335 140 | ] 141 | 142 | 143 | def is_rated_for_onsite_contest(contest): 144 | return contest.id in _RATED_FOR_ONSITE_CONTEST_IDS 145 | 146 | 147 | class ResolveHandleError(commands.CommandError): 148 | pass 149 | 150 | 151 | class HandleCountOutOfBoundsError(ResolveHandleError): 152 | def __init__(self, mincnt, maxcnt): 153 | super().__init__(f'Number of handles must be between {mincnt} and {maxcnt}') 154 | 155 | 156 | class FindMemberFailedError(ResolveHandleError): 157 | def __init__(self, member): 158 | super().__init__(f'Unable to convert `{member}` to a server member') 159 | 160 | 161 | class HandleNotRegisteredError(ResolveHandleError): 162 | def __init__(self, member): 163 | super().__init__(f'Codeforces handle for {member.mention} not found in database. ' 164 | 'Use ;handle identify (where needs to be replaced with your codeforces handle, e.g. ;handle identify tourist) to add yourself to the database') 165 | 166 | 167 | class HandleIsVjudgeError(ResolveHandleError): 168 | HANDLES = ('vjudge1 vjudge2 vjudge3 vjudge4 vjudge5 ' 169 | 'luogu_bot1 luogu_bot2 luogu_bot3 luogu_bot4 luogu_bot5').split() 170 | 171 | def __init__(self, handle): 172 | super().__init__(f"`{handle}`? I'm not doing that!\n\n(╯°□°)╯︵ ┻━┻") 173 | 174 | 175 | class FilterError(commands.CommandError): 176 | pass 177 | 178 | class ParamParseError(FilterError): 179 | pass 180 | 181 | def time_format(seconds): 182 | seconds = int(seconds) 183 | days, seconds = divmod(seconds, 86400) 184 | hours, seconds = divmod(seconds, 3600) 185 | minutes, seconds = divmod(seconds, 60) 186 | return days, hours, minutes, seconds 187 | 188 | 189 | def pretty_time_format(seconds, *, shorten=False, only_most_significant=False, always_seconds=False): 190 | days, hours, minutes, seconds = time_format(seconds) 191 | timespec = [ 192 | (days, 'day', 'days'), 193 | (hours, 'hour', 'hours'), 194 | (minutes, 'minute', 'minutes'), 195 | ] 196 | timeprint = [(cnt, singular, plural) for cnt, singular, plural in timespec if cnt] 197 | if not timeprint or always_seconds: 198 | timeprint.append((seconds, 'second', 'seconds')) 199 | if only_most_significant: 200 | timeprint = [timeprint[0]] 201 | 202 | def format_(triple): 203 | cnt, singular, plural = triple 204 | return f'{cnt}{singular[0]}' if shorten else f'{cnt} {singular if cnt == 1 else plural}' 205 | 206 | return ' '.join(map(format_, timeprint)) 207 | 208 | def get_start_and_end_of_month(time): 209 | time = time.replace(day=1, hour=0, minute=0, second=0, microsecond=0) 210 | start_time = int(time.timestamp()) 211 | if time.month == 12: 212 | time = time.replace(month=1,year=time.year+1) 213 | else: 214 | time = time.replace(month=time.month+1) 215 | end_time = int(time.timestamp()) 216 | return start_time, end_time 217 | 218 | def get_start_and_end_of_day(time): # needs editing -> TODO: make it return the start of day i.e. UTC 00:00 219 | time = time.replace(day=1, hour=0, minute=0, second=0, microsecond=0) 220 | start_time = int(time.timestamp()) 221 | if time.month == 12: 222 | time = time.replace(month=1,year=time.year+1) 223 | else: 224 | time = time.replace(month=time.month+1) 225 | end_time = int(time.timestamp()) 226 | return start_time, end_time 227 | 228 | 229 | def days_ago(t): 230 | days = (time.time() - t)/(60*60*24) 231 | if days < 1: 232 | return 'today' 233 | if days < 2: 234 | return 'yesterday' 235 | return f'{math.floor(days)} days ago' 236 | 237 | async def resolve_handles(ctx, converter, handles, *, mincnt=1, maxcnt=5, default_to_all_server=False): 238 | """Convert an iterable of strings to CF handles. A string beginning with ! indicates Discord username, 239 | otherwise it is a raw CF handle to be left unchanged.""" 240 | handles = set(handles) 241 | if default_to_all_server and not handles: 242 | handles.add('+server') 243 | if '+server' in handles: 244 | handles.remove('+server') 245 | guild_handles = {handle for discord_id, handle 246 | in user_db.get_handles_for_guild(ctx.guild.id)} 247 | handles.update(guild_handles) 248 | if len(handles) < mincnt or (maxcnt and maxcnt < len(handles)): 249 | raise HandleCountOutOfBoundsError(mincnt, maxcnt) 250 | resolved_handles = [] 251 | for handle in handles: 252 | if handle.startswith('!'): 253 | # ! denotes Discord user 254 | member_identifier = handle[1:] 255 | # suffix removal as quickfix for new username changes 256 | if member_identifier[-2:] == '#0': 257 | member_identifier = member_identifier[:-2] 258 | 259 | try: 260 | member = await converter.convert(ctx, member_identifier) 261 | except commands.errors.CommandError: 262 | raise FindMemberFailedError(member_identifier) 263 | handle = user_db.get_handle(member.id, ctx.guild.id) 264 | if handle is None: 265 | raise HandleNotRegisteredError(member) 266 | if handle in HandleIsVjudgeError.HANDLES: 267 | raise HandleIsVjudgeError(handle) 268 | resolved_handles.append(handle) 269 | return resolved_handles 270 | 271 | def members_to_handles(members: [discord.Member], guild_id): 272 | handles = [] 273 | for member in members: 274 | handle = user_db.get_handle(member.id, guild_id) 275 | if handle is None: 276 | raise HandleNotRegisteredError(member) 277 | handles.append(handle) 278 | return handles 279 | 280 | def filter_flags(args, params): 281 | args = list(args) 282 | flags = [False] * len(params) 283 | rest = [] 284 | for arg in args: 285 | try: 286 | flags[params.index(arg)] = True 287 | except ValueError: 288 | rest.append(arg) 289 | return flags, rest 290 | 291 | def negate_flags(*args): 292 | return [not x for x in args] 293 | 294 | def parse_date(arg): 295 | try: 296 | if len(arg) == 8: 297 | fmt = '%d%m%Y' 298 | elif len(arg) == 6: 299 | fmt = '%m%Y' 300 | elif len(arg) == 4: 301 | fmt = '%Y' 302 | else: 303 | raise ValueError 304 | return time.mktime(datetime.datetime.strptime(arg, fmt).timetuple()) 305 | except ValueError: 306 | raise ParamParseError(f'{arg} is an invalid date argument') 307 | 308 | 309 | def parse_tags(args, *, prefix): 310 | tags = [x[1:] for x in args if x[0] == prefix] 311 | return tags 312 | 313 | 314 | def parse_rating(args, default_value = None): 315 | for arg in args: 316 | if arg.isdigit(): 317 | return int(arg) 318 | return default_value 319 | 320 | def fix_urls(user: cf.User): 321 | if user.titlePhoto.startswith('//'): 322 | user = user._replace(titlePhoto = 'https:' + user.titlePhoto) 323 | return user 324 | 325 | 326 | class SubFilter: 327 | def __init__(self, rated=True): 328 | self.team = False 329 | self.rated = rated 330 | self.dlo, self.dhi = 0, 10**10 331 | self.rlo, self.rhi = 500, 3800 332 | self.types = [] 333 | self.tags = [] 334 | self.bantags = [] 335 | self.contests = [] 336 | self.indices = [] 337 | 338 | def parse(self, args): 339 | args = list(set(args)) 340 | rest = [] 341 | 342 | for arg in args: 343 | if arg == '+team': 344 | self.team = True 345 | elif arg == '+contest': 346 | self.types.append('CONTESTANT') 347 | elif arg =='+outof': 348 | self.types.append('OUT_OF_COMPETITION') 349 | elif arg == '+virtual': 350 | self.types.append('VIRTUAL') 351 | elif arg == '+practice': 352 | self.types.append('PRACTICE') 353 | elif arg[0:2] == 'c+': 354 | self.contests.append(arg[2:]) 355 | elif arg[0:2] == 'i+': 356 | self.indices.append(arg[2:]) 357 | elif arg[0] == '+': 358 | if len(arg) == 1: 359 | raise ParamParseError('Problem tag cannot be empty.') 360 | self.tags.append(arg[1:]) 361 | elif arg[0] == '~': 362 | if len(arg) == 1: 363 | raise ParamParseError('Problem tag cannot be empty.') 364 | self.bantags.append(arg[1:]) 365 | elif arg[0:2] == 'd<': 366 | self.dhi = min(self.dhi, parse_date(arg[2:])) 367 | elif arg[0:3] == 'd>=': 368 | self.dlo = max(self.dlo, parse_date(arg[3:])) 369 | elif arg[0:3] in ['r<=', 'r>=']: 370 | if len(arg) < 4: 371 | raise ParamParseError(f'{arg} is an invalid rating argument') 372 | elif arg[1] == '>': 373 | self.rlo = max(self.rlo, int(arg[3:])) 374 | else: 375 | self.rhi = min(self.rhi, int(arg[3:])) 376 | self.rated = True 377 | else: 378 | rest.append(arg) 379 | 380 | self.types = self.types or ['CONTESTANT', 'OUT_OF_COMPETITION', 'VIRTUAL', 'PRACTICE'] 381 | return rest 382 | 383 | @staticmethod 384 | def filter_solved(submissions): 385 | """Filters and keeps only solved submissions. If a problem is solved multiple times the first 386 | accepted submission is kept. The unique id for a problem is (problem name, contest start time). 387 | """ 388 | submissions.sort(key=lambda sub: sub.creationTimeSeconds) 389 | problems = set() 390 | solved_subs = [] 391 | 392 | for submission in submissions: 393 | problem = submission.problem 394 | contest = cache2.contest_cache.contest_by_id.get(problem.contestId, None) 395 | if submission.verdict == 'OK': 396 | # Assume (name, contest start time) is a unique identifier for problems 397 | problem_key = (problem.name, contest.startTimeSeconds if contest else 0) 398 | if problem_key not in problems: 399 | solved_subs.append(submission) 400 | problems.add(problem_key) 401 | return solved_subs 402 | 403 | def filter_subs(self, submissions): 404 | submissions = SubFilter.filter_solved(submissions) 405 | filtered_subs = [] 406 | for submission in submissions: 407 | problem = submission.problem 408 | contest = cache2.contest_cache.contest_by_id.get(problem.contestId, None) 409 | type_ok = submission.author.participantType in self.types 410 | date_ok = self.dlo <= submission.creationTimeSeconds < self.dhi 411 | tag_ok = problem.matches_all_tags(self.tags) 412 | bantag_ok = not problem.matches_any_tag(self.bantags) 413 | index_ok = not self.indices or any(index.lower() == problem.index.lower() for index in self.indices) 414 | contest_ok = not self.contests or (contest and contest.matches(self.contests)) 415 | team_ok = self.team or len(submission.author.members) == 1 416 | if self.rated: 417 | problem_ok = contest and contest.id < cf.GYM_ID_THRESHOLD and not is_nonstandard_problem(problem) 418 | rating_ok = problem.rating and self.rlo <= problem.rating <= self.rhi 419 | else: 420 | # acmsguru and gym allowed 421 | problem_ok = (not contest or contest.id >= cf.GYM_ID_THRESHOLD 422 | or not is_nonstandard_problem(problem)) 423 | rating_ok = True 424 | if type_ok and date_ok and rating_ok and tag_ok and bantag_ok and team_ok and problem_ok and contest_ok and index_ok: 425 | filtered_subs.append(submission) 426 | return filtered_subs 427 | 428 | def filter_rating_changes(self, rating_changes): 429 | rating_changes = [change for change in rating_changes 430 | if self.dlo <= change.ratingUpdateTimeSeconds < self.dhi] 431 | return rating_changes 432 | -------------------------------------------------------------------------------- /tle/util/cses_scraper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import aiohttp 4 | from lxml import html 5 | 6 | 7 | class CSESError(Exception): 8 | pass 9 | 10 | 11 | session = aiohttp.ClientSession() 12 | 13 | 14 | async def _fetch(url): 15 | async with session.get(url) as response: 16 | if response.status != 200: 17 | raise CSESError(f"Bad response from CSES, status code {status}") 18 | tree = html.fromstring(await response.read()) 19 | return tree 20 | 21 | 22 | async def get_problems(): 23 | tree = await _fetch('https://cses.fi/problemset/list/') 24 | links = [li.get('href') for li in tree.xpath('//*[@class="task"]/a')] 25 | ids = sorted(int(x.split('/')[-1]) for x in links) 26 | return ids 27 | 28 | 29 | async def get_problem_leaderboard(num): 30 | tree = await _fetch(f'https://cses.fi/problemset/stats/{num}/') 31 | fastest_table, shortest_table = tree.xpath( 32 | '//table[@class!="summary-table" and @class!="bot-killer"]') 33 | 34 | fastest = [a.text for a in fastest_table.xpath('.//a')] 35 | shortest = [a.text for a in shortest_table.xpath('.//a')] 36 | return fastest, shortest 37 | -------------------------------------------------------------------------------- /tle/util/db/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache_db_conn import * 2 | from .user_db_conn import * 3 | -------------------------------------------------------------------------------- /tle/util/db/cache_db_conn.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sqlite3 3 | 4 | from tle.util import codeforces_api as cf 5 | 6 | 7 | class CacheDbConn: 8 | def __init__(self, db_file): 9 | self.conn = sqlite3.connect(db_file) 10 | self.create_tables() 11 | 12 | def create_tables(self): 13 | # Table for contests from the contest.list endpoint. 14 | self.conn.execute( 15 | 'CREATE TABLE IF NOT EXISTS contest (' 16 | 'id INTEGER NOT NULL,' 17 | 'name TEXT,' 18 | 'start_time INTEGER,' 19 | 'duration INTEGER,' 20 | 'type TEXT,' 21 | 'phase TEXT,' 22 | 'prepared_by TEXT,' 23 | 'PRIMARY KEY (id)' 24 | ')' 25 | ) 26 | 27 | # Table for problems from the problemset.problems endpoint. 28 | self.conn.execute( 29 | 'CREATE TABLE IF NOT EXISTS problem (' 30 | 'contest_id INTEGER,' 31 | 'problemset_name TEXT,' 32 | '[index] TEXT,' 33 | 'name TEXT NOT NULL,' 34 | 'type TEXT,' 35 | 'points REAL,' 36 | 'rating INTEGER,' 37 | 'tags TEXT,' 38 | 'PRIMARY KEY (name)' 39 | ')' 40 | ) 41 | 42 | # Table for rating changes fetched from contest.ratingChanges endpoint for every contest. 43 | self.conn.execute( 44 | 'CREATE TABLE IF NOT EXISTS rating_change (' 45 | 'contest_id INTEGER NOT NULL,' 46 | 'handle TEXT NOT NULL,' 47 | 'rank INTEGER,' 48 | 'rating_update_time INTEGER,' 49 | 'old_rating INTEGER,' 50 | 'new_rating INTEGER,' 51 | 'UNIQUE (contest_id, handle)' 52 | ')' 53 | ) 54 | self.conn.execute('CREATE INDEX IF NOT EXISTS ix_rating_change_contest_id ' 55 | 'ON rating_change (contest_id)') 56 | self.conn.execute('CREATE INDEX IF NOT EXISTS ix_rating_change_handle ' 57 | 'ON rating_change (handle)') 58 | self.conn.execute('CREATE INDEX IF NOT EXISTS ix_rating_change_rating_update_time ' 59 | 'ON rating_change (handle ASC, rating_update_time DESC)') 60 | 61 | # Table for problems fetched from contest.standings endpoint for every contest. 62 | # This is separate from table problem as it contains the same problem twice if it 63 | # appeared in both Div 1 and Div 2 of some round. 64 | self.conn.execute( 65 | 'CREATE TABLE IF NOT EXISTS problem2 (' 66 | 'contest_id INTEGER,' 67 | 'problemset_name TEXT,' 68 | '[index] TEXT,' 69 | 'name TEXT NOT NULL,' 70 | 'type TEXT,' 71 | 'points REAL,' 72 | 'rating INTEGER,' 73 | 'tags TEXT,' 74 | 'PRIMARY KEY (contest_id, [index])' 75 | ')' 76 | ) 77 | self.conn.execute('CREATE INDEX IF NOT EXISTS ix_problem2_contest_id ' 78 | 'ON problem2 (contest_id)') 79 | 80 | def cache_contests(self, contests): 81 | query = ('INSERT OR REPLACE INTO contest ' 82 | '(id, name, start_time, duration, type, phase, prepared_by) ' 83 | 'VALUES (?, ?, ?, ?, ?, ?, ?)') 84 | rc = self.conn.executemany(query, contests).rowcount 85 | self.conn.commit() 86 | return rc 87 | 88 | def fetch_contests(self): 89 | query = ('SELECT id, name, start_time, duration, type, phase, prepared_by ' 90 | 'FROM contest') 91 | res = self.conn.execute(query).fetchall() 92 | return [cf.Contest._make(contest) for contest in res] 93 | 94 | @staticmethod 95 | def _squish_tags(problem): 96 | return (problem.contestId, problem.problemsetName, problem.index, problem.name, 97 | problem.type, problem.points, problem.rating, json.dumps(problem.tags)) 98 | 99 | def cache_problems(self, problems): 100 | query = ('INSERT OR REPLACE INTO problem ' 101 | '(contest_id, problemset_name, [index], name, type, points, rating, tags) ' 102 | 'VALUES (?, ?, ?, ?, ?, ?, ?, ?)') 103 | rc = self.conn.executemany(query, list(map(self._squish_tags, problems))).rowcount 104 | self.conn.commit() 105 | return rc 106 | 107 | @staticmethod 108 | def _unsquish_tags(problem): 109 | args, tags = problem[:-1], json.loads(problem[-1]) 110 | return cf.Problem(*args, tags) 111 | 112 | def fetch_problems(self): 113 | query = ('SELECT contest_id, problemset_name, [index], name, type, points, rating, tags ' 114 | 'FROM problem') 115 | res = self.conn.execute(query).fetchall() 116 | return list(map(self._unsquish_tags, res)) 117 | 118 | def save_rating_changes(self, changes): 119 | change_tuples = [(change.contestId, 120 | change.handle, 121 | change.rank, 122 | change.ratingUpdateTimeSeconds, 123 | change.oldRating, 124 | change.newRating) for change in changes] 125 | query = ('INSERT OR REPLACE INTO rating_change ' 126 | '(contest_id, handle, rank, rating_update_time, old_rating, new_rating) ' 127 | 'VALUES (?, ?, ?, ?, ?, ?)') 128 | rc = self.conn.executemany(query, change_tuples).rowcount 129 | self.conn.commit() 130 | return rc 131 | 132 | def clear_rating_changes(self, contest_id=None): 133 | if contest_id is None: 134 | query = 'DELETE FROM rating_change' 135 | self.conn.execute(query) 136 | else: 137 | query = 'DELETE FROM rating_change WHERE contest_id = ?' 138 | self.conn.execute(query, (contest_id,)) 139 | self.conn.commit() 140 | 141 | def get_users_with_more_than_n_contests(self, time_cutoff, n): 142 | query = ('SELECT handle, COUNT(*) AS num_contests ' 143 | 'FROM rating_change GROUP BY handle HAVING num_contests >= ? ' 144 | 'AND MAX(rating_update_time) >= ?') 145 | res = self.conn.execute(query, (n, time_cutoff,)).fetchall() 146 | return [user[0] for user in res] 147 | 148 | def get_all_rating_changes(self): 149 | query = ('SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating ' 150 | 'FROM rating_change r ' 151 | 'LEFT JOIN contest c ' 152 | 'ON r.contest_id = c.id ' 153 | 'ORDER BY rating_update_time') 154 | res = self.conn.execute(query) 155 | return (cf.RatingChange._make(change) for change in res) 156 | 157 | def get_rating_changes_for_contest(self, contest_id): 158 | query = ('SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating ' 159 | 'FROM rating_change r ' 160 | 'LEFT JOIN contest c ' 161 | 'ON r.contest_id = c.id ' 162 | 'WHERE r.contest_id = ?') 163 | res = self.conn.execute(query, (contest_id,)).fetchall() 164 | return [cf.RatingChange._make(change) for change in res] 165 | 166 | def has_rating_changes_saved(self, contest_id): 167 | query = ('SELECT contest_id ' 168 | 'FROM rating_change ' 169 | 'WHERE contest_id = ?') 170 | res = self.conn.execute(query, (contest_id,)).fetchone() 171 | return res is not None 172 | 173 | def get_rating_changes_for_handle(self, handle): 174 | query = ('SELECT contest_id, name, handle, rank, rating_update_time, old_rating, new_rating ' 175 | 'FROM rating_change r ' 176 | 'LEFT JOIN contest c ' 177 | 'ON r.contest_id = c.id ' 178 | 'WHERE r.handle = ?') 179 | res = self.conn.execute(query, (handle,)).fetchall() 180 | return [cf.RatingChange._make(change) for change in res] 181 | 182 | def get_all_ratings_before_timestamp(self, timestamp): 183 | query = ('SELECT contest_id, "Dummy", handle, rank, rating_update_time, old_rating, new_rating ' 184 | 'FROM rating_change ' 185 | 'WHERE rating_update_time < ? ' 186 | 'GROUP BY handle ' 187 | 'HAVING MAX(rating_update_time)') 188 | res = self.conn.execute(query, (timestamp,)).fetchall() 189 | return [cf.RatingChange._make(change) for change in res] 190 | 191 | def cache_problemset(self, problemset): 192 | query = ('INSERT OR REPLACE INTO problem2 ' 193 | '(contest_id, problemset_name, [index], name, type, points, rating, tags) ' 194 | 'VALUES (?, ?, ?, ?, ?, ?, ?, ?)') 195 | rc = self.conn.executemany(query, list(map(self._squish_tags, problemset))).rowcount 196 | self.conn.commit() 197 | return rc 198 | 199 | def fetch_problems2(self): 200 | query = ('SELECT contest_id, problemset_name, [index], name, type, points, rating, tags ' 201 | 'FROM problem2 ') 202 | res = self.conn.execute(query).fetchall() 203 | return list(map(self._unsquish_tags, res)) 204 | 205 | def clear_problemset(self, contest_id=None): 206 | if contest_id is None: 207 | query = 'DELETE FROM problem2' 208 | self.conn.execute(query) 209 | else: 210 | query = 'DELETE FROM problem2 WHERE contest_id = ?' 211 | self.conn.execute(query, (contest_id,)) 212 | 213 | def fetch_problemset(self, contest_id): 214 | query = ('SELECT contest_id, problemset_name, [index], name, type, points, rating, tags ' 215 | 'FROM problem2 ' 216 | 'WHERE contest_id = ?') 217 | res = self.conn.execute(query, (contest_id,)).fetchall() 218 | return list(map(self._unsquish_tags, res)) 219 | 220 | def problemset_empty(self): 221 | query = 'SELECT 1 FROM problem2' 222 | res = self.conn.execute(query).fetchone() 223 | return res is None 224 | 225 | def close(self): 226 | self.conn.close() 227 | -------------------------------------------------------------------------------- /tle/util/discord_common.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import functools 4 | import random 5 | 6 | import discord 7 | from discord.ext import commands 8 | 9 | from tle.util import codeforces_api as cf 10 | from tle.util import db 11 | from tle.util import tasks 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | _CF_COLORS = (0xFFCA1F, 0x198BCC, 0xFF2020) 16 | _SUCCESS_GREEN = 0x28A745 17 | _ALERT_AMBER = 0xFFBF00 18 | _BOT_PREFIX = ';' 19 | 20 | 21 | def embed_neutral(desc, color=None): 22 | return discord.Embed(description=str(desc), color=color) 23 | 24 | 25 | def embed_success(desc): 26 | return discord.Embed(description=str(desc), color=_SUCCESS_GREEN) 27 | 28 | 29 | def embed_alert(desc): 30 | return discord.Embed(description=str(desc), color=_ALERT_AMBER) 31 | 32 | 33 | def random_cf_color(): 34 | return random.choice(_CF_COLORS) 35 | 36 | 37 | def cf_color_embed(**kwargs): 38 | return discord.Embed(**kwargs, color=random_cf_color()) 39 | 40 | 41 | def set_same_cf_color(embeds): 42 | color = random_cf_color() 43 | for embed in embeds: 44 | embed.color=color 45 | 46 | 47 | def attach_image(embed, img_file): 48 | embed.set_image(url=f'attachment://{img_file.filename}') 49 | 50 | 51 | def set_author_footer(embed, user): 52 | embed.set_footer(text=f'Requested by {user}', icon_url=user.avatar) 53 | 54 | 55 | def send_error_if(*error_cls): 56 | """Decorator for `cog_command_error` methods. Decorated methods send the error in an alert embed 57 | when the error is an instance of one of the specified errors, otherwise the wrapped function is 58 | invoked. 59 | """ 60 | def decorator(func): 61 | @functools.wraps(func) 62 | async def wrapper(cog, ctx, error): 63 | if isinstance(error, error_cls): 64 | await ctx.send(embed=embed_alert(error)) 65 | error.handled = True 66 | else: 67 | await func(cog, ctx, error) 68 | return wrapper 69 | return decorator 70 | 71 | 72 | async def bot_error_handler(ctx, exception): 73 | if getattr(exception, 'handled', False): 74 | # Errors already handled in cogs should have .handled = True 75 | return 76 | 77 | if isinstance(exception, db.DatabaseDisabledError): 78 | await ctx.send(embed=embed_alert('Sorry, the database is not available. Some features are disabled.')) 79 | elif isinstance(exception, commands.NoPrivateMessage): 80 | await ctx.send(embed=embed_alert('Commands are disabled in private channels')) 81 | elif isinstance(exception, commands.DisabledCommand): 82 | await ctx.send(embed=embed_alert('Sorry, this command is temporarily disabled')) 83 | elif isinstance(exception, (cf.CodeforcesApiError, commands.UserInputError)): 84 | await ctx.send(embed=embed_alert(exception)) 85 | else: 86 | msg = 'Ignoring exception in command {}:'.format(ctx.command) 87 | exc_info = type(exception), exception, exception.__traceback__ 88 | extra = { 89 | "message_content": ctx.message.content, 90 | "jump_url": ctx.message.jump_url 91 | } 92 | logger.exception(msg, exc_info=exc_info, extra=extra) 93 | 94 | 95 | def once(func): 96 | """Decorator that wraps the given async function such that it is executed only once.""" 97 | first = True 98 | 99 | @functools.wraps(func) 100 | async def wrapper(*args, **kwargs): 101 | nonlocal first 102 | if first: 103 | first = False 104 | await func(*args, **kwargs) 105 | 106 | return wrapper 107 | 108 | 109 | def on_ready_event_once(bot): 110 | """Decorator that uses bot.event to set the given function as the bot's on_ready event handler, 111 | but does not execute it more than once. 112 | """ 113 | def register_on_ready(func): 114 | @bot.event 115 | @once 116 | async def on_ready(): 117 | await func() 118 | 119 | return register_on_ready 120 | 121 | async def presence(bot): 122 | await bot.change_presence(activity=discord.Activity( 123 | type=discord.ActivityType.listening, 124 | name='your commands')) 125 | await asyncio.sleep(60) 126 | 127 | @tasks.task(name='OrzUpdate', 128 | waiter=tasks.Waiter.fixed_delay(5*60)) 129 | async def presence_task(_): 130 | while True: 131 | target = random.choice([ 132 | "i_pranav","badal_arya","abhi_wd", 133 | "thoughtlessnerd","hikaku","HimanshuRaj","denjell","harsh rishi miglani","Melon King","glenJP123"]) 134 | await bot.change_presence(activity=discord.Game( 135 | name=f'{target} orz')) 136 | await asyncio.sleep(10 * 60) 137 | 138 | presence_task.start() 139 | 140 | class TleHelp(commands.DefaultHelpCommand): 141 | def add_command_formatting(self, command): 142 | """A utility function to format the non-indented block of commands and groups. 143 | 144 | Parameters 145 | ------------ 146 | command: :class:`Command` 147 | The command to format. 148 | """ 149 | 150 | if command.description: 151 | self.paginator.add_line(command.description, empty=True) 152 | 153 | signature = _BOT_PREFIX + command.qualified_name 154 | if len(command.aliases) > 0: 155 | aliases = '|'.join(command.aliases) 156 | signature += '|'+aliases 157 | if command.usage: 158 | signature += " "+command.usage 159 | self.paginator.add_line(signature, empty=True) 160 | 161 | if command.help: 162 | try: 163 | self.paginator.add_line(command.help, empty=True) 164 | except RuntimeError: 165 | for line in command.help.splitlines(): 166 | self.paginator.add_line(line) 167 | self.paginator.add_line() 168 | 169 | -------------------------------------------------------------------------------- /tle/util/elo.py: -------------------------------------------------------------------------------- 1 | # ELO 2 | # python 3.4.3 3 | import math 4 | 5 | _ELO_CONSTANT = 60 6 | 7 | class ELOPlayer: 8 | def __init__(self): 9 | self.name = "" 10 | self.place = 0 11 | self.eloPre = 0 12 | self.eloPost = 0 13 | self.eloChange = 0 14 | 15 | 16 | class ELOMatch: 17 | def __init__(self): 18 | self.players = [] 19 | 20 | def addPlayer(self, name, place, elo): 21 | player = ELOPlayer() 22 | 23 | player.name = name 24 | player.place = place 25 | player.eloPre = elo 26 | 27 | self.players.append(player) 28 | 29 | def getELO(self, name): 30 | for p in self.players: 31 | if p.name == name: 32 | return p.eloPost 33 | 34 | return 1500 35 | 36 | def getELOChange(self, name): 37 | for p in self.players: 38 | if p.name == name: 39 | return p.eloChange 40 | 41 | return 0 42 | 43 | def calculateELOs(self): 44 | n = len(self.players) 45 | K = _ELO_CONSTANT / max(1, (n - 1)) 46 | 47 | for i in range(n): 48 | curPlace = self.players[i].place 49 | curELO = self.players[i].eloPre 50 | 51 | for j in range(n): 52 | if i != j: 53 | opponentPlace = self.players[j].place 54 | opponentELO = self.players[j].eloPre 55 | 56 | # work out S 57 | if curPlace < opponentPlace: 58 | S = 1.0 59 | elif curPlace == opponentPlace: 60 | S = 0.5 61 | else: 62 | S = 0.0 63 | 64 | # work out EA 65 | EA = 1 / (1.0 + math.pow(10.0, (opponentELO - curELO) / 400.0)) 66 | 67 | # calculate ELO change vs this one opponent, add it to our change bucket 68 | # I currently round at this point, this keeps rounding changes symetrical between EA and EB, but changes K more than it should 69 | self.players[i].eloChange += round(K * (S - EA)) 70 | 71 | # add accumulated change to initial ELO for final ELO 72 | 73 | self.players[i].eloPost = self.players[i].eloPre + self.players[i].eloChange 74 | 75 | -------------------------------------------------------------------------------- /tle/util/events.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from discord.ext import commands 5 | 6 | 7 | # Event types 8 | 9 | class Event: 10 | """Base class for events.""" 11 | pass 12 | 13 | 14 | class ContestListRefresh(Event): 15 | def __init__(self, contests): 16 | self.contests = contests 17 | 18 | 19 | class RatingChangesUpdate(Event): 20 | def __init__(self, *, contest, rating_changes): 21 | self.contest = contest 22 | self.rating_changes = rating_changes 23 | 24 | 25 | # Event errors 26 | 27 | class EventError(commands.CommandError): 28 | pass 29 | 30 | 31 | class ListenerNotRegistered(EventError): 32 | def __init__(self, listener): 33 | super().__init__(f'Listener {listener.name} is not registered for event ' 34 | f'{listener.event_cls.__name__}.') 35 | 36 | 37 | # Event system 38 | 39 | class EventSystem: 40 | """Rudimentary event system.""" 41 | 42 | def __init__(self): 43 | self.listeners_by_event = {} 44 | self.futures_by_event = {} 45 | self.logger = logging.getLogger(self.__class__.__name__) 46 | 47 | def add_listener(self, listener): 48 | listeners = self.listeners_by_event.setdefault(listener.event_cls, set()) 49 | listeners.add(listener) 50 | 51 | def remove_listener(self, listener): 52 | try: 53 | self.listeners_by_event[listener.event_cls].remove(listener) 54 | except KeyError: 55 | raise ListenerNotRegistered(listener) 56 | 57 | async def wait_for(self, event_cls, *, timeout=None): 58 | future = asyncio.get_running_loop().create_future() 59 | futures = self.futures_by_event.setdefault(event_cls, []) 60 | futures.append(future) 61 | return await asyncio.wait_for(future, timeout) 62 | 63 | def dispatch(self, event_cls, *args, **kwargs): 64 | self.logger.info(f'Dispatching event `{event_cls.__name__}`') 65 | event = event_cls(*args, **kwargs) 66 | for listener in self.listeners_by_event.get(event_cls, []): 67 | listener.trigger(event) 68 | futures = self.futures_by_event.pop(event_cls, []) 69 | for future in futures: 70 | if not future.done(): 71 | future.set_result(event) 72 | 73 | 74 | # Listener 75 | 76 | def _ensure_coroutine_func(func): 77 | if not asyncio.iscoroutinefunction(func): 78 | raise TypeError('The listener function must be a coroutine function.') 79 | 80 | 81 | class Listener: 82 | """A listener for a particular event. A listener must have a name, the event it should listen 83 | to and a coroutine function `func` that is called when the event is dispatched. 84 | """ 85 | def __init__(self, name, event_cls, func, *, with_lock=False): 86 | """`with_lock` controls whether execution of `func` should be guarded by an asyncio.Lock.""" 87 | _ensure_coroutine_func(func) 88 | self.name = name 89 | self.event_cls = event_cls 90 | self.func = func 91 | self.lock = asyncio.Lock() if with_lock else None 92 | self.logger = logging.getLogger(self.__class__.__name__) 93 | 94 | def trigger(self, event): 95 | asyncio.create_task(self._trigger(event)) 96 | 97 | async def _trigger(self, event): 98 | try: 99 | if self.lock: 100 | async with self.lock: 101 | await self.func(event) 102 | else: 103 | await self.func(event) 104 | except asyncio.CancelledError: 105 | raise 106 | except: 107 | self.logger.exception(f'Exception in listener `{self.name}`.') 108 | 109 | def __eq__(self, other): 110 | return (isinstance(other, Listener) 111 | and (self.event_cls, self.func) == (other.event_cls, other.func)) 112 | 113 | def __hash__(self): 114 | return hash((self.event_cls, self.func)) 115 | 116 | 117 | class ListenerSpec: 118 | """A descriptor intended to be an interface between an instance and its listeners. It creates 119 | the expected listener when `__get__` is called from an instance for the first time. No two 120 | listener specs in the same class should have the same name. 121 | """ 122 | def __init__(self, name, event_cls, func, *, with_lock=False): 123 | """`with_lock` controls whether execution of `func` should be guarded by an asyncio.Lock.""" 124 | _ensure_coroutine_func(func) 125 | self.name = name 126 | self.event_cls = event_cls 127 | self.func = func 128 | self.with_lock = with_lock 129 | 130 | def __get__(self, instance, owner): 131 | if instance is None: 132 | return self 133 | try: 134 | listeners = getattr(instance, '___listeners___') 135 | except AttributeError: 136 | listeners = instance.___listeners___ = {} 137 | if self.name not in listeners: 138 | # In Python <=3.7 iscoroutinefunction returns False for async functions wrapped by 139 | # functools.partial. 140 | # TODO: Use functools.partial when we move to Python 3.8. 141 | async def wrapper(event): 142 | return await self.func(instance, event) 143 | 144 | listeners[self.name] = Listener(self.name, self.event_cls, wrapper, 145 | with_lock=self.with_lock) 146 | return listeners[self.name] 147 | 148 | 149 | def listener(*, name, event_cls, with_lock=False): 150 | """Returns a decorator that creates a `Listener` with the given options.""" 151 | 152 | def decorator(func): 153 | return Listener(name, event_cls, func, with_lock=with_lock) 154 | 155 | return decorator 156 | 157 | 158 | def listener_spec(*, name, event_cls, with_lock=False): 159 | """Returns a decorator that creates a `ListenerSpec` with the given options.""" 160 | 161 | def decorator(func): 162 | return ListenerSpec(name, event_cls, func, with_lock=with_lock) 163 | 164 | return decorator 165 | -------------------------------------------------------------------------------- /tle/util/font_downloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib.request 4 | 5 | from zipfile import ZipFile 6 | from io import BytesIO 7 | 8 | from tle import constants 9 | 10 | URL_BASE = 'https://noto-website-2.storage.googleapis.com/pkgs/' 11 | FONTS = [constants.NOTO_SANS_CJK_BOLD_FONT_PATH, 12 | constants.NOTO_SANS_CJK_REGULAR_FONT_PATH] 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def _unzip(font, archive): 18 | with ZipFile(archive) as zipfile: 19 | if font not in zipfile.namelist(): 20 | raise KeyError(f'Expected font file {font} not present in downloaded zip archive.') 21 | zipfile.extract(font, constants.FONTS_DIR) 22 | 23 | 24 | def _download(font_path): 25 | font = os.path.basename(font_path) 26 | logger.info(f'Downloading font `{font}`.') 27 | with urllib.request.urlopen(f'{URL_BASE}{font}.zip') as resp: 28 | _unzip(font, BytesIO(resp.read())) 29 | 30 | 31 | def maybe_download(): 32 | for font_path in FONTS: 33 | if not os.path.isfile(font_path): 34 | _download(font_path) 35 | -------------------------------------------------------------------------------- /tle/util/gemini_model_settings.py: -------------------------------------------------------------------------------- 1 | text_generation_config = { 2 | "temperature": 0.69, 3 | "top_p": 1, 4 | "top_k": 1, 5 | "max_output_tokens": 2048, 6 | } 7 | 8 | text_safety_settings = [ 9 | { 10 | "category": "HARM_CATEGORY_HARASSMENT", 11 | "threshold": "BLOCK_NONE" 12 | }, 13 | { 14 | "category": "HARM_CATEGORY_HATE_SPEECH", 15 | "threshold": "BLOCK_NONE" 16 | }, 17 | { 18 | "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", 19 | "threshold": "BLOCK_NONE" 20 | }, 21 | { 22 | "category": "HARM_CATEGORY_DANGEROUS_CONTENT", 23 | "threshold": "BLOCK_NONE" 24 | } 25 | ] 26 | 27 | image_generation_config = { 28 | "temperature": 0.4, 29 | "top_p": 1, 30 | "top_k": 32, 31 | "max_output_tokens": 4096, 32 | } 33 | 34 | image_safety_settings = [ 35 | { 36 | "category": "HARM_CATEGORY_HARASSMENT", 37 | "threshold": "BLOCK_NONE" 38 | }, 39 | { 40 | "category": "HARM_CATEGORY_HATE_SPEECH", 41 | "threshold": "BLOCK_NONE" 42 | }, 43 | { 44 | "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", 45 | "threshold": "BLOCK_NONE" 46 | }, 47 | { 48 | "category": "HARM_CATEGORY_DANGEROUS_CONTENT", 49 | "threshold": "BLOCK_NONE" 50 | } 51 | ] 52 | -------------------------------------------------------------------------------- /tle/util/graph_common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import discord 4 | import time 5 | import matplotlib.font_manager 6 | import matplotlib 7 | matplotlib.use('agg') # Explicitly set the backend to avoid issues 8 | 9 | from tle import constants 10 | from matplotlib import pyplot as plt 11 | from matplotlib import rcParams 12 | from cycler import cycler 13 | 14 | rating_color_cycler = cycler('color', ['#5d4dff', 15 | '#009ccc', 16 | '#00ba6a', 17 | '#b99d27', 18 | '#cb2aff']) 19 | 20 | fontprop = matplotlib.font_manager.FontProperties(fname=constants.NOTO_SANS_CJK_REGULAR_FONT_PATH) 21 | 22 | 23 | # String wrapper to avoid the underscore behavior in legends 24 | # 25 | # In legends, matplotlib ignores labels that begin with _ 26 | # https://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.legend 27 | # However, this check is only done for actual string objects. 28 | class StrWrap: 29 | def __init__(self, s): 30 | self.string = s 31 | def __str__(self): 32 | return self.string 33 | 34 | def get_current_figure_as_file(): 35 | filename = os.path.join(constants.TEMP_DIR, f'tempplot_{time.time()}.png') 36 | plt.savefig(filename, facecolor=plt.gca().get_facecolor(), bbox_inches='tight', pad_inches=0.25) 37 | 38 | with open(filename, 'rb') as file: 39 | discord_file = discord.File(io.BytesIO(file.read()), filename='plot.png') 40 | 41 | os.remove(filename) 42 | return discord_file 43 | 44 | def plot_rating_bg(ranks): 45 | ymin, ymax = plt.gca().get_ylim() 46 | bgcolor = plt.gca().get_facecolor() 47 | for rank in ranks: 48 | plt.axhspan(rank.low, rank.high, facecolor=rank.color_graph, alpha=0.8, edgecolor=bgcolor, linewidth=0.5) 49 | 50 | locs, labels = plt.xticks() 51 | for loc in locs: 52 | plt.axvline(loc, color=bgcolor, linewidth=0.5) 53 | plt.ylim(ymin, ymax) 54 | -------------------------------------------------------------------------------- /tle/util/handledict.py: -------------------------------------------------------------------------------- 1 | """ 2 | A case insensitive dictionay with bare minimum functions required for handling usernames. 3 | """ 4 | 5 | 6 | class HandleDict: 7 | def __init__(self): 8 | self._store = {} 9 | 10 | @staticmethod 11 | def _getlower(key): 12 | return key.lower() if type(key) == str else key 13 | 14 | def __setitem__(self, key, value): 15 | # Use the lowercased key for lookups, but store the actual 16 | # key alongside the value. 17 | self._store[self._getlower(key)] = (key, value) 18 | 19 | def __getitem__(self, key): 20 | return self._store[self._getlower(key)][1] 21 | 22 | # get correct handle irrespective of the input case of the handle (if the handle is present) 23 | def get_correct_handle(self, key): 24 | try: 25 | return self._store[self._getlower(key)][0] 26 | except KeyError: 27 | return "" 28 | 29 | def __delitem__(self, key): 30 | del self._store[self._getlower(key)] 31 | 32 | def __iter__(self): 33 | return (cased_key for cased_key, mapped_value in self._store.values()) 34 | 35 | def items(self): 36 | return dict([value for value in self._store.values()]).items() 37 | 38 | def __repr__(self): 39 | return str(self.items()) 40 | -------------------------------------------------------------------------------- /tle/util/paginator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | 4 | _REACT_FIRST = '\N{BLACK LEFT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}' 5 | _REACT_PREV = '\N{BLACK LEFT-POINTING TRIANGLE}' 6 | _REACT_NEXT = '\N{BLACK RIGHT-POINTING TRIANGLE}' 7 | _REACT_LAST = '\N{BLACK RIGHT-POINTING DOUBLE TRIANGLE WITH VERTICAL BAR}' 8 | 9 | 10 | def chunkify(sequence, chunk_size): 11 | """Utility method to split a sequence into fixed size chunks.""" 12 | return [sequence[i: i + chunk_size] for i in range(0, len(sequence), chunk_size)] 13 | 14 | 15 | class PaginatorError(Exception): 16 | pass 17 | 18 | 19 | class NoPagesError(PaginatorError): 20 | pass 21 | 22 | 23 | class InsufficientPermissionsError(PaginatorError): 24 | pass 25 | 26 | 27 | class Paginated: 28 | def __init__(self, pages): 29 | self.pages = pages 30 | self.cur_page = None 31 | self.message = None 32 | self.reaction_map = { 33 | _REACT_FIRST: functools.partial(self.show_page, 1), 34 | _REACT_PREV: self.prev_page, 35 | _REACT_NEXT: self.next_page, 36 | _REACT_LAST: functools.partial(self.show_page, len(pages)) 37 | } 38 | 39 | async def show_page(self, page_num): 40 | if 1 <= page_num <= len(self.pages): 41 | content, embed = self.pages[page_num - 1] 42 | await self.message.edit(content=content, embed=embed) 43 | self.cur_page = page_num 44 | 45 | async def prev_page(self): 46 | await self.show_page(self.cur_page - 1) 47 | 48 | async def next_page(self): 49 | await self.show_page(self.cur_page + 1) 50 | 51 | async def paginate(self, bot, channel, wait_time, delete_after:float = None): 52 | content, embed = self.pages[0] 53 | self.message = await channel.send(content, embed=embed, delete_after=delete_after) 54 | 55 | if len(self.pages) == 1: 56 | # No need to paginate. 57 | return 58 | 59 | self.cur_page = 1 60 | for react in self.reaction_map.keys(): 61 | await self.message.add_reaction(react) 62 | 63 | def check(reaction, user): 64 | return (bot.user != user and 65 | reaction.message.id == self.message.id and 66 | reaction.emoji in self.reaction_map) 67 | 68 | while True: 69 | try: 70 | reaction, user = await bot.wait_for('reaction_add', timeout=wait_time, check=check) 71 | await reaction.remove(user) 72 | await self.reaction_map[reaction.emoji]() 73 | except asyncio.TimeoutError: 74 | await self.message.clear_reactions() 75 | break 76 | 77 | 78 | def paginate(bot, channel, pages, *, wait_time, set_pagenum_footers=False, delete_after:float = None): 79 | if not pages: 80 | raise NoPagesError() 81 | permissions = channel.permissions_for(channel.guild.me) 82 | if not permissions.manage_messages: 83 | raise InsufficientPermissionsError('Permission to manage messages required') 84 | if len(pages) > 1 and set_pagenum_footers: 85 | for i, (content, embed) in enumerate(pages): 86 | embed.set_footer(text=f'Page {i + 1} / {len(pages)}') 87 | paginated = Paginated(pages) 88 | asyncio.create_task(paginated.paginate(bot, channel, wait_time, delete_after)) 89 | -------------------------------------------------------------------------------- /tle/util/ranklist/__init__.py: -------------------------------------------------------------------------------- 1 | from .ranklist import * 2 | -------------------------------------------------------------------------------- /tle/util/ranklist/ranklist.py: -------------------------------------------------------------------------------- 1 | from discord.ext import commands 2 | 3 | from tle.util.ranklist.rating_calculator import CodeforcesRatingCalculator 4 | from tle.util.handledict import HandleDict 5 | from tle.util.codeforces_api import make_from_dict, RanklistRow 6 | 7 | 8 | class RanklistError(commands.CommandError): 9 | def __init__(self, contest, message=None): 10 | if message is not None: 11 | super().__init__(message) 12 | self.contest = contest 13 | 14 | 15 | class ContestNotRatedError(RanklistError): 16 | def __init__(self, contest): 17 | super().__init__(contest, f'`{contest.name}` is not rated') 18 | 19 | 20 | class HandleNotPresentError(RanklistError): 21 | def __init__(self, contest, handle): 22 | super().__init__(contest, f'Handle `{handle}`` not present in standings of `{contest.name}`') 23 | self.handle = handle 24 | 25 | 26 | class DeltasNotPresentError(RanklistError): 27 | def __init__(self, contest): 28 | super().__init__(contest, f'Rating changes for `{contest.name}` not calculated or set.') 29 | 30 | 31 | class Ranklist: 32 | def __init__(self, contest, problems, standings, fetch_time, *, is_rated): 33 | self.contest = contest 34 | self.problems = problems 35 | self.standings = standings 36 | self.fetch_time = fetch_time 37 | self.is_rated = is_rated 38 | self.delta_by_handle = None 39 | self.deltas_status = None 40 | self.standing_by_id = None 41 | self._create_inverse_standings() 42 | 43 | def _create_inverse_standings(self): 44 | self.standing_by_id = HandleDict() 45 | for row in self.standings: 46 | id_ = self.get_ranklist_lookup_key(row) 47 | self.standing_by_id[id_] = row 48 | 49 | def remove_unofficial_contestants(self): 50 | """ 51 | To be used for cases when official ranklist contains unofficial contestants 52 | Currently this is seen is Educational Contests ranklist where div1 contestants are marked official in api result 53 | """ 54 | 55 | if self.delta_by_handle is None: 56 | raise DeltasNotPresentError(self.contest) 57 | 58 | official_standings = [] 59 | current_rated_rank = 1 60 | last_rated_rank = 0 61 | last_rated_score = (-1, -1) 62 | for contestant in self.standings: 63 | handle = self.get_ranklist_lookup_key(contestant) 64 | if handle in self.delta_by_handle: 65 | current_score = (contestant.points, contestant.penalty) 66 | standings_row = self.standing_by_id[handle]._asdict() 67 | standings_row['rank'] = current_rated_rank if current_score != last_rated_score else last_rated_rank 68 | official_standings.append(make_from_dict(RanklistRow, standings_row)) 69 | last_rated_rank = standings_row['rank'] 70 | last_rated_score = current_score 71 | current_rated_rank += 1 72 | 73 | self.standings = official_standings 74 | self._create_inverse_standings() 75 | 76 | def set_deltas(self, delta_by_handle): 77 | if not self.is_rated: 78 | raise ContestNotRatedError(self.contest) 79 | self.delta_by_handle = delta_by_handle.copy() 80 | self.deltas_status = 'Final' 81 | 82 | def predict(self, current_rating): 83 | if not self.is_rated: 84 | raise ContestNotRatedError(self.contest) 85 | standings = [(id_, row.points, row.penalty, current_rating[id_]) 86 | for id_, row in self.standing_by_id.items() if id_ in current_rating] 87 | if standings: 88 | self.delta_by_handle = CodeforcesRatingCalculator(standings).calculate_rating_changes() 89 | self.deltas_status = 'Predicted' 90 | 91 | def get_delta(self, handle): 92 | if not self.is_rated: 93 | raise ContestNotRatedError(self.contest) 94 | if handle not in self.standing_by_id: 95 | raise HandleNotPresentError(self.contest, handle) 96 | return self.delta_by_handle.get(handle) 97 | 98 | def get_standing_row(self, handle): 99 | try: 100 | return self.standing_by_id[handle] 101 | except KeyError: 102 | raise HandleNotPresentError(self.contest, handle) 103 | 104 | @staticmethod 105 | def get_ranklist_lookup_key(contestant): 106 | return contestant.party.teamName or contestant.party.members[0].handle 107 | -------------------------------------------------------------------------------- /tle/util/ranklist/rating_calculator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Codeforces code to recalculate ratings 3 | by Mike Mirzayanov (mirzayanovmr@gmail.com) at https://codeforces.com/contest/1/submission/13861109 4 | Updated to use the current rating formula. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | 9 | import numpy as np 10 | from numpy.fft import fft, ifft 11 | 12 | 13 | def intdiv(x, y): 14 | return -(-x // y) if x < 0 else x // y 15 | 16 | 17 | @dataclass 18 | class Contestant: 19 | party: str 20 | points: float 21 | penalty: int 22 | rating: int 23 | need_rating: int = 0 24 | delta: int = 0 25 | rank: float = 0.0 26 | seed: float = 0.0 27 | 28 | 29 | class CodeforcesRatingCalculator: 30 | def __init__(self, standings): 31 | """Calculate Codeforces rating changes and seeds given contest and user information.""" 32 | self.contestants = [Contestant(handle, points, penalty, rating) 33 | for handle, points, penalty, rating in standings] 34 | self._precalc_seed() 35 | self._reassign_ranks() 36 | self._process() 37 | self._update_delta() 38 | 39 | def calculate_rating_changes(self): 40 | """Return a mapping between contestants and their corresponding delta.""" 41 | return {contestant.party: contestant.delta for contestant in self.contestants} 42 | 43 | def get_seed(self, rating, me=None): 44 | """Get seed given a rating and user.""" 45 | seed = self.seed[rating] 46 | if me: 47 | seed -= self.elo_win_prob[rating - me.rating] 48 | return seed 49 | 50 | def _precalc_seed(self): 51 | MAX = 6144 52 | 53 | # Precompute the ELO win probability for all possible rating differences. 54 | self.elo_win_prob = np.roll(1 / (1 + pow(10, np.arange(-MAX, MAX) / 400)), -MAX) 55 | 56 | # Compute the rating histogram. 57 | count = np.zeros(2 * MAX) 58 | for a in self.contestants: 59 | count[a.rating] += 1 60 | 61 | # Precompute the seed for all possible ratings using FFT. 62 | self.seed = 1 + ifft(fft(count) * fft(self.elo_win_prob)).real 63 | 64 | def _reassign_ranks(self): 65 | """Find the rank of each contestant.""" 66 | contestants = self.contestants 67 | contestants.sort(key=lambda o: (-o.points, o.penalty)) 68 | points = penalty = rank = None 69 | for i in reversed(range(len(contestants))): 70 | if contestants[i].points != points or contestants[i].penalty != penalty: 71 | rank = i + 1 72 | points = contestants[i].points 73 | penalty = contestants[i].penalty 74 | contestants[i].rank = rank 75 | 76 | def _process(self): 77 | """Process and assign approximate delta for each contestant.""" 78 | for a in self.contestants: 79 | a.seed = self.get_seed(a.rating, a) 80 | mid_rank = (a.rank * a.seed) ** 0.5 81 | a.need_rating = self._rank_to_rating(mid_rank, a) 82 | a.delta = intdiv(a.need_rating - a.rating, 2) 83 | 84 | def _rank_to_rating(self, rank, me): 85 | """Binary Search to find the performance rating for a given rank.""" 86 | left, right = 1, 8000 87 | while right - left > 1: 88 | mid = (left + right) // 2 89 | if self.get_seed(mid, me) < rank: 90 | right = mid 91 | else: 92 | left = mid 93 | return left 94 | 95 | def _update_delta(self): 96 | """Update the delta of each contestant.""" 97 | contestants = self.contestants 98 | n = len(contestants) 99 | 100 | contestants.sort(key=lambda o: -o.rating) 101 | correction = intdiv(-sum(c.delta for c in contestants), n) - 1 102 | for contestant in contestants: 103 | contestant.delta += correction 104 | 105 | zero_sum_count = min(4 * round(n ** 0.5), n) 106 | delta_sum = -sum(contestants[i].delta for i in range(zero_sum_count)) 107 | correction = min(0, max(-10, intdiv(delta_sum, zero_sum_count))) 108 | for contestant in contestants: 109 | contestant.delta += correction 110 | -------------------------------------------------------------------------------- /tle/util/table.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | 3 | FULL_WIDTH = 1.66667 4 | WIDTH_MAPPING = {'F': FULL_WIDTH, 'H': 1, 'W': FULL_WIDTH, 'Na': 1, 'N': 1, 'A': 1} 5 | 6 | def width(s): 7 | return round(sum(WIDTH_MAPPING[unicodedata.east_asian_width(c)] for c in s)) 8 | 9 | 10 | class Content: 11 | def __init__(self, *args): 12 | self.data = args 13 | def sizes(self): 14 | return [width(str(x)) for x in self.data] 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | class Header(Content): 19 | def layout(self, style): 20 | return style.format_header(self.data) 21 | 22 | class Data(Content): 23 | def layout(self, style): 24 | return style.format_body(self.data) 25 | 26 | class Line: 27 | def __init__(self, c='-'): 28 | self.c = c 29 | def layout(self, style): 30 | self.data = ['']*style.ncols 31 | return style.format_line(self.c) 32 | 33 | class Style: 34 | def __init__(self, body, header=None): 35 | self._body = body 36 | self._header = header or body 37 | self.ncols = body.count('}') 38 | 39 | def _pad(self, data, fmt): 40 | S = [] 41 | lastc = None 42 | size = iter(self.sizes) 43 | datum = iter(data) 44 | for c in fmt: 45 | if lastc == ':': 46 | dstr = str(next(datum)) 47 | sz = str(next(size) - (width(dstr) - len(dstr))) 48 | if c in '<>^': 49 | S.append(c + sz) 50 | else: 51 | S.append(sz + c) 52 | else: 53 | S.append(c) 54 | lastc = c 55 | return ''.join(S) 56 | 57 | def format_header(self, data): 58 | return self._pad(data, self._header).format(*data) 59 | 60 | def format_line(self, c): 61 | data = ['']*self.ncols 62 | return self._pad(data, self._header).replace(':', ':'+c).format(*data) 63 | 64 | def format_body(self, data): 65 | return self._pad(data, self._body).format(*data) 66 | 67 | def set_colwidths(self, sizes): 68 | self.sizes = sizes 69 | 70 | class Table: 71 | def __init__(self, style): 72 | self.style = style 73 | self.rows = [] 74 | 75 | def append(self, row): 76 | self.rows.append(row) 77 | return self 78 | __add__ = append 79 | 80 | def __repr__(self): 81 | sizes = [row.sizes() for row in self.rows if isinstance(row, Content)] 82 | max_colsize = [max(s[i] for s in sizes) for i in range(self.style.ncols)] 83 | self.style.set_colwidths(max_colsize) 84 | return '\n'.join(row.layout(self.style) for row in self.rows) 85 | __str__ = __repr__ 86 | -------------------------------------------------------------------------------- /tle/util/tasks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from discord.ext import commands 5 | 6 | import tle.util.codeforces_common as cf_common 7 | 8 | 9 | class TaskError(commands.CommandError): 10 | pass 11 | 12 | 13 | class WaiterRequired(TaskError): 14 | def __init__(self, name): 15 | super().__init__(f'No waiter set for task `{name}`') 16 | 17 | 18 | class TaskAlreadyRunning(TaskError): 19 | def __init__(self, name): 20 | super().__init__(f'Attempt to start task `{name}` which is already running') 21 | 22 | 23 | def _ensure_coroutine_func(func): 24 | if not asyncio.iscoroutinefunction(func): 25 | raise TypeError('The decorated function must be a coroutine function.') 26 | 27 | 28 | class Waiter: 29 | def __init__(self, func, *, run_first=False, needs_instance=False): 30 | """`run_first` denotes whether this waiter should be run before the task's `func` when 31 | run for the first time. `needs_instance` indicates whether a self argument is required by 32 | the `func`. 33 | """ 34 | _ensure_coroutine_func(func) 35 | self.func = func 36 | self.run_first = run_first 37 | self.needs_instance = needs_instance 38 | 39 | async def wait(self, instance=None): 40 | if self.needs_instance: 41 | return await self.func(instance) 42 | else: 43 | return await self.func() 44 | 45 | @staticmethod 46 | def fixed_delay(delay, run_first=False): 47 | """Returns a waiter that always waits for the given time (in seconds) and returns the 48 | time waited. 49 | """ 50 | 51 | async def wait_func(): 52 | await asyncio.sleep(delay) 53 | return delay 54 | 55 | return Waiter(wait_func, run_first=run_first) 56 | 57 | @staticmethod 58 | def for_event(event_cls, run_first=True): 59 | """Returns a waiter that waits for the given event and returns the result of that 60 | event. 61 | """ 62 | 63 | async def wait_func(): 64 | return await cf_common.event_sys.wait_for(event_cls) 65 | 66 | return Waiter(wait_func, run_first=run_first) 67 | 68 | 69 | class ExceptionHandler: 70 | def __init__(self, func, *, needs_instance=False): 71 | """`needs_instance` indicates whether a self argument is required by the `func`.""" 72 | _ensure_coroutine_func(func) 73 | self.func = func 74 | self.needs_instance = needs_instance 75 | 76 | async def handle(self, exception, instance=None): 77 | if self.needs_instance: 78 | await self.func(instance, exception) 79 | else: 80 | await self.func(exception) 81 | 82 | 83 | class Task: 84 | """A task that repeats until stopped. A task must have a name, a coroutine function `func` to 85 | execute periodically and another coroutine function `waiter` to wait on between calls to `func`. 86 | The return value of `waiter` is passed to `func` in the next call. An optional coroutine 87 | function `exception_handler` may be provided to which exceptions will be reported. 88 | """ 89 | 90 | def __init__(self, name, func, waiter, exception_handler=None, *, instance=None): 91 | """`instance`, if present, is passed as the first argument to `func`.""" 92 | _ensure_coroutine_func(func) 93 | self.name = name 94 | self.func = func 95 | self._waiter = waiter 96 | self._exception_handler = exception_handler 97 | self.instance = instance 98 | self.asyncio_task = None 99 | self.logger = logging.getLogger(self.__class__.__name__) 100 | 101 | def waiter(self, run_first=False): 102 | """Returns a decorator that sets the decorated coroutine function as the waiter for this 103 | Task. 104 | """ 105 | 106 | def decorator(func): 107 | self._waiter = Waiter(func, run_first=run_first) 108 | return func 109 | 110 | return decorator 111 | 112 | def exception_handler(self): 113 | """Returns a decorator that sets the decorated coroutine function as the exception handler 114 | for this Task. 115 | """ 116 | 117 | def decorator(func): 118 | self._exception_handler = ExceptionHandler(func) 119 | return func 120 | 121 | return decorator 122 | 123 | @property 124 | def running(self): 125 | return self.asyncio_task is not None and not self.asyncio_task.done() 126 | 127 | def start(self): 128 | """Starts up the task.""" 129 | if self._waiter is None: 130 | raise WaiterRequired(self.name) 131 | if self.running: 132 | raise TaskAlreadyRunning(self.name) 133 | self.logger.info(f'Starting up task `{self.name}`.') 134 | self.asyncio_task = asyncio.create_task(self._task()) 135 | 136 | async def manual_trigger(self, arg=None): 137 | """Manually triggers the `func` with the optionally provided `arg`, which defaults to 138 | `None`. 139 | """ 140 | self.logger.info(f'Manually triggering task `{self.name}`.') 141 | await self._execute_func(arg) 142 | 143 | async def stop(self): 144 | """Stops the task, interrupting the currently running coroutines.""" 145 | if self.running: 146 | self.logger.info(f'Stopping task `{self.name}`.') 147 | self.asyncio_task.cancel() 148 | await asyncio.sleep(0) # To ensure cancellation if called from within the task itself. 149 | 150 | async def _task(self): 151 | arg = None 152 | if self._waiter.run_first: 153 | arg = await self._waiter.wait(self.instance) 154 | while True: 155 | await self._execute_func(arg) 156 | arg = await self._waiter.wait(self.instance) 157 | 158 | async def _execute_func(self, arg): 159 | try: 160 | if self.instance is not None: 161 | await self.func(self.instance, arg) 162 | else: 163 | await self.func(arg) 164 | except asyncio.CancelledError: 165 | raise 166 | except Exception as ex: 167 | self.logger.warning(f'Exception in task `{self.name}`, ignoring.', exc_info=True) 168 | if self._exception_handler is not None: 169 | await self._exception_handler.handle(ex, self.instance) 170 | 171 | 172 | class TaskSpec: 173 | """A descriptor intended to be an interface between an instance and its tasks. It creates 174 | the expected task when `__get__` is called from an instance for the first time. No two task 175 | specs in the same class should have the same name.""" 176 | 177 | def __init__(self, name, func, waiter=None, exception_handler=None): 178 | _ensure_coroutine_func(func) 179 | self.name = name 180 | self.func = func 181 | self._waiter = waiter 182 | self._exception_handler = exception_handler 183 | 184 | def waiter(self, run_first=False, needs_instance=True): 185 | """Returns a decorator that sets the decorated coroutine function as the waiter for this 186 | TaskSpec. 187 | """ 188 | 189 | def decorator(func): 190 | self._waiter = Waiter(func, run_first=run_first, needs_instance=needs_instance) 191 | return func 192 | 193 | return decorator 194 | 195 | def exception_handler(self, needs_instance=True): 196 | """Returns a decorator that sets the decorated coroutine function as the exception handler 197 | for this TaskSpec. 198 | """ 199 | 200 | def decorator(func): 201 | self._exception_handler = ExceptionHandler(func, needs_instance=needs_instance) 202 | return func 203 | 204 | return decorator 205 | 206 | def __get__(self, instance, owner): 207 | if instance is None: 208 | return self 209 | try: 210 | tasks = getattr(instance, '___tasks___') 211 | except AttributeError: 212 | tasks = instance.___tasks___ = {} 213 | if self.name not in tasks: 214 | tasks[self.name] = Task(self.name, self.func, self._waiter, self._exception_handler, 215 | instance=instance) 216 | return tasks[self.name] 217 | 218 | 219 | def task(*, name, waiter=None, exception_handler=None): 220 | """Returns a decorator that creates a `Task` with the given options.""" 221 | 222 | def decorator(func): 223 | return Task(name, func, waiter, exception_handler, instance=None) 224 | 225 | return decorator 226 | 227 | 228 | def task_spec(*, name, waiter=None, exception_handler=None): 229 | """Returns a decorator that creates a `TaskSpec` descriptor with the given options.""" 230 | 231 | def decorator(func): 232 | return TaskSpec(name, func, waiter, exception_handler) 233 | 234 | return decorator 235 | --------------------------------------------------------------------------------