├── aoc_helper ├── __main__.py ├── __init__.py ├── day_template.py ├── formatting.py ├── data.py ├── types.py ├── main.py ├── interface.py └── utils.py ├── upload.py ├── pyproject.toml ├── LICENSE ├── .gitignore └── README.md /aoc_helper/__main__.py: -------------------------------------------------------------------------------- 1 | from .main import cli 2 | 3 | if __name__ == "__main__": 4 | cli() 5 | -------------------------------------------------------------------------------- /aoc_helper/__init__.py: -------------------------------------------------------------------------------- 1 | from .interface import fetch, lazy_submit, lazy_test, submit 2 | from .utils import * 3 | -------------------------------------------------------------------------------- /upload.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from pypi_config import PASSWORD, USERNAME 5 | 6 | os.system(f"{sys.executable} -m build -n") 7 | os.system(f"twine upload dist/* -u {USERNAME} -p {PASSWORD} --skip-existing") 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools>=61.2"] 4 | 5 | [project] 6 | authors = [{name = "Starwort"}, {name = "salt-die"}] 7 | classifiers = [ 8 | "Programming Language :: Python :: 3", 9 | "License :: OSI Approved :: MIT License", 10 | "Operating System :: OS Independent", 11 | ] 12 | dependencies = [ 13 | "requests", 14 | "beautifulsoup4", 15 | "typing_extensions", 16 | ] 17 | description = "A helper package for Advent of Code" 18 | name = "aoc_helper" 19 | requires-python = ">=3.9" 20 | urls = {Homepage = "https://github.com/Starwort/aoc_helper"} 21 | version = "1.14.0" 22 | 23 | [project.readme] 24 | content-type = "text/markdown" 25 | file = "README.md" 26 | 27 | [project.scripts] 28 | aoc = "aoc_helper.main:cli" 29 | 30 | [project.optional-dependencies] 31 | cli = ["click", "click-aliases"] 32 | fancy = ["rich"] 33 | full = ["click", "click-aliases", "rich"] 34 | 35 | [tool.setuptools] 36 | include-package-data = false 37 | 38 | [tool.setuptools.packages] 39 | find = {namespaces = false} 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Starwort, salt-die 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /aoc_helper/day_template.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, defaultdict, deque 2 | 3 | import aoc_helper 4 | from aoc_helper import ( 5 | Grid, 6 | PrioQueue, 7 | SparseGrid, 8 | decode_text, 9 | extract_ints, 10 | extract_iranges, 11 | extract_ranges, 12 | extract_uints, 13 | frange, 14 | irange, 15 | iter, 16 | list, 17 | map, 18 | multirange, 19 | range, 20 | search, 21 | tail_call, 22 | ) 23 | 24 | raw = aoc_helper.fetch({day}, {year}) 25 | 26 | 27 | def parse_raw(raw: str): 28 | return ... 29 | 30 | 31 | data = parse_raw(raw) 32 | 33 | 34 | # providing this default is somewhat of a hack - there isn't any other way to 35 | # force type inference to happen, AFAIK - but this won't work with standard 36 | # collections (list, set, dict, tuple) 37 | def part_one(data=data): 38 | ... 39 | 40 | 41 | aoc_helper.lazy_test(day={day}, year={year}, parse=parse_raw, solution=part_one) 42 | 43 | 44 | # providing this default is somewhat of a hack - there isn't any other way to 45 | # force type inference to happen, AFAIK - but this won't work with standard 46 | # collections (list, set, dict, tuple) 47 | def part_two(data=data): 48 | ... 49 | 50 | 51 | aoc_helper.lazy_test(day={day}, year={year}, parse=parse_raw, solution=part_two) 52 | 53 | aoc_helper.lazy_submit(day={day}, year={year}, solution=part_one, data=data) 54 | aoc_helper.lazy_submit(day={day}, year={year}, solution=part_two, data=data) 55 | -------------------------------------------------------------------------------- /aoc_helper/formatting.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import time 3 | import typing 4 | from builtins import print as print_raw 5 | 6 | T = typing.TypeVar("T") 7 | U = typing.TypeVar("U") 8 | 9 | try: 10 | from rich import get_console, print, progress 11 | except ImportError: 12 | print = print_raw # suppress a Pylance warning 13 | 14 | RED = "" 15 | YELLOW = "" 16 | GREEN = "" 17 | BLUE = "" 18 | GOLD = "" 19 | RESET = "" 20 | 21 | def wait(msg: str, secs: float) -> None: 22 | print(msg) 23 | time.sleep(secs) 24 | 25 | def work(msg: str, worker: typing.Callable[[U], T], data: U) -> T: 26 | print(msg) 27 | return worker(data) 28 | 29 | else: 30 | RED = "[red]" 31 | YELLOW = "[yellow]" 32 | GREEN = "[green]" 33 | BLUE = "[blue]" 34 | GOLD = "[gold1]" 35 | RESET = "[/]" 36 | 37 | get_console()._highlight = False 38 | 39 | def wait(msg: str, secs: float) -> None: 40 | for _ in progress.track( 41 | builtins.range(int(10 * secs)), 42 | description=msg, 43 | show_speed=False, 44 | transient=True, 45 | ): 46 | time.sleep(0.1) 47 | 48 | def _rich_work(msg: str, worker: typing.Callable[[U], T], data: U) -> T: 49 | with progress.Progress( 50 | progress.TextColumn("{task.description}"), 51 | progress.SpinnerColumn(), 52 | progress.TimeElapsedColumn(), 53 | transient=True, 54 | ) as bar: 55 | task = bar.add_task(msg) 56 | val = worker(data) 57 | bar.advance(task) 58 | return val 59 | 60 | work = _rich_work 61 | 62 | __all__ = ( 63 | "RED", 64 | "YELLOW", 65 | "GREEN", 66 | "BLUE", 67 | "GOLD", 68 | "RESET", 69 | "wait", 70 | "work", 71 | "print", 72 | "print_raw", 73 | ) 74 | -------------------------------------------------------------------------------- /aoc_helper/data.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pathlib 3 | import re 4 | import typing 5 | 6 | try: 7 | import importlib_metadata as metadata # type: ignore 8 | except ImportError: 9 | from importlib import metadata # type: ignore 10 | 11 | DATA_DIR = pathlib.Path.home() / ".config" / "aoc_helper" 12 | if not DATA_DIR.exists(): 13 | DATA_DIR.mkdir(parents=True) 14 | PRACTICE_DATA_DIR = DATA_DIR / "practice" 15 | if not PRACTICE_DATA_DIR.exists(): 16 | PRACTICE_DATA_DIR.mkdir(parents=True) 17 | 18 | DEFAULT_YEAR = datetime.datetime.today().year 19 | TODAY = datetime.datetime.today().day 20 | LEADERBOARD_URL = "https://adventofcode.com/{year}/leaderboard/day/{day}" 21 | URL = "https://adventofcode.com/{year}/day/{day}" 22 | WAIT_TIME = re.compile(r"You have (?:(\d+)m )?(\d+)s left to wait.") 23 | RANK = re.compile(r"You (?:got|achieved) rank (\d+) on this star's leaderboard.") 24 | 25 | HEADERS = { 26 | "User-Agent": ( 27 | f"github.com/starwort/aoc_helper v{metadata.version('aoc_helper')} contact:" 28 | " Discord @starwort Github https://github.com/Starwort/aoc_helper/issues" 29 | ) 30 | } 31 | 32 | 33 | @typing.overload 34 | def get_cookie(missing_ok: typing.Literal[False]) -> dict[str, str]: ... 35 | @typing.overload 36 | def get_cookie() -> dict[str, str]: ... 37 | @typing.overload 38 | def get_cookie( 39 | missing_ok: typing.Literal[True], 40 | ) -> typing.Optional[typing.Dict[str, str]]: ... 41 | @typing.overload 42 | def get_cookie( 43 | missing_ok: bool, 44 | ) -> typing.Optional[typing.Dict[str, str]]: ... 45 | 46 | 47 | def get_cookie(missing_ok=False): 48 | token_file = DATA_DIR / "token.txt" 49 | if token_file.exists(): 50 | return {"session": token_file.read_text().strip("\n")} 51 | if missing_ok: 52 | return None 53 | token = input("Could not find configuration file. Please enter your token\n>>> ") 54 | token_file.write_text(token) 55 | return {"session": token} 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VSCode 132 | .vscode/ 133 | pypi_config.py -------------------------------------------------------------------------------- /aoc_helper/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol, TypeAlias, TypeVar, Union 2 | 3 | # mostly stolen from typeshed 4 | 5 | _T_contra = TypeVar("_T_contra", contravariant=True) 6 | _T_co = TypeVar("_T_co", covariant=True) 7 | 8 | 9 | class SupportsDunderLT(Protocol[_T_contra]): 10 | def __lt__(self, __other: _T_contra) -> bool: 11 | ... 12 | 13 | 14 | class SupportsDunderGT(Protocol[_T_contra]): 15 | def __gt__(self, __other: _T_contra) -> bool: 16 | ... 17 | 18 | 19 | class SupportsAdd(Protocol[_T_contra, _T_co]): 20 | def __add__(self, __x: _T_contra) -> _T_co: 21 | ... 22 | 23 | 24 | class SupportsRAdd(Protocol[_T_contra, _T_co]): 25 | def __radd__(self, __x: _T_contra) -> _T_co: 26 | ... 27 | 28 | 29 | class SupportsSub(Protocol[_T_contra, _T_co]): 30 | def __sub__(self, __x: _T_contra) -> _T_co: 31 | ... 32 | 33 | 34 | class SupportsMul(Protocol[_T_contra, _T_co]): 35 | def __mul__(self, __x: _T_contra) -> _T_co: 36 | ... 37 | 38 | 39 | class SupportsRMul(Protocol[_T_contra, _T_co]): 40 | def __rmul__(self, __x: _T_contra) -> _T_co: 41 | ... 42 | 43 | 44 | class SupportsDiv(Protocol[_T_contra, _T_co]): 45 | def __div__(self, __x: _T_contra) -> _T_co: 46 | ... 47 | 48 | 49 | class _SupportsSumWithNoDefaultGiven( 50 | SupportsAdd[Any, Any], SupportsRAdd[int, Any], Protocol 51 | ): 52 | ... 53 | 54 | 55 | class SupportsHash(Protocol): 56 | def __hash__(self) -> Any: 57 | ... 58 | 59 | 60 | SupportsSumNoDefaultT = TypeVar( 61 | "SupportsSumNoDefaultT", bound=_SupportsSumWithNoDefaultGiven 62 | ) 63 | 64 | AddableT = TypeVar("AddableT", bound=SupportsAdd[Any, Any]) 65 | AddableU = TypeVar("AddableU", bound=SupportsAdd[Any, Any]) 66 | HashableU = TypeVar("HashableU", bound=SupportsHash) 67 | 68 | SubtractableT = TypeVar("SubtractableT", bound=SupportsSub[Any, Any]) 69 | 70 | 71 | class _SupportsProdWithNoDefaultGiven( 72 | SupportsMul[Any, Any], SupportsRMul[int, Any], Protocol 73 | ): 74 | ... 75 | 76 | 77 | SupportsProdNoDefaultT = TypeVar( 78 | "SupportsProdNoDefaultT", bound=_SupportsProdWithNoDefaultGiven 79 | ) 80 | 81 | MultipliableT = TypeVar("MultipliableT", bound=SupportsMul[Any, Any]) 82 | MultipliableU = TypeVar("MultipliableU", bound=SupportsMul[Any, Any]) 83 | 84 | 85 | class _SupportsMean(_SupportsSumWithNoDefaultGiven, SupportsDiv[Any, Any], Protocol): 86 | ... 87 | 88 | 89 | SupportsMean = TypeVar("SupportsMean", bound=_SupportsMean) 90 | 91 | SupportsRichComparison: TypeAlias = Union[SupportsDunderLT[Any], SupportsDunderGT[Any]] 92 | SupportsRichComparisonT = TypeVar( 93 | "SupportsRichComparisonT", bound=SupportsRichComparison 94 | ) # noqa: Y001 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `aoc_helper` 2 | 3 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 4 | 5 | @salt-die's aoc_helper package, rewritten from the ground up 6 | 7 | ## Automation 8 | 9 | This project aims to be compliant with the [Advent of Code Automation Guidelines](https://www.reddit.com/r/adventofcode/wiki/faqs/automation). Here are the strategies it uses: 10 | 11 | - Once inputs are downloaded, they are cached in `~/.config/aoc_helper/YEAR/DAY.in` (or a similar path for Windows users) - [`interface.fetch`](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L97-L155) (lines [107-108](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L107-L108), [152](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L152)) 12 | - The `User-Agent` header declares the package name, version, and my contact info - [`data.HEADERS`](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/data.py#L20-L25), [used](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L118) [in](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L139) [every](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L210) [outbound](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L266) [request](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L341) 13 | - If requesting input before the puzzle unlocks, [the library will wait for unlock before sending any requests](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L130-L133) (except on day 1, where it will send a request to validate the session token) 14 | - If sending an answer too soon after an incorrect one, [the library will wait the cooldown specified in the response](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L231-L234) (sending only one extra request; it *is* however possible for a user to send multiple requests in quick succession, by repeatedly calling `submit` before the cooldown is over) 15 | - Advent of Code will not be queried at all [if the puzzle has already been solved](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L182-L190) or [if an answer has already been submitted](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/interface.py#L193-L199) 16 | - If, for some reason, the user decides they wish to clear their cache (for example, if they believe their input to be corrupted) they can do so by using the [`aoc clean`](https://github.com/Starwort/aoc_helper/blob/master/aoc_helper/main.py#L91-L121) command. 17 | 18 | ## Installation 19 | 20 | Install `aoc_helper` with `pip`! 21 | 22 | ```bash 23 | pip install aoc_helper 24 | # install the dependencies required for the Command Line Interface 25 | pip install aoc_helper[cli] 26 | # install the dependencies required for colour 27 | pip install aoc_helper[fancy] 28 | # install all additional dependencies 29 | pip install aoc_helper[cli,fancy] 30 | # or 31 | pip install aoc_helper[full] 32 | ``` 33 | 34 | ## Configuration 35 | 36 | When you first use any function that interfaces with Advent of Code, you will be prompted to enter your session token. 37 | 38 | Your session token is stored as a *HTTPOnly cookie*. This means there is no way of extracting it with JavaScript, you either must 39 | use a browser extension such as [EditThisCookie](http://www.editthiscookie.com/), or follow [this guide](https://github.com/wimglenn/advent-of-code-wim/issues/1) 40 | 41 | This token is stored in `~/.config/aoc_helper/token.txt` (`C:\Users\YOUR_USERNAME\.config\aoc_helper\token.txt` on Windows, 42 | probably), and other `aoc_helper` data is stored in this folder (such as your input and submission caches). 43 | 44 | If, for whatever reason, you feel the need to clear your caches, you can do so by deleting the relevant folders in `aoc_helper`'s 45 | configuration folder. 46 | 47 | ## Command Line Interface 48 | 49 | `aoc_helper` has a command line interface, accessed by running `python -m aoc_helper` or `aoc` followed by the command line arguments. Its commands are detailed below: 50 | 51 | ### `fetch` 52 | 53 | `aoc fetch [--year ]` 54 | 55 | Fetch your input for a given day. 56 | YEAR is the current year by default. 57 | 58 | Examples (written during 2020): 59 | 60 | ```bash 61 | aoc fetch 2 # fetches input for 2020 day 2 62 | aoc fetch 24 --year 2019 # fetches input for 2019 day 24 63 | ``` 64 | 65 | ### `submit` 66 | 67 | `aoc submit [--year ]` 68 | 69 | Submits your answer for a given day and part. 70 | YEAR is the current year by default. 71 | 72 | Examples (written during 2020): 73 | 74 | ```bash 75 | aoc submit 2 1 643 # submits 643 as the answer for 2020 day 2 part 1 76 | aoc submit 24 1 12531574 --year 2019 # submits 12531574 as the answer for 2019 day 2 part 1 77 | ``` 78 | 79 | ### `template` 80 | 81 | `aoc template [--year YEAR]` 82 | 83 | Generates templates for your advent of code folder. 84 | YEAR is the current year by default. 85 | 86 | `DAYS` must be a comma-separated list of date ranges, which may be one of: 87 | 88 | - `all`, to generate a template for every day in the year 89 | - A single integer in the range \[1, 25] to generate a single template 90 | - A pair of integers in the range \[1, 25], separated with a hyphen, to generate template for each day in the range 91 | 92 | Note that even if a day is included in the list of days more than once (including implicitly, within ranges or `all`), it will only be generated once. 93 | 94 | Filenames are formatted as `day_NUMBER.py` where `NUMBER` is the 2-digit day number. 95 | 96 | Examples (written during 2020): 97 | 98 | ```bash 99 | aoc template all # generates day_01.py to day_25.py, with aoc_helper methods referencing 2020, in the current folder 100 | aoc template 3 --year 2019 # generates day_03.py, with aoc_helper methods referencing 2019, in the current folder 101 | aoc template 3-5 --year 2017 # generates day_03.py to day_05.py, with aoc_helper methods referencing 2017, in the current folder 102 | aoc template 3-5,7,9,9-10 # generates files for days 3, 4, 5, 7, 9, and 10 103 | ``` 104 | -------------------------------------------------------------------------------- /aoc_helper/main.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import os 4 | import pathlib 5 | import re 6 | import sys 7 | import typing 8 | from contextlib import redirect_stdout 9 | from datetime import date, timedelta 10 | from itertools import zip_longest 11 | 12 | from .utils import chunk_default 13 | 14 | try: 15 | import click 16 | from click_aliases import ClickAliasedGroup 17 | except ImportError: 18 | print("Missing dependencies for the CLI. Please `pip install aoc_helper[cli]`") 19 | exit(1) 20 | 21 | from .data import ( 22 | DATA_DIR, 23 | DEFAULT_YEAR, 24 | HEADERS, 25 | PRACTICE_DATA_DIR, 26 | RANK, 27 | URL, 28 | get_cookie, 29 | ) 30 | from .interface import _estimate_practice_rank, _format_timedelta, days_in_year 31 | from .interface import fetch as fetch_input 32 | from .interface import submit as submit_answer 33 | from .interface import validate_token as validate_token_ 34 | 35 | TEMPLATE = (pathlib.Path(__file__).parent / "day_template.py").read_text() 36 | 37 | RANGE_REGEX = re.compile(r"(2[0-5]|1[0-9]|0?[1-9])-(2[0-5]|1[0-9]|0?[2-9])") 38 | 39 | 40 | def parse_range(_, __, value: str) -> list[int]: 41 | ranges = value.split(",") 42 | days: set[int] = set() 43 | for range_ in ranges: 44 | match = RANGE_REGEX.match(range_) 45 | if match: 46 | lb = int(match[1]) 47 | ub = int(match[2]) + 1 48 | days |= set(range(lb, ub)) 49 | elif range_.isnumeric() and 1 <= int(range_) <= 25: 50 | days.add(int(range_)) 51 | elif range_ == "all": 52 | days = set(range(1, 26)) 53 | else: 54 | raise click.BadParameter( 55 | "every part must be a single day, a range of days in the form " 56 | "a-b, or the word 'all'" 57 | ) 58 | return sorted(days) 59 | 60 | 61 | @click.group(cls=ClickAliasedGroup) 62 | def cli(): 63 | pass 64 | 65 | 66 | @cli.command(aliases=["f", "get", "get-input", "download"]) 67 | @click.argument("day", type=int) 68 | @click.option("--year", type=int, default=DEFAULT_YEAR) 69 | def fetch(day: int, year: int): 70 | """Fetch and print the input for day DAY of --year""" 71 | print(fetch_input(day, year)) 72 | 73 | 74 | @cli.command(aliases=["read-puzzle", "puzzle"]) 75 | @click.argument("day", type=int) 76 | @click.option("--year", type=int, default=DEFAULT_YEAR) 77 | @click.option( 78 | "-c", 79 | "--colour", 80 | "--color", 81 | type=click.Choice(["auto", "always", "never"]), 82 | default="auto", 83 | ) 84 | def read(day: int, year: int, colour: typing.Literal["auto", "always", "never"]): 85 | """Read the puzzle for day DAY or --year in your terminal""" 86 | import sys 87 | from os import getenv 88 | 89 | import requests 90 | from bs4 import BeautifulSoup, Tag 91 | 92 | try: 93 | from rich.console import Console, ConsoleOptions, RenderResult 94 | from rich.markdown import TextElement 95 | from rich.segment import Segment 96 | except ImportError: 97 | print( 98 | "Missing dependency rich. Please `pip install rich`," 99 | " `pip install aoc_helper[full]` or `pip install aoc_helper[fancy]`", 100 | file=sys.stderr, 101 | ) 102 | return 103 | puzzle_info = requests.get( 104 | URL.format(year=year, day=day), headers=HEADERS, cookies=get_cookie() 105 | ) 106 | soup = BeautifulSoup(puzzle_info.text, "html.parser") 107 | puzzle: Tag = soup.find("main") # type: ignore 108 | for emphasis in puzzle.find_all("em"): 109 | emphasis.string.replace_with(f"[bold gold1]{emphasis.text}[/]") 110 | terminal = Console() 111 | if colour == "auto": 112 | pager = getenv("PAGER") or "" 113 | colour = ( 114 | "always" 115 | if terminal.is_terminal 116 | and any( 117 | flag in pager 118 | for flag in ("-r", "-R", "--raw-control-chars", "--RAW-CONTROL-CHARS") 119 | ) 120 | else "never" 121 | ) 122 | 123 | class CodeBlock(TextElement): 124 | def __init__(self, text: str): 125 | self.text = text 126 | 127 | def __rich_console__( 128 | self, console: Console, options: ConsoleOptions 129 | ) -> RenderResult: 130 | lines = self.text.split("\n") 131 | line_count = len(lines) 132 | width = len(str(line_count)) 133 | render_options = options.update(width=console.width - 7 - width) 134 | for line_no, line in enumerate(lines, 1): 135 | inner_lines = console.render_lines(line, render_options) 136 | for i, line in enumerate(inner_lines): 137 | yield Segment(" ") 138 | yield Segment( 139 | f"{line_no:>{width}} │ " if i == 0 else " " * (width) + " │ " 140 | ) 141 | yield from line 142 | yield Segment("\n") 143 | 144 | class BulletItem(TextElement): 145 | def __init__(self, text: str): 146 | self.text = text 147 | 148 | def __rich_console__( 149 | self, console: Console, options: ConsoleOptions 150 | ) -> RenderResult: 151 | render_options = options.update(width=console.width - 2) 152 | lines = console.render_lines(self.text, render_options) 153 | for i, line in enumerate(lines): 154 | yield Segment("- " if i == 0 else " ") 155 | yield from line 156 | yield Segment("\n") 157 | 158 | class NumberedItem(TextElement): 159 | def __init__(self, text: str, number: int, width: int): 160 | self.text = text 161 | self.number = number 162 | self.width = width 163 | 164 | def __rich_console__( 165 | self, console: Console, options: ConsoleOptions 166 | ) -> RenderResult: 167 | render_options = options.update(width=console.width - self.width - 2) 168 | lines = console.render_lines(self.text, render_options) 169 | for i, line in enumerate(lines): 170 | yield ( 171 | f"{self.number:>{self.width}}. " 172 | if i == 0 173 | else (" " * self.width + " ") 174 | ) 175 | yield from line 176 | yield "\n" 177 | 178 | with terminal.pager(styles=colour == "always") as pager: 179 | first = True 180 | for el in puzzle.children: 181 | if not isinstance(el, Tag): 182 | continue 183 | if el.name == "article": 184 | for part_el in el.children: 185 | if not isinstance(part_el, Tag): 186 | continue 187 | if not first: 188 | terminal.print() 189 | first = False 190 | if part_el.name == "h2": 191 | terminal.rule( 192 | "[bold gold1 underline]" + part_el.text.strip("- "), 193 | style="bold gold1", 194 | ) 195 | elif part_el.name == "p": 196 | terminal.print(part_el.text) 197 | elif part_el.name == "pre": 198 | terminal.print(CodeBlock(part_el.text.strip("\n"))) 199 | elif part_el.name == "ul": 200 | for li in part_el.find_all("li"): 201 | terminal.print(BulletItem(li.text)) 202 | elif part_el.name == "ol": 203 | lis = list(part_el.find_all("li")) 204 | width = len(str(len(lis))) 205 | for i, li in enumerate(lis, 1): 206 | terminal.print(NumberedItem(li.text, i, width)) 207 | 208 | elif el.name == "p": 209 | if el.text.startswith("Your puzzle answer was"): 210 | terminal.print() 211 | terminal.print() 212 | terminal.print(el.text) 213 | terminal.print() 214 | elif el.text.startswith("Both parts of this puzzle are complete!"): 215 | terminal.print(f"[bold gold1]{el.text}[/]") 216 | return 217 | 218 | 219 | @cli.command(aliases=["s", "send"]) 220 | @click.argument("day", type=int) 221 | @click.argument("part", type=click.Choice(["1", "2"])) 222 | @click.argument("answer") 223 | @click.option("--year", type=int, default=DEFAULT_YEAR) 224 | @click.option("--practice", is_flag=True) 225 | def submit( 226 | day: int, part: typing.Literal["1", "2"], answer: str, year: int, practice: bool 227 | ): 228 | """Submit the answer for day DAY part PART of --year""" 229 | _ = practice # used via sys.argv 230 | submit_answer(day, int(part), answer, year) 231 | 232 | 233 | @cli.command(aliases=["t", "create"]) 234 | @click.argument("days", callback=parse_range) 235 | @click.option("--year", type=int, default=DEFAULT_YEAR) 236 | def template(days: list[int], year: int): 237 | """Generate an answer stub for every day of DAYS in --year""" 238 | invalid = [day for day in days if day > days_in_year(year)] 239 | if invalid: 240 | raise click.BadParameter( 241 | "values out of range: " + ", ".join(map(str, invalid)), 242 | param_hint="days", 243 | ) 244 | for day in days: 245 | print(f"Generating day_{day:0>2}.py") 246 | pathlib.Path(f"day_{day:0>2}.py").write_text( 247 | TEMPLATE.format(day=day, year=year) 248 | ) 249 | 250 | 251 | @cli.command(aliases=["get-browser", "set-browser"]) 252 | @click.argument("state", type=bool, default=None, required=False) 253 | def browser(state: typing.Optional[bool]): 254 | """Enable, disable, or check browser automation""" 255 | file = DATA_DIR / ".nobrowser" 256 | if state is None: 257 | print(f"Web browser automation is {'dis' if file.exists() else 'en'}abled.") 258 | elif state: 259 | file.unlink(True) 260 | print("Enabled web browser automation") 261 | else: 262 | file.touch() 263 | print("Disabled web browser automation") 264 | 265 | 266 | @cli.command( 267 | aliases=[ 268 | "clear", 269 | "purge", 270 | "delete", 271 | "clear-cache", 272 | "clean-cache", 273 | "purge-cache", 274 | "delete-cache", 275 | ] 276 | ) 277 | @click.argument("days", callback=parse_range) 278 | @click.argument("year", type=int, default=DEFAULT_YEAR) 279 | @click.option( 280 | "--type", 281 | type=click.Choice( 282 | ["input", "submissions", "solutions", "1", "2", "tests", "practice", "all"] 283 | ), 284 | help="What to delete", 285 | default="input", 286 | ) 287 | def clean(days: list[int], year: int, type: str): 288 | """Clean the cached --type data for DAYS of YEAR""" 289 | for day in days: 290 | if type in ("input", "all"): 291 | (DATA_DIR / f"{year}" / f"{day}.in").unlink(True) 292 | if type in ("submissions", "all"): 293 | file = DATA_DIR / f"{year}" / f"{day}" / "submissions.json" 294 | if ( 295 | not file.exists() 296 | or not RANK.search(file.read_text()) 297 | or click.confirm( 298 | f"Are you sure you want to delete your submissions for {year} day" 299 | f" {day}? Your cached rank will be forgotten" 300 | ) 301 | ): 302 | file.unlink(True) 303 | if type in ("practice", "all"): 304 | folder = PRACTICE_DATA_DIR / f"{year}" / f"{day}" 305 | if not folder.exists() or click.confirm( 306 | f"Are you sure you want to delete your practice data for {year} day" 307 | f" {day}? {len(os.listdir(folder))} entries will be forgotten" 308 | ): 309 | for file in os.listdir(folder): 310 | (folder / file).unlink(True) 311 | folder.rmdir() 312 | if type in ("solutions", "all", "1"): 313 | (DATA_DIR / f"{year}" / f"{day}" / "1.solution").unlink(True) 314 | if type in ("solutions", "all", "2"): 315 | (DATA_DIR / f"{year}" / f"{day}" / "2.solution").unlink(True) 316 | if type in ("tests", "all"): 317 | (DATA_DIR / f"{year}" / f"{day}" / "tests.json").unlink(True) 318 | 319 | 320 | @cli.command( 321 | name="visualise-cache", 322 | aliases=["cache", "visualize-cache", "list-cache", "view-cache"], 323 | ) 324 | @click.option( 325 | "-c", 326 | "--colour", 327 | "--color", 328 | type=click.Choice(["auto", "always", "never"]), 329 | default="auto", 330 | ) 331 | @click.option( 332 | "--validate-token/--no-validate-token", 333 | default=True, 334 | ) 335 | @click.option( 336 | "--update-token-on-invalid/--no-update-token-on-invalid", 337 | default=False, 338 | help=( 339 | "Whether to prompt for a new token if the current one is invalid." 340 | " Only takes effect if --validate-token is set" 341 | ), 342 | ) 343 | def visualise_cache( 344 | colour: typing.Literal["auto", "always", "never"], 345 | validate_token: bool, 346 | update_token_on_invalid: bool, 347 | ): 348 | """Get a visual overview of the aoc_helper cache""" 349 | from .formatting import print 350 | 351 | try: 352 | from rich.console import Console 353 | except ImportError: 354 | if colour != "never": 355 | print( 356 | "Missing dependency rich. Colour will be disabled. Please", 357 | "`pip install rich`, `pip install aoc_helper[full]`, or", 358 | "`pip install aoc_helper[fancy]`", 359 | file=sys.stderr, 360 | ) 361 | 362 | def rule(title: str): 363 | width = 73 - 2 - len(title) 364 | left = width // 2 365 | right = width - left 366 | print(f"{'=' * left} {title} {'=' * right}") 367 | 368 | real_width = 73 369 | else: 370 | terminal = Console() 371 | if colour == "auto": 372 | colour = "always" if terminal.is_terminal else "never" 373 | if colour == "never": 374 | terminal.no_color = True 375 | else: 376 | Console._environ["FORCE_COLOR"] = "1" # type: ignore 377 | 378 | real_width = terminal.width if terminal.is_terminal else 73 379 | terminal.width = 73 380 | rule = terminal.rule 381 | 382 | if colour == "always": 383 | from .formatting import GREEN, RED, RESET, YELLOW 384 | else: 385 | GREEN = RED = YELLOW = RESET = "" 386 | 387 | token = get_cookie(missing_ok=not update_token_on_invalid) 388 | cached_years = sorted(DATA_DIR.iterdir()) 389 | did_print = False 390 | 391 | if token is not None: 392 | valid = None 393 | if validate_token: 394 | valid = validate_token_(update_token_on_invalid) 395 | if colour == "always" or valid is not False: 396 | print( 397 | { 398 | None: f"{YELLOW}Token (not validated){RESET}", 399 | True: f"{GREEN}Token{RESET}", 400 | False: f"{RED}Token{RESET}", 401 | }[valid] 402 | ) 403 | did_print = True 404 | elif colour == "always": 405 | print(f"{RED}Token{RESET}") 406 | did_print = True 407 | 408 | def format( 409 | val: str, exists: bool, success_colour: str = GREEN, fail_colour: str = RED 410 | ): 411 | if colour == "always": 412 | text_col = success_colour if exists else fail_colour 413 | return f"{text_col}{val}{text_col and RESET}" 414 | else: 415 | return val if exists else (" " * len(val)) 416 | 417 | years = [year for year in cached_years if year.is_dir() and year.name.isnumeric()] 418 | # input tests 419 | # solutions 1 2 420 | # practice data 421 | days = range(1, 26) 422 | blocks_that_fit = max((real_width + 4) // 77, 1) 423 | 424 | blocks = [] 425 | for year in years: 426 | with io.StringIO() as buf, redirect_stdout(buf): 427 | rule(year.name) 428 | has_input = [(year / f"{day}.in").exists() for day in days] 429 | has_tests = [(year / f"{day}" / "tests.json").exists() for day in days] 430 | has_solution_1 = [(year / f"{day}" / "1.solution").exists() for day in days] 431 | has_solution_2 = [(year / f"{day}" / "2.solution").exists() for day in days] 432 | has_any_solution = [a or b for a, b in zip(has_solution_1, has_solution_2)] 433 | has_practice = [ 434 | (PRACTICE_DATA_DIR / year.name / f"{day}").exists() for day in days 435 | ] 436 | for block in range(5): 437 | print( 438 | *(f"{day:^13}" for day in range(block * 5 + 1, block * 5 + 6)), 439 | sep=" ", 440 | ) 441 | print( 442 | *( 443 | f"{format('Input', has_input[day])} {format('Tests', has_tests[day])}" 444 | for day in range(block * 5, block * 5 + 5) 445 | ), 446 | sep=" ", 447 | ) 448 | print( 449 | *( 450 | f"{format('Solutions', has_any_solution[day], '', '')}" 451 | f" {format('1', has_solution_1[day])}" 452 | f" {format('2', has_solution_2[day])}" 453 | for day in range(block * 5, block * 5 + 5) 454 | ), 455 | sep=" ", 456 | ) 457 | print( 458 | *( 459 | format("Practice data", has_practice[day]) 460 | for day in range(block * 5, block * 5 + 5) 461 | ), 462 | sep=" ", 463 | ) 464 | blocks.append(buf.getvalue()) 465 | 466 | import builtins 467 | 468 | for blocks in chunk_default(blocks, blocks_that_fit, ""): 469 | if did_print: 470 | print() 471 | blocks = [block.splitlines() for block in blocks] 472 | for lines in zip_longest(*blocks, fillvalue=""): 473 | builtins.print(" ".join(lines)) 474 | 475 | 476 | @cli.command(name="validate-token", aliases=["token", "check-token"]) 477 | def validate_token(): 478 | """ 479 | Validate the stored session token. Will prompt for a new token if the 480 | current one is invalid. 481 | """ 482 | if validate_token_(): 483 | print("Token is valid") 484 | 485 | 486 | @cli.command(name="practice-results") 487 | @click.argument("day", type=int) 488 | @click.option("--year", type=int, default=DEFAULT_YEAR) 489 | def practice_results(day: int, year: int): 490 | """Show all practice results for day DAY of --year""" 491 | import locale 492 | 493 | locale.setlocale( 494 | locale.LC_TIME, "" 495 | ) # https://github.com/python/cpython/issues/73643 496 | folder = PRACTICE_DATA_DIR / f"{year}" / f"{day}" 497 | if not folder.exists(): 498 | print("No practice results found") 499 | return 500 | 501 | def format_result(result: typing.Optional[tuple[int, int, int]]): 502 | if not result: 503 | return "no rank" 504 | estimated, best, worst = result 505 | if best == worst: 506 | return f"rank {best}" 507 | if worst > 100: 508 | worst = "100+" 509 | return f"approximately rank {estimated} - {best} to {worst}" 510 | 511 | for file in sorted(folder.iterdir()): 512 | attempt_year, attempt_month, attempt_day = map(int, file.stem.split("-")) 513 | attempt_date = date(attempt_year, attempt_month, attempt_day).strftime("%x") 514 | results: list[float] = json.loads(file.read_text()) 515 | if len(results) == 1: 516 | solve_time = timedelta(seconds=results[0]) 517 | result = format_result(_estimate_practice_rank(day, 1, year, solve_time)) 518 | print( 519 | f"{attempt_date} - Part 1: {_format_timedelta(solve_time)} ({result})," 520 | " Part 2: (unsolved)" 521 | ) 522 | elif len(results) == 2: 523 | solve_time_1 = timedelta(seconds=results[0]) 524 | solve_time_2 = timedelta(seconds=results[1]) 525 | result_1 = format_result( 526 | _estimate_practice_rank(day, 1, year, solve_time_1) 527 | ) 528 | result_2 = format_result( 529 | _estimate_practice_rank(day, 2, year, solve_time_2) 530 | ) 531 | print( 532 | f"{attempt_date} - " 533 | f"Part 1: {_format_timedelta(solve_time_1)} ({result_1}), " 534 | f"Part 2: {_format_timedelta(solve_time_2)} ({result_2})" 535 | ) 536 | -------------------------------------------------------------------------------- /aoc_helper/interface.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import pathlib 4 | import sys 5 | import typing 6 | import webbrowser 7 | from warnings import warn 8 | 9 | import requests 10 | from bs4 import BeautifulSoup as Soup 11 | 12 | from .data import ( 13 | DATA_DIR, 14 | DEFAULT_YEAR, 15 | HEADERS, 16 | LEADERBOARD_URL, 17 | PRACTICE_DATA_DIR, 18 | RANK, 19 | TODAY, 20 | URL, 21 | WAIT_TIME, 22 | get_cookie, 23 | ) 24 | from .formatting import ( 25 | BLUE, 26 | GOLD, 27 | GREEN, 28 | RED, 29 | RESET, 30 | YELLOW, 31 | print, 32 | print_raw, 33 | wait, 34 | work, 35 | ) 36 | 37 | T = typing.TypeVar("T") 38 | U = typing.TypeVar("U") 39 | 40 | 41 | def days_in_year(year: int) -> int: 42 | if year < 2015: 43 | raise ValueError("Advent of Code started in 2015.") 44 | if year < 2025: 45 | return 25 46 | return 12 # https://adventofcode.com/2025/about#faq_num_days 47 | 48 | 49 | def _open_page(page: str) -> None: 50 | """Open the page if the user hasn't opted out""" 51 | if not (DATA_DIR / ".nobrowser").exists(): 52 | webbrowser.open(page) 53 | 54 | 55 | def _make(folder: pathlib.Path) -> None: 56 | """Create folder if it doesn't exist.""" 57 | if not folder.exists(): 58 | folder.mkdir(parents=True) 59 | 60 | 61 | def _pretty_print(message: str) -> None: 62 | """Analyse and print message""" 63 | if message.startswith("That's the"): 64 | print(GREEN + message + RESET) 65 | elif message.startswith("You don't"): 66 | print(YELLOW + message + RESET) 67 | elif message.startswith("That's not"): 68 | print(RED + message + RESET) 69 | elif message.startswith("You got rank"): 70 | print(GOLD + message + RESET) 71 | else: 72 | raise ValueError("Failed to parse response.") 73 | 74 | 75 | def validate_token(allow_prompt: bool = True) -> bool: 76 | cookie = get_cookie(not allow_prompt) 77 | if cookie is None: 78 | return False 79 | resp = requests.get( 80 | URL.format(day=1, year=2015) + "/input", 81 | cookies=get_cookie(), 82 | headers=HEADERS, 83 | ) 84 | if not resp.ok: 85 | if resp.status_code != 400: 86 | raise ValueError("Received bad response") 87 | if not allow_prompt: 88 | return False 89 | token_file = DATA_DIR / "token.txt" 90 | print(f"{RED}Your token has expired. Please enter your new" f" token.{RESET}") 91 | token = input(">>> ") 92 | token_file.write_text(token) 93 | return validate_token(True) 94 | return True 95 | 96 | 97 | def fetch(day: int = TODAY, year: int = DEFAULT_YEAR, never_print: bool = False) -> str: 98 | """Fetch and return the input for `day` of `year`. 99 | 100 | If `--practice` is provided on the command line, pretend that today is the 101 | day of the puzzle and wait for puzzle unlock accordingly. 'today' is 102 | determined by UTC; from 0:00 to 5:00 UTC, this puzzle will block until 5:00 103 | UTC - after that, until 0:00 UTC the next day, input fetching will be 104 | instant. 105 | 106 | All inputs are cached in `aoc_helper.DATA_DIR`.""" 107 | import sys 108 | 109 | day_ = str(day) 110 | year_ = str(year) 111 | 112 | _make(DATA_DIR / year_) 113 | input_path = DATA_DIR / year_ / (day_ + ".in") 114 | 115 | if input_path.exists(): 116 | should_print = False 117 | if "--practice" in sys.argv: 118 | now = datetime.datetime.utcnow() 119 | unlock = datetime.datetime(now.year, now.month, now.day, 5) 120 | if now < unlock: 121 | should_print = True 122 | wait( 123 | f"{YELLOW}Waiting for puzzle unlock...{RESET}", 124 | (unlock - now).total_seconds(), 125 | ) 126 | print(GREEN + "Fetching input!" + RESET) 127 | _open_page(URL.format(day=day, year=year)) 128 | input_data = input_path.read_text() 129 | if "--practice" in sys.argv and should_print: 130 | print(input_data) 131 | return input_data 132 | else: 133 | unlock = datetime.datetime(year, 12, day, 5) 134 | now = datetime.datetime.utcnow() 135 | if "--practice" in sys.argv: 136 | unlock = unlock.replace(year=now.year, month=now.month, day=now.day) 137 | if now < unlock: 138 | if day == 1: 139 | # If we're waiting for the first puzzle, the user's token may 140 | # have expired. We should check that now. 141 | validate_token() 142 | now = datetime.datetime.utcnow() 143 | wait( 144 | f"{YELLOW}Waiting for puzzle unlock...{RESET}", 145 | (unlock - now).total_seconds(), 146 | ) 147 | print(GREEN + "Fetching input!" + RESET) 148 | _open_page(URL.format(day=day, year=year)) 149 | resp = requests.get( 150 | URL.format(day=day, year=year) + "/input", 151 | cookies=get_cookie(), 152 | headers=HEADERS, 153 | ) 154 | if not resp.ok: 155 | if resp.status_code == 400: 156 | token_file = DATA_DIR / "token.txt" 157 | print( 158 | f"{RED}Your token has expired. Please enter your new token.{RESET}" 159 | ) 160 | token = input(">>> ") 161 | token_file.write_text(token) 162 | return fetch(day, year, never_print) 163 | raise ValueError("Received bad response") 164 | data = resp.text.strip("\n") 165 | input_path.write_text(data) 166 | if not never_print: 167 | print_raw(data) 168 | return data 169 | 170 | 171 | def _load_leaderboard_times( 172 | day: int, year: int = DEFAULT_YEAR 173 | ) -> tuple[list[datetime.timedelta], list[datetime.timedelta]]: 174 | day_ = str(day) 175 | year_ = str(year) 176 | 177 | day_dir = DATA_DIR / year_ / day_ 178 | _make(day_dir) 179 | 180 | # Load cached leaderboards 181 | leaderboards = day_dir / "leaderboards.json" 182 | if leaderboards.exists(): 183 | with leaderboards.open() as f: 184 | # ([seconds], [seconds]) 185 | data: list[list[int]] = json.load(f) 186 | return [datetime.timedelta(seconds=t) for t in data[0]], [datetime.timedelta(seconds=t) for t in data[1]] # type: ignore 187 | else: 188 | leaderboard_page = requests.get( 189 | LEADERBOARD_URL.format(day=day, year=year), headers=HEADERS 190 | ) 191 | soup = Soup(leaderboard_page.text, "html.parser") 192 | times = soup.select(".leaderboard-entry") 193 | part_1_times: list[datetime.timedelta] = [] 194 | part_2_times: list[datetime.timedelta] = [] 195 | in_part_2 = False 196 | for leaderboard_time in times: 197 | if leaderboard_time.span.text == " 1)": # type: ignore 198 | in_part_2 = not in_part_2 199 | time_to_solve = datetime.datetime.strptime( 200 | leaderboard_time.select_one(".leaderboard-time").text, # type: ignore 201 | "%b %d %H:%M:%S", 202 | ) - datetime.datetime(1900, 12, day) 203 | if in_part_2: 204 | part_2_times.append(time_to_solve) 205 | else: 206 | part_1_times.append(time_to_solve) 207 | if not part_1_times: 208 | # no part 2 leaderboard, so boards were read in backwards 209 | part_2_times, part_1_times = part_1_times, part_2_times 210 | if len(part_1_times) == len(part_2_times) == 100: 211 | # both leaderboards are full, cache them 212 | with leaderboards.open("w") as f: 213 | json.dump( 214 | ( 215 | [t.total_seconds() for t in part_1_times], 216 | [t.total_seconds() for t in part_2_times], 217 | ), 218 | f, 219 | ) 220 | return part_1_times, part_2_times 221 | 222 | 223 | def _practice_result_for(day: int, year: int) -> list[int]: 224 | practice_data_dir = PRACTICE_DATA_DIR / str(year) / str(day) 225 | _make(practice_data_dir) 226 | try: 227 | with open( 228 | practice_data_dir 229 | / f"{datetime.datetime.utcnow().year:04}-{datetime.datetime.utcnow().month:02}-{datetime.datetime.utcnow().day:02}.json", 230 | "r", 231 | ) as f: 232 | return json.load(f) 233 | except FileNotFoundError: 234 | return [] 235 | 236 | 237 | def _calculate_practice_result(day: int, part: int, year: int) -> None: 238 | if "--practice" not in sys.argv: 239 | return 240 | now = datetime.datetime.utcnow() 241 | solve_time = datetime.timedelta( 242 | hours=now.hour - 5, 243 | minutes=now.minute, 244 | seconds=now.second, 245 | microseconds=now.microsecond, 246 | ) 247 | practice_data_dir = PRACTICE_DATA_DIR / str(year) / str(day) 248 | _make(practice_data_dir) 249 | 250 | filename = f"{now.year:04}-{now.month:02}-{now.day:02}.json" 251 | 252 | try: 253 | with open(practice_data_dir / filename) as f: 254 | data: list[float] = json.load(f) 255 | except (json.decoder.JSONDecodeError, FileNotFoundError): 256 | data = [] 257 | with open(practice_data_dir / filename, "w") as f: 258 | data.append(solve_time.total_seconds()) 259 | json.dump(data, f) 260 | _report_practice_result(day, part, year, solve_time) 261 | 262 | 263 | def _estimate_practice_rank( 264 | day: int, part: int, year: int, solve_time: datetime.timedelta 265 | ) -> typing.Optional[tuple[int, int, int]]: 266 | import bisect 267 | 268 | leaderboard = _load_leaderboard_times(day, year)[part - 1] 269 | # aoc truncates solve times, so we do too for the purpose of sorting 270 | truncated_solve_time = datetime.timedelta(seconds=int(solve_time.total_seconds())) 271 | best_possible_rank = bisect.bisect_left(leaderboard, truncated_solve_time) + 1 272 | worst_possible_rank = bisect.bisect_right(leaderboard, truncated_solve_time) + 1 273 | if best_possible_rank > 100: 274 | return None 275 | span = worst_possible_rank - best_possible_rank 276 | approx_rank = best_possible_rank + round(span * solve_time.microseconds / 1_000_000) 277 | return approx_rank, best_possible_rank, worst_possible_rank 278 | 279 | 280 | def _format_timedelta(solve_time: datetime.timedelta) -> str: 281 | minutes, seconds = divmod(solve_time.seconds, 60) 282 | hours, minutes = divmod(minutes, 60) 283 | if hours > 0: 284 | return f"{hours:02}:{minutes:02}:{seconds:02}" 285 | else: 286 | return f"{minutes:02}:{seconds:02}.{solve_time.microseconds // 10_000:02}" 287 | 288 | 289 | def _report_practice_result( 290 | day: int, part: int, year: int, solve_time: datetime.timedelta 291 | ) -> None: 292 | print( 293 | f"{GREEN}You solved the puzzle in" 294 | f" {BLUE}{_format_timedelta(solve_time)}{GREEN}!{RESET}" 295 | ) 296 | result = _estimate_practice_rank(day, part, year, solve_time) 297 | if not result: 298 | print(f"{YELLOW}You would not have achieved a leaderboard position.{RESET}") 299 | else: 300 | likely_rank, best_possible_rank, worst_possible_rank = result 301 | if best_possible_rank == worst_possible_rank: 302 | print(f"{GOLD}You would have achieved rank {best_possible_rank}!{RESET}") 303 | else: 304 | if worst_possible_rank > 100: 305 | worst_possible_rank = "100+" 306 | print( 307 | f"{GOLD}You would have achieved approximately rank" 308 | f" {likely_rank} ({best_possible_rank} to" 309 | f" {worst_possible_rank})!{RESET}" 310 | ) 311 | 312 | 313 | def submit(day: int, part: int, answer: typing.Any, year: int = DEFAULT_YEAR) -> None: 314 | """Submit a solution. 315 | 316 | Submissions are cached; submitting an already submitted solution will return the 317 | previous response. 318 | """ 319 | day_ = str(day) 320 | year_ = str(year) 321 | part_ = str(part) 322 | answer_ = str(answer) 323 | 324 | submission_dir = DATA_DIR / year_ / day_ 325 | _make(submission_dir) 326 | 327 | # Load cached solutions 328 | submissions = submission_dir / "submissions.json" 329 | if submissions.exists(): 330 | with submissions.open() as f: 331 | solutions = json.load(f) 332 | else: 333 | solutions = {"1": {}, "2": {}} 334 | 335 | # Check if solved 336 | solution_file = submission_dir / f"{part}.solution" 337 | if solution_file.exists(): 338 | solution = solution_file.read_text() 339 | if "--practice" in sys.argv: 340 | if solution == answer_: 341 | _calculate_practice_result(day, part, year) 342 | else: 343 | print( 344 | f"{RED}Submitted {BLUE}{answer_}{RESET}; that's not the right" 345 | f" answer.{RESET}" 346 | ) 347 | return 348 | if "--force-run" in sys.argv: 349 | if solution != answer_: 350 | print( 351 | f"{RED}[Day {BLUE}{day}{RESET} part {BLUE}{part}{RESET}]" 352 | f" Solution produced incorrect answer {BLUE}{answer_}{RESET}!" 353 | f" (Correct answer: {BLUE}{solution}{RESET}){RESET}" 354 | ) 355 | else: 356 | print( 357 | f"{GREEN}[Day {BLUE}{day}{RESET} part {BLUE}{part}{RESET}]" 358 | f" Solution produced correct answer {BLUE}{answer_}{RESET}!{RESET}" 359 | ) 360 | _print_rank(solutions[part_][solution]) 361 | return 362 | print( 363 | f"Day {BLUE}{day}{RESET} part {BLUE}{part}{RESET} " 364 | "has already been solved.\nThe solution was: " 365 | f"{BLUE}{solution}{RESET}" 366 | ) 367 | _print_rank(solutions[part_][solution]) 368 | return 369 | 370 | # Check if answer has already been submitted 371 | if answer_ in solutions[part_]: 372 | print( 373 | f"{YELLOW}Solution: {BLUE}{answer}{RESET} to part " 374 | f"{BLUE}{part}{RESET} has already been submitted.\n" 375 | f"Response was:{RESET}" 376 | ) 377 | return _pretty_print(solutions[part_][answer_]) 378 | 379 | while True: 380 | print( 381 | f"Submitting {BLUE}{answer}{RESET} as the solution to part " 382 | f"{BLUE}{part}{RESET}..." 383 | ) 384 | resp = requests.post( 385 | url=URL.format(day=day, year=year) + "/answer", 386 | cookies=get_cookie(), 387 | data={"level": part_, "answer": answer_}, 388 | headers=HEADERS, 389 | ) 390 | if not resp.ok: 391 | if resp.status_code == 400: 392 | token_file = DATA_DIR / "token.txt" 393 | token = input( 394 | "Your token has expired. Please enter your new token\n>>> " 395 | ) 396 | token_file.write_text(token) 397 | return submit(day, part, answer, year) 398 | raise ValueError("Received bad response") 399 | 400 | article = Soup(resp.text, "html.parser").article 401 | assert article is not None 402 | msg = article.text 403 | 404 | if msg.startswith("You gave"): 405 | print(RED + msg + RESET) 406 | wait_match = WAIT_TIME.search(msg) 407 | assert wait_match is not None 408 | pause = 60 * int(wait_match[1] or 0) + int(wait_match[2]) 409 | wait( 410 | f"{YELLOW}Waiting {BLUE}{pause}{RESET} seconds to retry...{RESET}", 411 | pause, 412 | ) 413 | else: 414 | break 415 | 416 | if msg.startswith("That's the"): 417 | _print_rank(msg) 418 | solution_file.write_text(answer_) 419 | _calculate_practice_result(day, part, year) 420 | if part == 1: 421 | if not resp.url.endswith("#part2"): 422 | resp.url += "#part2" # scroll to part 2 423 | _open_page(resp.url) # open part 2 in the user's browser 424 | else: 425 | _pretty_print(msg) 426 | 427 | # Cache submission 428 | solutions[part_][answer_] = msg 429 | with submissions.open("w") as f: 430 | json.dump(solutions, f) 431 | 432 | 433 | def _print_rank(msg: str) -> None: 434 | match = RANK.search(msg) 435 | if match: 436 | _pretty_print(f"You got rank {match.group(1)} for this puzzle") 437 | else: 438 | _pretty_print(msg) 439 | 440 | 441 | def submit_final(year: int): 442 | day = days_in_year(year) 443 | print(f"{GREEN}Finishing Advent of Code {BLUE}{year}{RESET}!{RESET}") 444 | resp = requests.post( 445 | url=URL.format(day=day, year=year) + "/answer", 446 | cookies=get_cookie(), 447 | data={"level": "2", "answer": "0"}, 448 | headers=HEADERS, 449 | ) 450 | if not resp.ok: 451 | if resp.status_code == 400: 452 | token_file = DATA_DIR / "token.txt" 453 | token = input("Your token has expired. Please enter your new token\n>>> ") 454 | token_file.write_text(token) 455 | return submit_final(year) 456 | raise ValueError("Received bad response") 457 | 458 | print("Response from the server:") 459 | article = Soup(resp.text, "html.parser").article 460 | assert article is not None 461 | print(article.text.strip()) 462 | if len(_practice_result_for(day, year)) < 2: 463 | _calculate_practice_result(day, 2, year) 464 | 465 | 466 | def lazy_submit( 467 | day: int, 468 | solution: typing.Callable[[U], typing.Any], 469 | data: U, 470 | year: int = DEFAULT_YEAR, 471 | ) -> None: 472 | """Run the function only if we haven't seen a solution. 473 | 474 | Will also run the solution if `--force-run` or `--practice` is passed on the 475 | command line. 476 | 477 | solution is expected to be named 'part_one' or 'part_two' 478 | """ 479 | import sys 480 | 481 | part = 1 if solution.__name__ == "part_one" else 2 482 | submission_dir = DATA_DIR / str(year) / str(day) 483 | if day == days_in_year(year) and part == 2: 484 | # Don't try to submit part 2 if part 1 isn't solved 485 | if (submission_dir / "1.solution").exists(): 486 | submit_final(year) 487 | else: 488 | return 489 | solution_file = submission_dir / f"{part}.solution" 490 | # Check if solved 491 | if ( 492 | solution_file.exists() 493 | and "--force-run" not in sys.argv 494 | and ( 495 | "--practice" not in sys.argv # not in practice mode 496 | or len(_practice_result_for(day, year)) 497 | >= part # or solved today in practice 498 | ) 499 | ): 500 | # Load cached solutions 501 | submissions = submission_dir / "submissions.json" 502 | with submissions.open() as f: 503 | solutions = json.load(f) 504 | 505 | solution_ = solution_file.read_text() 506 | print( 507 | f"Day {BLUE}{day}{RESET} part {BLUE}{part}{RESET} " 508 | "has already been solved.\nThe solution was: " 509 | f"{BLUE}{solution_}{RESET}" 510 | ) 511 | _print_rank(solutions[str(part)][solution_]) 512 | else: 513 | answer = work( 514 | f"{YELLOW}Running part" 515 | f" {RESET}{BLUE}{part}{RESET}{YELLOW} solution...{RESET}", 516 | solution, 517 | data, 518 | ) 519 | if answer is not None: 520 | submit(day, part, answer, year) 521 | 522 | 523 | def get_sample_input( 524 | day: int, part: int, year: int = DEFAULT_YEAR 525 | ) -> typing.Optional[tuple[str, str]]: 526 | """Retrieves the example input and answer for the corresponding AOC challenge.""" 527 | testing_dir = DATA_DIR / str(year) / str(day) 528 | _make(testing_dir) 529 | testing_file = testing_dir / "tests.json" 530 | 531 | if testing_file.exists(): 532 | test_info: dict[str, typing.Optional[tuple[str, str]]] = json.loads( 533 | testing_file.read_text() 534 | ) 535 | else: 536 | test_info: dict[str, typing.Optional[tuple[str, str]]] = {} 537 | 538 | if str(part) in test_info: 539 | return test_info[str(part)] 540 | 541 | resp = requests.post( 542 | url=URL.format(day=day, year=year), cookies=get_cookie(), headers=HEADERS 543 | ) 544 | soup = Soup(resp.text, "html.parser") 545 | 546 | example_test_inputs = [] 547 | # Find the example test input for that day. 548 | for possible_test_input in soup.find_all("pre"): 549 | preceding_text = ( 550 | possible_test_input.previous_element.previous_element.text.lower() 551 | ) 552 | if ( 553 | "for example" in preceding_text 554 | or "consider" in preceding_text 555 | or "given" in preceding_text 556 | ) and ":" in preceding_text: 557 | example_test_inputs.append(possible_test_input.text.strip()) 558 | 559 | if not example_test_inputs: 560 | test_info[str(part)] = None 561 | testing_file.write_text(json.dumps(test_info)) 562 | warn( 563 | f"An issue occurred while fetching test data for {year} day" 564 | f" {day} part {part}. You may either ignore this message, or pass" 565 | " custom test data to lazy_test.", 566 | RuntimeWarning, 567 | ) 568 | return 569 | 570 | try: 571 | test_input = example_test_inputs[-1] 572 | 573 | # Attempt to retrieve answer to said example data from puzzle part. 574 | current_part = soup.find_all("article")[part - 1] 575 | last_sentence = current_part.find_all("p")[-2] 576 | answer = last_sentence.find_all("code")[-1] 577 | except IndexError: 578 | test_info[str(part)] = None 579 | testing_file.write_text(json.dumps(test_info)) 580 | warn( 581 | f"An issue occurred while fetching test data for {year} day" 582 | f" {day} part {part}. You may either ignore this message, or pass" 583 | " custom test data to lazy_test.", 584 | RuntimeWarning, 585 | ) 586 | return 587 | if not answer.em: 588 | try: 589 | answer = last_sentence.find_all("em")[-1] 590 | except IndexError: 591 | pass 592 | 593 | try: 594 | answer = answer.text.strip("\n").split()[-1] 595 | except IndexError: 596 | test_info[str(part)] = None 597 | testing_file.write_text(json.dumps(test_info)) 598 | warn( 599 | f"An issue occurred while fetching test data for {year} day" 600 | f" {day} part {part}. You may either ignore this message, or pass" 601 | " custom test data to lazy_test.", 602 | RuntimeWarning, 603 | ) 604 | return 605 | 606 | test_data = test_input, answer 607 | test_info[str(part)] = test_data 608 | testing_file.write_text(json.dumps(test_info)) 609 | return test_data 610 | 611 | 612 | def _test(part: int, answer: str, expected_answer: str) -> None: 613 | assert answer == expected_answer, ( 614 | f"The expected answer for the example test input was {expected_answer} but" 615 | f" your answer was {answer}." 616 | ) 617 | print( 618 | f"{GREEN}Test for part {BLUE}{part}{RESET} succeeded!" 619 | f" The answer for part {BLUE}{part}{RESET} with the test data was:" 620 | f" {BLUE}{answer}{RESET}{RESET}" 621 | ) 622 | 623 | 624 | def lazy_test( 625 | day: int, 626 | parse: typing.Callable[[str], T], 627 | solution: typing.Callable[[T], typing.Any], 628 | year: int = DEFAULT_YEAR, 629 | test_data: typing.Optional[tuple[str, typing.Any]] = None, 630 | ) -> None: 631 | """Test the function with AOC's example data only if we haven't tested it already. 632 | 633 | Solution is expected to be named 'part_one' or 'part_two' 634 | """ 635 | part = 1 if solution.__name__ == "part_one" else 2 636 | testing_dir = DATA_DIR / str(year) / str(day) 637 | _make(testing_dir) 638 | 639 | # If this is part 2, skip fetching/running tests if part 1 hasn't been submitted 640 | if part == 2 and not (testing_dir / "1.solution").exists(): 641 | return 642 | 643 | # If this part has been submitted, skip running tests 644 | if not (testing_dir / f"{part}.solution").exists(): 645 | if test_data is None: # No test data passed (most common) 646 | test_data = get_sample_input(day, part, year) 647 | if test_data is None: # No test data scraped (uncommon) 648 | return 649 | test_input, test_answer = test_data 650 | 651 | answer = work( 652 | f"{YELLOW}Running the test for part {BLUE}{part}{RESET} solution...{RESET}", 653 | solution, 654 | parse(test_input), 655 | ) 656 | if answer is not None: 657 | _test(part, str(answer).strip(), str(test_answer).strip()) 658 | -------------------------------------------------------------------------------- /aoc_helper/utils.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import collections 3 | import copy 4 | import functools 5 | import itertools 6 | import math 7 | import operator 8 | import re 9 | import sys 10 | import typing 11 | from collections import Counter, UserList, deque 12 | from collections.abc import Hashable 13 | from heapq import heapify, heappop, heappush, nlargest, nsmallest 14 | 15 | from typing_extensions import ParamSpec, TypeVarTuple, Unpack 16 | 17 | from aoc_helper.types import ( 18 | AddableT, 19 | AddableU, 20 | HashableU, 21 | MultipliableT, 22 | MultipliableU, 23 | SupportsMean, 24 | SupportsProdNoDefaultT, 25 | SupportsRichComparison, 26 | SupportsRichComparisonT, 27 | SupportsSumNoDefaultT, 28 | ) 29 | 30 | T = typing.TypeVar("T") 31 | T_Co = typing.TypeVar("T_Co", covariant=True) 32 | SpecialisationT = typing.TypeVar("SpecialisationT", covariant=True) 33 | Ts = TypeVarTuple("Ts") 34 | U = typing.TypeVar("U") 35 | GenericU = typing.Generic[T] 36 | P = ParamSpec("P") 37 | 38 | 39 | Iterable = typing.Union["iter[T]", typing.Iterable[T]] 40 | AnyIterable = typing.Union[Iterable[T], builtins.list[T], tuple[T, ...], "list[T]"] 41 | MaybeIterator = typing.Union[T, Iterable["MaybeIterator[T]"]] 42 | 43 | 44 | def extract_ints(raw: str) -> "list[int]": 45 | """Utility function to extract all integers from some string. 46 | 47 | Some inputs can be directly parsed with this function. 48 | """ 49 | return list(map(int, re.findall(r"((?:-|\+)?\d+)", raw))) 50 | 51 | 52 | def extract_uints(raw: str) -> "list[int]": 53 | """Utility function to extract all integers from some string. 54 | 55 | Minus signs will be *ignored*; the output integers will all be positive. 56 | 57 | Some inputs can be directly parsed with this function. 58 | """ 59 | return list(map(int, re.findall(r"(\d+)", raw))) 60 | 61 | 62 | def _range_from_match(match: tuple[str, str]) -> "range": 63 | if match[1]: 64 | return range(int(match[0]), int(match[1])) 65 | else: 66 | return range(int(match[0]), int(match[0])) 67 | 68 | 69 | def _irange_from_match(match: tuple[str, str]) -> "range": 70 | if match[1]: 71 | return range(int(match[0]), int(match[1]) + 1) 72 | else: 73 | return range(int(match[0]), int(match[0]) + 1) 74 | 75 | 76 | def extract_ranges(raw: str) -> "list[range]": 77 | """Utility function to extract all ranges from some string. 78 | 79 | Ranges are interpreted as `start-stop` and are not inclusive. 80 | 81 | Some inputs can be directly parsed with this function. 82 | """ 83 | return list(map(_range_from_match, re.findall(r"(\d+)(?:-(\d+))?", raw))) 84 | 85 | 86 | def extract_iranges(raw: str) -> "list[range]": 87 | """Utility function to extract all ranges from some string. 88 | 89 | Ranges are interpreted as `start-stop` and are inclusive. 90 | 91 | Some inputs can be directly parsed with this function. 92 | """ 93 | return list(map(_irange_from_match, re.findall(r"(\d+)(?:-(\d+))?", raw))) 94 | 95 | 96 | @typing.overload 97 | def chunk( 98 | iterable: Iterable[T], chunk_size: typing.Literal[2] 99 | ) -> "typing.Iterator[tuple[T, T]]": ... 100 | 101 | 102 | @typing.overload 103 | def chunk( 104 | iterable: Iterable[T], chunk_size: typing.Literal[3] 105 | ) -> "typing.Iterator[tuple[T, T, T]]": ... 106 | 107 | 108 | @typing.overload 109 | def chunk( 110 | iterable: Iterable[T], chunk_size: typing.Literal[4] 111 | ) -> "typing.Iterator[tuple[T, T, T, T]]": ... 112 | 113 | 114 | @typing.overload 115 | def chunk( 116 | iterable: Iterable[T], chunk_size: typing.Literal[5] 117 | ) -> "typing.Iterator[tuple[T, T, T, T, T]]": ... 118 | 119 | 120 | @typing.overload 121 | def chunk( 122 | iterable: Iterable[T], chunk_size: typing.Literal[6] 123 | ) -> "typing.Iterator[tuple[T, T, T, T, T, T]]": ... 124 | 125 | 126 | @typing.overload 127 | def chunk( 128 | iterable: Iterable[T], chunk_size: int 129 | ) -> "typing.Iterator[tuple[T, ...]]": ... 130 | 131 | 132 | def chunk(iterable: Iterable[T], chunk_size: int) -> "typing.Iterator[tuple[T, ...]]": 133 | """Utility function to chunk an iterable into chunks of a given size. 134 | 135 | If there are not enough elements in the iterable to fill the last chunk, 136 | the last chunk will be dropped. 137 | """ 138 | return zip(*[builtins.iter(iterable)] * chunk_size) 139 | 140 | 141 | @typing.overload 142 | def chunk_default( 143 | iterable: Iterable[T], chunk_size: typing.Literal[2], default: T 144 | ) -> "typing.Iterator[tuple[T, T]]": ... 145 | 146 | 147 | @typing.overload 148 | def chunk_default( 149 | iterable: Iterable[T], chunk_size: typing.Literal[3], default: T 150 | ) -> "typing.Iterator[tuple[T, T, T]]": ... 151 | 152 | 153 | @typing.overload 154 | def chunk_default( 155 | iterable: Iterable[T], chunk_size: typing.Literal[4], default: T 156 | ) -> "typing.Iterator[tuple[T, T, T, T]]": ... 157 | 158 | 159 | @typing.overload 160 | def chunk_default( 161 | iterable: Iterable[T], chunk_size: typing.Literal[5], default: T 162 | ) -> "typing.Iterator[tuple[T, T, T, T, T]]": ... 163 | 164 | 165 | @typing.overload 166 | def chunk_default( 167 | iterable: Iterable[T], chunk_size: typing.Literal[6], default: T 168 | ) -> "typing.Iterator[tuple[T, T, T, T, T, T]]": ... 169 | 170 | 171 | @typing.overload 172 | def chunk_default( 173 | iterable: Iterable[T], chunk_size: int, default: T 174 | ) -> "typing.Iterator[tuple[T, ...]]": ... 175 | 176 | 177 | def chunk_default( 178 | iterable: Iterable[T], chunk_size: int, default: T 179 | ) -> Iterable[tuple[T, ...]]: 180 | """Utility function to chunk an iterable into chunks of a given size. 181 | 182 | If there are not enough elements in the iterable to fill the last chunk, 183 | the missing elements will be replaced with the default value. 184 | """ 185 | return itertools.zip_longest( 186 | *[builtins.iter(iterable)] * chunk_size, fillvalue=default 187 | ) 188 | 189 | 190 | class list(typing.Generic[T], UserList[T]): 191 | """Smart/fluent list class""" 192 | 193 | _SENTINEL = object() 194 | 195 | def iter(self) -> "iter[T]": 196 | """Return an iterator over the list.""" 197 | return iter(self) 198 | 199 | def mapped(self, func: typing.Callable[[T], U]) -> "list[U]": 200 | """Return a list containing the result of calling func on each 201 | element in the list. The function is called on each element immediately. 202 | """ 203 | return list(map(func, self)) 204 | 205 | def starmapped( 206 | self: typing.Union["list[AnyIterable[Ts]]", "list[tuple[Unpack[Ts]]]"], 207 | func: typing.Callable[[Unpack[Ts]], U], 208 | ) -> "list[U]": 209 | """Return a list containing the result of calling func on each 210 | element in the list. The function is called on each element immediately. 211 | """ 212 | return list(itertools.starmap(func, self)) 213 | 214 | def mapped_each( 215 | self: "list[AnyIterable[SpecialisationT]]", 216 | func: typing.Callable[[SpecialisationT], U], 217 | ) -> "list[list[U]]": 218 | """Return a list containing the results of mapping each element of self 219 | with func. The function is called on each element immediately. 220 | """ 221 | return self.mapped(lambda i: list(map(func, i))) 222 | 223 | def starmapped_each( 224 | self: typing.Union[ 225 | "list[AnyIterable[AnyIterable[Ts]]]", 226 | "list[AnyIterable[tuple[Unpack[Ts]]]]", 227 | ], 228 | func: typing.Callable[[Unpack[Ts]], U], 229 | ) -> "list[list[U]]": 230 | """Return a list containing the results of mapping each element of self 231 | with func. The function is called on each element immediately. 232 | """ 233 | return self.mapped(lambda i: list(itertools.starmap(func, i))) 234 | 235 | def filtered( 236 | self, pred: typing.Union[typing.Callable[[T], bool], T, None] = None 237 | ) -> "list[T]": 238 | """Return a list containing only the elements for which pred 239 | returns True. 240 | 241 | If pred is None, return a list containing only elements that are 242 | truthy. 243 | 244 | If pred is a T (and T is not a callable or None), return a list 245 | containing only elements that compare equal to pred. 246 | """ 247 | if not callable(pred) and pred is not None: 248 | pred = (lambda j: lambda i: i == j)(pred) 249 | return list(filter(pred, self)) 250 | 251 | def find( 252 | self, pred: typing.Union[typing.Callable[[T], bool], T, None] = None 253 | ) -> typing.Optional[T]: 254 | """Return the first element of self for which pred returns True. 255 | 256 | If pred is None, return the first element which is truthy. 257 | 258 | If pred is a T (and T is not a callable or None), return the first element 259 | that compares equal to pred. 260 | 261 | If no such element exists, return None. 262 | """ 263 | if pred is None: 264 | pred = bool 265 | elif not callable(pred): 266 | pred = (lambda j: lambda i: i == j)(pred) 267 | for i in self: 268 | if pred(i): 269 | return i 270 | 271 | def any(self, pred: typing.Union[typing.Callable[[T], bool], T] = bool) -> bool: 272 | """Return True if any element of this list satisfies the given predicate. 273 | The default predicate is bool; therefore by default this method returns 274 | True if any element is truthy. 275 | """ 276 | if not callable(pred): 277 | pred = (lambda j: lambda i: i == j)(pred) 278 | return any(pred(item) for item in self) 279 | 280 | def all(self, pred: typing.Union[typing.Callable[[T], bool], T] = bool) -> bool: 281 | """Return True if all elements of this list satisfy the given predicate. 282 | The default predicate is bool; therefore by default this method returns 283 | True if all elements are truthy. 284 | """ 285 | if not callable(pred): 286 | pred = (lambda j: lambda i: i == j)(pred) 287 | return all(pred(item) for item in self) 288 | 289 | def none(self, pred: typing.Union[typing.Callable[[T], bool], T] = bool) -> bool: 290 | """Return True if no element of this list satisfies the given predicate. 291 | The default predicate is bool; therefore by default this method returns 292 | True if no element is truthy. 293 | """ 294 | if not callable(pred): 295 | pred = (lambda j: lambda i: i == j)(pred) 296 | return not any(pred(item) for item in self) 297 | 298 | @typing.overload 299 | def windowed(self, window_size: typing.Literal[2]) -> "list[tuple[T, T]]": ... 300 | 301 | @typing.overload 302 | def windowed(self, window_size: typing.Literal[3]) -> "list[tuple[T, T, T]]": ... 303 | 304 | @typing.overload 305 | def windowed(self, window_size: typing.Literal[4]) -> "list[tuple[T, T, T, T]]": ... 306 | 307 | @typing.overload 308 | def windowed( 309 | self, window_size: typing.Literal[5] 310 | ) -> "list[tuple[T, T, T, T, T]]": ... 311 | 312 | @typing.overload 313 | def windowed( 314 | self, window_size: typing.Literal[6] 315 | ) -> "list[tuple[T, T, T, T, T, T]]": ... 316 | 317 | @typing.overload 318 | def windowed(self, window_size: int) -> "list[tuple[T, ...]]": ... 319 | 320 | def windowed(self, window_size): 321 | """Return an list containing the elements of this list in 322 | a sliding window of size window_size. If there are not enough elements 323 | to create a full window, the list will be empty. 324 | """ 325 | return list(self.iter().window(window_size)) 326 | 327 | def shifted_zip(self, shift: int = 1) -> "iter[tuple[T, T]]": 328 | """Return an iterator containing pairs of elements separated by shift. 329 | 330 | If there are fewer than shift elements, the iterator will be empty. 331 | """ 332 | return iter(zip(self, self[shift:])) 333 | 334 | @typing.overload 335 | def reduce(self, func: typing.Callable[[T, T], T]) -> T: ... 336 | 337 | @typing.overload 338 | def reduce(self, func: typing.Callable[[U, T], U], initial: U) -> U: ... 339 | 340 | def reduce(self, func, initial=_SENTINEL): 341 | """Reduce the list to a single value, using the reduction 342 | function provided. 343 | """ 344 | if initial is self._SENTINEL: 345 | return functools.reduce(func, self) 346 | return functools.reduce(func, self, initial) 347 | 348 | @typing.overload 349 | def accumulated(self) -> "list[T]": ... 350 | 351 | @typing.overload 352 | def accumulated(self, func: typing.Callable[[T, T], T]) -> "list[T]": ... 353 | 354 | @typing.overload 355 | def accumulated( 356 | self, func: typing.Callable[[T, T], T], initial: T 357 | ) -> "list[T]": ... 358 | 359 | @typing.overload 360 | def accumulated( 361 | self, func: typing.Callable[[U, T], U], initial: U 362 | ) -> "list[U]": ... 363 | 364 | def accumulated(self, func=operator.add, initial=_SENTINEL): 365 | """Return the accumulated results of calling func on the elements in 366 | this list. 367 | 368 | initial is only usable on versions of Python equal to or greater than 3.8. 369 | """ 370 | if initial is self._SENTINEL: 371 | return list(itertools.accumulate(self, func)) 372 | return list(itertools.accumulate(self, func, initial)) # type: ignore 373 | 374 | @typing.overload 375 | def chunked(self, n: typing.Literal[2]) -> "list[tuple[T, T]]": ... 376 | 377 | @typing.overload 378 | def chunked(self, n: typing.Literal[3]) -> "list[tuple[T, T, T]]": ... 379 | 380 | @typing.overload 381 | def chunked(self, n: typing.Literal[4]) -> "list[tuple[T, T, T, T]]": ... 382 | 383 | @typing.overload 384 | def chunked(self, n: typing.Literal[5]) -> "list[tuple[T, T, T, T, T]]": ... 385 | 386 | @typing.overload 387 | def chunked(self, n: typing.Literal[6]) -> "list[tuple[T, T, T, T, T, T]]": ... 388 | 389 | @typing.overload 390 | def chunked(self, n: int) -> "list[tuple[T, ...]]": ... 391 | 392 | def chunked(self, n): 393 | """Return a list containing the elements of this list in chunks 394 | of size n. If there are not enough elements to fill the last chunk, it 395 | will be dropped. 396 | """ 397 | return list(chunk(self, n)) 398 | 399 | def chunked_default(self, n: int, default: T) -> "list[tuple[T, ...]]": 400 | """Return a list containing the elements of this list in chunks 401 | of size n. If there are not enough elements to fill the last chunk, the 402 | missing elements will be replaced with the default value. 403 | """ 404 | return list(chunk_default(self, n, default)) 405 | 406 | @typing.overload 407 | def sum( 408 | self: "list[SupportsSumNoDefaultT]", 409 | ) -> typing.Union[SupportsSumNoDefaultT, typing.Literal[0]]: ... 410 | 411 | @typing.overload 412 | def sum( 413 | self: "list[AddableT]", initial: AddableU 414 | ) -> typing.Union[AddableT, AddableU]: ... 415 | 416 | def sum(self, initial=_SENTINEL): 417 | """Return the sum of all elements in this list. 418 | 419 | If initial is provided, it is used as the initial value. 420 | """ 421 | # Pylance *hates* this method because the specialisation isn't provided on the implementation 422 | if initial is self._SENTINEL: 423 | return sum(self) # type: ignore 424 | return sum(self, initial) # type: ignore 425 | 426 | @typing.overload 427 | def prod( 428 | self: "list[SupportsProdNoDefaultT]", 429 | ) -> typing.Union[T, typing.Literal[1]]: ... 430 | 431 | @typing.overload 432 | def prod( 433 | self: "list[MultipliableT]", initial: MultipliableU 434 | ) -> typing.Union[MultipliableT, MultipliableU]: ... 435 | 436 | def prod(self, initial=_SENTINEL): 437 | """Return the product of all elements in this list. 438 | 439 | If initial is provided, it is used as the initial value. 440 | """ 441 | # Pylance *hates* this method because the specialisation isn't provided on the implementation 442 | if initial is self._SENTINEL: 443 | return math.prod(self) # type: ignore 444 | # math.prod isn't actually guaranteed to run for non-numerics, so we 445 | # have to ignore the type error here. 446 | return math.prod(self, start=initial) # type: ignore 447 | 448 | @typing.overload 449 | def sorted( 450 | self: "list[SupportsRichComparisonT]", 451 | *, 452 | reverse: bool = False, 453 | ) -> "list[SupportsRichComparisonT]": ... 454 | 455 | @typing.overload 456 | def sorted( 457 | self, 458 | key: typing.Callable[[T], SupportsRichComparison], 459 | reverse: bool = False, 460 | ) -> "list[T]": ... 461 | 462 | def sorted(self, key=None, reverse=False): # type: ignore 463 | """Return a list containing the elements of this list sorted 464 | according to the given key and reverse parameters. 465 | """ 466 | # I hate working with specialisations I should have just written a pyi 467 | result: builtins.list[T] = sorted(self, key=key, reverse=reverse) # type: ignore 468 | return list(result) 469 | 470 | def reversed(self) -> "list[T]": 471 | """Return a list containing the elements of this list in 472 | reverse order. 473 | """ 474 | return list(reversed(self)) 475 | 476 | @typing.overload 477 | def min( 478 | self: "list[SupportsRichComparisonT]", 479 | ) -> T: ... 480 | 481 | @typing.overload 482 | def min( 483 | self, 484 | key: typing.Callable[[T], SupportsRichComparisonT], 485 | ) -> T: ... 486 | 487 | def min(self, key=None) -> T: 488 | """Return the minimum element of this list, according to the given 489 | key. 490 | """ 491 | return min(self, key=key) # type: ignore 492 | 493 | @typing.overload 494 | def max( 495 | self: "list[SupportsRichComparisonT]", 496 | ) -> T: ... 497 | 498 | @typing.overload 499 | def max( 500 | self, 501 | key: typing.Callable[[T], SupportsRichComparisonT], 502 | ) -> T: ... 503 | 504 | def max(self, key=None) -> T: 505 | """Return the maximum element of this list, according to the given 506 | key. 507 | """ 508 | return max(self, key=key) # type: ignore 509 | 510 | def len(self) -> int: 511 | """Return the length of this list.""" 512 | return len(self) 513 | 514 | def mean(self: "list[SupportsMean]") -> SupportsMean: 515 | """Statistical mean of this list. 516 | 517 | T must be summable and divisible by an integer, 518 | and there must be at least one element in this list. 519 | """ 520 | if self.len() == 0: 521 | raise ValueError("Called mean() on an empty list") 522 | return self.sum() / self.len() # type: ignore 523 | 524 | @typing.overload 525 | def median(self: "list[SupportsRichComparisonT]") -> T: ... 526 | 527 | @typing.overload 528 | def median(self, key: typing.Callable[[T], SupportsRichComparisonT]) -> T: ... 529 | 530 | def median(self, key=None) -> T: 531 | """Statistical median of this list. 532 | 533 | T must be orderable and there must be at least one 534 | element in this list. 535 | Further more, if this list contains an odd number 536 | of elements, T must also be summable and divisible 537 | by an integer. 538 | """ 539 | if self.len() == 0: 540 | raise ValueError("Called median() on an empty list") 541 | if self.len() % 2: 542 | return self.sorted(key=key)[self.len() // 2] # type: ignore 543 | else: 544 | sorted_self = self.sorted(key=key) # type: ignore 545 | return (sorted_self[self.len() // 2] + sorted_self[self.len() // 2 - 1]) / 2 # type: ignore 546 | 547 | def mode(self) -> "list[T]": 548 | """Statistical mode of this list. 549 | 550 | T must be hashable and there must be at least one 551 | element in this list. 552 | """ 553 | if self.len() == 0: 554 | raise ValueError("Called mode() on an empty list") 555 | counted = Counter(self).most_common() 556 | n_ties = max(i[1] for i in counted) 557 | return list(i[0] for i in counted if i[1] == n_ties) 558 | 559 | @typing.overload 560 | def flat(self: "list[Iterable[SpecialisationT]]") -> "list[SpecialisationT]": ... 561 | 562 | @typing.overload 563 | def flat( 564 | self: "list[Iterable[SpecialisationT]]", 565 | recursive: typing.Literal[False] = False, 566 | ) -> "list[SpecialisationT]": ... 567 | 568 | @typing.overload 569 | def flat( 570 | self: "list[Iterable[MaybeIterator[SpecialisationT]]]", 571 | recursive: typing.Literal[True] = True, 572 | ) -> "list[SpecialisationT]": ... 573 | 574 | def flat(self: "list[Iterable[typing.Any]]", recursive=False): # type: ignore 575 | """Flattened version of this list. 576 | 577 | If recursive is specified, flattens recursively instead 578 | of by one layer. 579 | """ 580 | if not recursive: 581 | return list(item for list in self for item in list) 582 | return list( 583 | subitem 584 | for item in self 585 | for subitem in ( 586 | item.tee(1)[0].flatten(True) # type: ignore 587 | if isinstance(item, iter) 588 | else ( 589 | list(item).flat(True) # type: ignore 590 | if isinstance(item, (builtins.list, list)) 591 | else [item] 592 | ) 593 | ) 594 | ) 595 | 596 | def enumerated(self, start: int = 0) -> "list[tuple[int, T]]": 597 | return list(enumerate(self, start)) 598 | 599 | def deepcopy(self) -> "list[T]": 600 | return copy.deepcopy(self) 601 | 602 | def nlargest(self, n: int) -> "list[T]": 603 | """Return the n largest elements of self.""" 604 | return list(nlargest(n, self)) 605 | 606 | def nsmallest(self, n: int) -> "list[T]": 607 | """Return the n smallest elements of self.""" 608 | return list(nsmallest(n, self)) 609 | 610 | @typing.overload 611 | def transposition( 612 | self: "list[list[SpecialisationT]]", 613 | ) -> "list[list[SpecialisationT]]": ... 614 | 615 | @typing.overload 616 | def transposition( 617 | self: "list[Iterable[SpecialisationT]]", 618 | ) -> "list[list[SpecialisationT]]": ... 619 | 620 | @typing.overload 621 | def transposition( 622 | self: "list[tuple[SpecialisationT, ...]]", 623 | ) -> "list[list[SpecialisationT]]": ... 624 | 625 | @typing.overload 626 | def transposition( 627 | self: "list[builtins.list[SpecialisationT]]", 628 | ) -> "list[list[SpecialisationT]]": ... 629 | 630 | def transposition( 631 | self, 632 | ): 633 | """Return the transposition of this list, which is assumed to be 634 | rectangular, not ragged. If this list was ragged, then it will be 635 | cropped to the largest rectangle that is fully populated. 636 | 637 | This operation looks similar to a 90° rotation followed by a reflection: 638 | 639 | ABC 640 | DEF 641 | HIJ 642 | KLM 643 | 644 | transposes to: 645 | 646 | ADHK 647 | BEIL 648 | CFJM 649 | """ 650 | return list(zip(*self)).mapped(list) 651 | 652 | def into_grid(self: "list[list[SpecialisationT]]") -> "Grid[SpecialisationT]": 653 | """Convert this list, which is assumed to be rectangular, not ragged, 654 | into a Grid. 655 | 656 | This function converts directly; it doesn't copy - expect strange 657 | behaviour if you continue using self. 658 | """ 659 | return Grid(self) 660 | 661 | def into_queue(self) -> "PrioQueue[T]": 662 | """Convert this list into a PrioQueue. 663 | 664 | This function converts directly; it doesn't copy - expect strange 665 | behaviour if you continue using self. 666 | """ 667 | return PrioQueue(self.into_builtin()) 668 | 669 | def into_builtin(self) -> builtins.list[T]: 670 | """Unwrap this list into a builtins.list. 671 | 672 | This function converts directly; it doesn't copy - expect strange 673 | behaviour if you continue using self. 674 | """ 675 | return self.data 676 | 677 | def combinations(self, r: int) -> "list[tuple[T, ...]]": 678 | """Return a list over the combinations, without replacement, of 679 | length r of the elements of this list. 680 | """ 681 | return list(itertools.combinations(self, r)) 682 | 683 | def combinations_with_replacement(self, r: int) -> "list[tuple[T, ...]]": 684 | """Return a list over the combinations, with replacement, of 685 | length r of the elements of this list. 686 | """ 687 | return list(itertools.combinations_with_replacement(self, r)) 688 | 689 | def permutations(self, r: typing.Union[int, None] = None) -> "list[tuple[T, ...]]": 690 | """Return a list over the permutations of the elements of this 691 | list. 692 | 693 | If r is provided, the returned list will only contain permutations 694 | of size r. 695 | """ 696 | return list(itertools.permutations(self, r)) 697 | 698 | def divide(self, n: int) -> "list[list[T]]": 699 | """Divide this list into n equal-sized chunks.""" 700 | assert self.len() % n == 0 701 | chunk_size = self.len() // n 702 | return list(chunk(self, chunk_size)).mapped(list) 703 | 704 | def __repr__(self) -> str: 705 | return f"list({super().__repr__()})" 706 | 707 | 708 | class iter(typing.Generic[T_Co], typing.Iterator[T_Co], typing.Iterable[T_Co]): 709 | """Smart/fluent iterator class""" 710 | 711 | _SENTINEL = object() 712 | 713 | def __init__(self, it: Iterable[T_Co]) -> None: 714 | self.it = builtins.iter(it) 715 | 716 | def __iter__(self) -> typing.Iterator[T_Co]: 717 | return self.it.__iter__() 718 | 719 | def __next__(self) -> T_Co: 720 | return next(self.it) 721 | 722 | def map(self, func: typing.Callable[[T_Co], U]) -> "iter[U]": 723 | """Return an iterator containing the result of calling func on each 724 | element in this iterator. 725 | """ 726 | return iter(map(func, self)) 727 | 728 | def starmap( 729 | self: typing.Union["iter[Iterable[Ts]]", "iter[tuple[Unpack[Ts]]]"], 730 | func: typing.Callable[[Unpack[Ts]], U], 731 | ) -> "iter[U]": 732 | """Return an iterator containing the result of calling func on each 733 | element in this iterator. 734 | """ 735 | return iter(itertools.starmap(func, self)) 736 | 737 | def map_each( 738 | self: "iter[Iterable[SpecialisationT]]", 739 | func: typing.Callable[[SpecialisationT], U], 740 | ) -> "iter[iter[U]]": 741 | """Return an iterator containing the result of calling func on each 742 | element in each element in this iterator. 743 | """ 744 | return iter(self.map(lambda i: iter(i).map(func))) 745 | 746 | def starmap_each( 747 | self: typing.Union[ 748 | "iter[Iterable[AnyIterable[Ts]]]", 749 | "iter[Iterable[tuple[Unpack[Ts]]]]", 750 | ], 751 | func: typing.Callable[[Unpack[Ts]], U], 752 | ) -> "iter[iter[U]]": 753 | """Return an iterator containing the result of calling func on each 754 | element in each element in this iterator. 755 | """ 756 | return iter(self.map(lambda i: iter(i).starmap(func))) 757 | 758 | def filter( 759 | self, pred: typing.Union[typing.Callable[[T_Co], bool], T_Co] = bool 760 | ) -> "iter[T_Co]": 761 | """Return an iterator containing only the elements for which pred 762 | returns True. 763 | 764 | If pred is a T (and T is not callable), return an iterator 765 | containing only elements that compare equal to pred. 766 | """ 767 | if not callable(pred) and pred is not None: 768 | pred = (lambda j: lambda i: i == j)(pred) 769 | return iter(filter(pred, self)) 770 | 771 | def find( 772 | self, pred: typing.Union[typing.Callable[[T_Co], bool], T_Co, None] = None 773 | ) -> typing.Optional[T_Co]: 774 | """Return the first element of self for which pred returns True. 775 | 776 | If pred is None, return the first element which is truthy. 777 | 778 | If pred is a T (and T is not a callable or None), return the first element 779 | that compares equal to pred. 780 | 781 | If no such element exists, return None. 782 | """ 783 | if pred is None: 784 | pred = bool 785 | elif not callable(pred): 786 | pred = (lambda j: lambda i: i == j)(pred) 787 | for i in self: 788 | if pred(i): 789 | return i 790 | 791 | def any( 792 | self, pred: typing.Union[typing.Callable[[T_Co], bool], T_Co] = bool 793 | ) -> bool: 794 | """Consume this iterator and return True if any element satisfies the 795 | given predicate. The default predicate is bool; therefore by default this 796 | method returns True if any element is truthy. 797 | """ 798 | if not callable(pred): 799 | pred = (lambda j: lambda i: i == j)(pred) 800 | return any(pred(item) for item in self) 801 | 802 | def all( 803 | self, pred: typing.Union[typing.Callable[[T_Co], bool], T_Co] = bool 804 | ) -> bool: 805 | """Consume this iterator and return True if all elements satisfy the 806 | given predicate. The default predicate is bool; therefore by default this 807 | method returns True if all elements are truthy. 808 | """ 809 | if not callable(pred): 810 | pred = (lambda j: lambda i: i == j)(pred) 811 | return all(pred(item) for item in self) 812 | 813 | def none( 814 | self, pred: typing.Union[typing.Callable[[T_Co], bool], T_Co] = bool 815 | ) -> bool: 816 | """Consume this iterator and return True if no element satisfies the 817 | given predicate. The default predicate is bool; therefore by default this 818 | method returns True if no element is truthy. 819 | """ 820 | if not callable(pred): 821 | pred = (lambda j: lambda i: i == j)(pred) 822 | return not any(pred(item) for item in self) 823 | 824 | @typing.overload 825 | def reduce(self, func: typing.Callable[[T_Co, T_Co], T_Co]) -> T_Co: ... 826 | 827 | @typing.overload 828 | def reduce(self, func: typing.Callable[[U, T_Co], U], initial: U) -> U: ... 829 | 830 | def reduce(self, func, initial=_SENTINEL): 831 | """Reduce the iterator to a single value, using the reduction 832 | function provided. 833 | """ 834 | if initial is self._SENTINEL: 835 | return functools.reduce(func, self) 836 | return functools.reduce(func, self, initial) 837 | 838 | @typing.overload 839 | def accumulate(self) -> "iter[T_Co]": ... 840 | 841 | @typing.overload 842 | def accumulate(self, func: typing.Callable[[T_Co, T_Co], T_Co]) -> "iter[T_Co]": ... 843 | 844 | @typing.overload 845 | def accumulate( 846 | self, 847 | func: typing.Callable[[T_Co, T_Co], T_Co], 848 | initial: T_Co, # type: ignore 849 | ) -> "iter[T_Co]": ... 850 | 851 | @typing.overload 852 | def accumulate( 853 | self, func: typing.Callable[[U, T_Co], U], initial: U 854 | ) -> "iter[U]": ... 855 | 856 | def accumulate(self, func=operator.add, initial=_SENTINEL): 857 | """Return the accumulated results of calling func on the elements in 858 | this iterator. 859 | 860 | initial is only usable on versions of Python equal to or greater than 3.8. 861 | """ 862 | if initial is self._SENTINEL: 863 | return iter(itertools.accumulate(self, func)) 864 | return iter(itertools.accumulate(self, func, initial)) # type: ignore 865 | 866 | def foreach(self, func: typing.Callable[[T_Co], typing.Any]) -> None: 867 | """Run func on every value in this iterator, immediately.""" 868 | for el in self: 869 | func(el) 870 | 871 | @typing.overload 872 | def chunk(self, n: typing.Literal[2]) -> "iter[tuple[T_Co, T_Co]]": ... 873 | 874 | @typing.overload 875 | def chunk(self, n: typing.Literal[3]) -> "iter[tuple[T_Co, T_Co, T_Co]]": ... 876 | 877 | @typing.overload 878 | def chunk(self, n: typing.Literal[4]) -> "iter[tuple[T_Co, T_Co, T_Co, T_Co]]": ... 879 | 880 | @typing.overload 881 | def chunk( 882 | self, n: typing.Literal[5] 883 | ) -> "iter[tuple[T_Co, T_Co, T_Co, T_Co, T_Co]]": ... 884 | 885 | @typing.overload 886 | def chunk( 887 | self, n: typing.Literal[6] 888 | ) -> "iter[tuple[T_Co, T_Co, T_Co, T_Co, T_Co, T_Co]]": ... 889 | 890 | @typing.overload 891 | def chunk(self, n: int) -> "iter[tuple[T_Co, ...]]": ... 892 | 893 | def chunk(self, n): 894 | """Return an iterator containing the elements of this iterator in chunks 895 | of size n. If there are not enough elements to fill the last chunk, it 896 | will be dropped. 897 | """ 898 | return iter(chunk(self, n)) 899 | 900 | def chunk_default(self, n: int, default: T_Co) -> "iter[tuple[T_Co, ...]]": # type: ignore 901 | """Return an iterator containing the elements of this iterator in chunks 902 | of size n. If there are not enough elements to fill the last chunk, the 903 | missing elements will be replaced with the default value. 904 | """ 905 | return iter(chunk_default(self, n, default)) 906 | 907 | def _window( 908 | self, window_size: int 909 | ) -> typing.Generator[tuple[T_Co, ...], None, None]: 910 | elements: typing.Deque[T_Co] = deque() 911 | for _ in range(window_size): 912 | try: 913 | elements.append(self.next()) 914 | except StopIteration: 915 | return 916 | 917 | yield tuple(elements) 918 | 919 | for el in self: 920 | elements.popleft() 921 | elements.append(el) 922 | yield tuple(elements) 923 | 924 | @typing.overload 925 | def window(self, window_size: typing.Literal[2]) -> "iter[tuple[T_Co, T_Co]]": ... 926 | 927 | @typing.overload 928 | def window( 929 | self, window_size: typing.Literal[3] 930 | ) -> "iter[tuple[T_Co, T_Co, T_Co]]": ... 931 | 932 | @typing.overload 933 | def window( 934 | self, window_size: typing.Literal[4] 935 | ) -> "iter[tuple[T_Co, T_Co, T_Co, T_Co]]": ... 936 | 937 | @typing.overload 938 | def window( 939 | self, window_size: typing.Literal[5] 940 | ) -> "iter[tuple[T_Co, T_Co, T_Co, T_Co, T_Co]]": ... 941 | 942 | @typing.overload 943 | def window( 944 | self, window_size: typing.Literal[6] 945 | ) -> "iter[tuple[T_Co, T_Co, T_Co, T_Co, T_Co, T_Co]]": ... 946 | 947 | @typing.overload 948 | def window(self, window_size: int) -> "iter[tuple[T_Co, ...]]": ... 949 | 950 | def window(self, window_size): 951 | """Return an iterator containing the elements of this iterator in 952 | a sliding window of size window_size. If there are not enough elements 953 | to create a full window, the iterator will be empty. 954 | """ 955 | return iter(self._window(window_size)) 956 | 957 | def shifted_zip(self, shift: int = 1) -> "iter[tuple[T_Co, T_Co]]": 958 | """Return an iterator containing pairs of elements separated by shift. 959 | 960 | If there are fewer than shift elements, the iterator will be empty. 961 | """ 962 | return self.window(shift + 1).map(lambda x: (x[0], x[-1])) 963 | 964 | def next(self) -> T_Co: 965 | """Return the next element in the iterator, or raise StopIteration.""" 966 | return next(self) 967 | 968 | @typing.overload 969 | def next_or(self, default: T_Co) -> T_Co: # type: ignore 970 | ... 971 | 972 | @typing.overload 973 | def next_or(self, default: U) -> typing.Union[T_Co, U]: ... 974 | 975 | def next_or(self, default): 976 | """Return the next element in the iterator, or default.""" 977 | try: 978 | return next(self, default) 979 | except StopIteration: 980 | return default 981 | 982 | def skip(self, n: int = 1) -> "iter[T_Co]": 983 | """Skip and discard n elements from this iterator. 984 | 985 | Raises StopIteration if there are not enough elements. 986 | """ 987 | for _ in builtins.range(n): 988 | self.next() 989 | return self 990 | 991 | def nth(self, n: int) -> T_Co: 992 | """Return the nth element of this iterator. 993 | 994 | Discards all elements up to the nth element, and raises StopIteration 995 | if there are not enough elements. 996 | """ 997 | self.skip(n) 998 | return self.next() 999 | 1000 | def take(self, n: int) -> "iter[T_Co]": 1001 | """Return the next n elements of this iterator as an iterator.""" 1002 | return iter(self.next() for _ in builtins.range(n)) 1003 | 1004 | @typing.overload 1005 | def collect(self) -> list[T_Co]: ... 1006 | 1007 | @typing.overload # TODO: why doesn't this work? 1008 | def collect(self, collection_type: typing.Type[U]) -> "U[T_Co]": # type: ignore 1009 | ... 1010 | 1011 | def collect(self, collection_type=None): 1012 | """Return a list containing all remaining elements of this iterator.""" 1013 | if collection_type is None: 1014 | collection_type = list 1015 | return collection_type(self) 1016 | 1017 | def chain(self, other: Iterable[T_Co]) -> "iter[T_Co]": 1018 | """Return an iterator containing the elements of this iterator followed 1019 | by the elements of other. 1020 | """ 1021 | return iter(itertools.chain(self, other)) 1022 | 1023 | @typing.overload 1024 | def sum( 1025 | self: "iter[SupportsSumNoDefaultT]", 1026 | ) -> typing.Union[SupportsSumNoDefaultT, typing.Literal[0]]: ... 1027 | 1028 | @typing.overload 1029 | def sum( 1030 | self: "iter[AddableT]", initial: AddableU 1031 | ) -> typing.Union[AddableT, AddableU]: ... 1032 | 1033 | def sum(self, initial=_SENTINEL): 1034 | """Return the sum of all elements in this iterator. 1035 | 1036 | If initial is provided, it is used as the initial value. 1037 | """ 1038 | if initial is self._SENTINEL: 1039 | return sum(self) # type: ignore 1040 | # sum isn't actually guaranteed to run for non-numerics, so we have to 1041 | # ignore the type error here. 1042 | return sum(self, initial) # type: ignore 1043 | 1044 | @typing.overload 1045 | def prod( 1046 | self: "iter[SupportsProdNoDefaultT]", 1047 | ) -> typing.Union[T_Co, typing.Literal[1]]: ... 1048 | 1049 | @typing.overload 1050 | def prod( 1051 | self: "iter[MultipliableT]", initial: MultipliableU 1052 | ) -> typing.Union[MultipliableT, MultipliableU]: ... 1053 | 1054 | def prod(self, initial=_SENTINEL): 1055 | """Return the product of all elements in this iterator. 1056 | 1057 | If initial is provided, it is used as the initial value. 1058 | """ 1059 | if initial is self._SENTINEL: 1060 | return math.prod(self) # type: ignore 1061 | # math.prod isn't actually guaranteed to run for non-numerics, so we 1062 | # have to ignore the type error here. 1063 | return math.prod(self, start=initial) # type: ignore 1064 | 1065 | @typing.overload 1066 | def sorted( 1067 | self: "iter[SupportsRichComparisonT]", 1068 | *, 1069 | reverse: bool = False, 1070 | ) -> "list[SupportsRichComparisonT]": ... 1071 | 1072 | @typing.overload 1073 | def sorted( 1074 | self, 1075 | key: typing.Callable[[T_Co], SupportsRichComparison], 1076 | reverse: bool = False, 1077 | ) -> "list[T_Co]": ... 1078 | 1079 | def sorted(self, key=None, reverse=False): # type: ignore 1080 | """Return a list containing the elements of this iterator sorted 1081 | according to the given key and reverse parameters. 1082 | """ 1083 | result: builtins.list[T_Co] = sorted(self, key=key, reverse=reverse) # type: ignore 1084 | return list(result) 1085 | 1086 | def reversed(self) -> "iter[T_Co]": 1087 | """Return an iterator containing the elements of this iterator in 1088 | reverse order. 1089 | """ 1090 | return iter(reversed(list(self))) 1091 | 1092 | @typing.overload 1093 | def min( 1094 | self: "iter[SupportsRichComparisonT]", 1095 | ) -> T_Co: ... 1096 | 1097 | @typing.overload 1098 | def min( 1099 | self, 1100 | key: typing.Callable[[T_Co], SupportsRichComparisonT], 1101 | ) -> T_Co: ... 1102 | 1103 | def min(self, key=None) -> T_Co: 1104 | """Return the minimum element of this iterator, according to the given 1105 | key. 1106 | """ 1107 | return min(self, key=key) # type: ignore 1108 | 1109 | @typing.overload 1110 | def max( 1111 | self: "iter[SupportsRichComparisonT]", 1112 | ) -> T_Co: ... 1113 | 1114 | @typing.overload 1115 | def max( 1116 | self, 1117 | key: typing.Callable[[T_Co], SupportsRichComparisonT], 1118 | ) -> T_Co: ... 1119 | 1120 | def max(self, key=None) -> T_Co: 1121 | """Return the maximum element of this iterator, according to the given 1122 | key. 1123 | """ 1124 | return max(self, key=key) # type: ignore 1125 | 1126 | def tee(self, n: int = 2) -> tuple["iter[T_Co]", ...]: 1127 | """Return a tuple of n iterators containing the elements of this 1128 | iterator. 1129 | """ 1130 | self.it, *iterators = itertools.tee(self, n + 1) 1131 | return tuple(iter(iterator) for iterator in iterators) 1132 | 1133 | def permutations( 1134 | self, r: typing.Union[int, None] = None 1135 | ) -> "iter[tuple[T_Co, ...]]": 1136 | """Return an iterator over the permutations of the elements of this 1137 | iterator. 1138 | 1139 | If r is provided, the returned iterator will only contain permutations 1140 | of size r. 1141 | """ 1142 | return iter(itertools.permutations(self, r)) 1143 | 1144 | def combinations(self, r: int) -> "iter[tuple[T_Co, ...]]": 1145 | """Return an iterator over the combinations, without replacement, of 1146 | length r of the elements of this iterator. 1147 | """ 1148 | return iter(itertools.combinations(self, r)) 1149 | 1150 | def combinations_with_replacement(self, r: int) -> "iter[tuple[T_Co, ...]]": 1151 | """Return an iterator over the combinations, with replacement, of 1152 | length r of the elements of this iterator. 1153 | """ 1154 | return iter(itertools.combinations_with_replacement(self, r)) 1155 | 1156 | @typing.overload 1157 | def flatten( 1158 | self: "iter[Iterable[SpecialisationT]]", 1159 | ) -> "iter[SpecialisationT]": ... 1160 | 1161 | @typing.overload 1162 | def flatten( 1163 | self: "iter[Iterable[SpecialisationT]]", 1164 | recursive: typing.Literal[False] = False, 1165 | ) -> "iter[SpecialisationT]": ... 1166 | 1167 | @typing.overload 1168 | def flatten( 1169 | self: "iter[Iterable[MaybeIterator[SpecialisationT]]]", 1170 | recursive: typing.Literal[True] = True, 1171 | ) -> "iter[SpecialisationT]": ... 1172 | 1173 | def flatten( 1174 | self: "iter[Iterable[MaybeIterator[SpecialisationT]]]", 1175 | recursive: bool = False, 1176 | ) -> "iter[typing.Any]": 1177 | """Flatten this iterator. 1178 | 1179 | If recursive is specified, flattens recursively instead 1180 | of by one layer. 1181 | """ 1182 | if not recursive: 1183 | return iter(item for iterator in self for item in iterator) 1184 | return iter( 1185 | item 1186 | for iterator in self 1187 | for item in ( 1188 | iterator.flatten(True) # type: ignore 1189 | if isinstance(iterator, iter) 1190 | else ( 1191 | list(iterator).flat(True) # type: ignore 1192 | if isinstance(iterator, (builtins.list, list)) 1193 | else [iterator] 1194 | ) 1195 | ) 1196 | ) 1197 | 1198 | def enumerate(self, start: int = 0) -> "iter[tuple[int, T_Co]]": 1199 | """Return an iterator over the elements of this iterator, paired with 1200 | their index, starting at start. 1201 | """ 1202 | return iter(enumerate(self, start)) 1203 | 1204 | def count(self) -> int: 1205 | """Consume this iterator and return the number of elements it contained.""" 1206 | return self.map(lambda _: 1).sum() 1207 | 1208 | def nlargest(self, n: int) -> list[T_Co]: 1209 | """Consume this iterator and return the n largest elements.""" 1210 | return list(nlargest(n, self)) 1211 | 1212 | def nsmallest(self, n: int) -> list[T_Co]: 1213 | """Consume this iterator and return the n smallest elements.""" 1214 | return list(nsmallest(n, self)) 1215 | 1216 | def __repr__(self) -> str: 1217 | return f"iter({self.it!r})" 1218 | 1219 | 1220 | class range(iter[int]): 1221 | _SENTINEL = object() 1222 | 1223 | def __init__( 1224 | self, 1225 | start: int, 1226 | stop: int = _SENTINEL, # type: ignore 1227 | step: int = 1, 1228 | ): 1229 | if step == 0: 1230 | raise ValueError("Step size must not be 0") 1231 | if stop is range._SENTINEL: 1232 | stop = start 1233 | start = 0 1234 | self.start = start 1235 | self.stop = stop 1236 | self.step = step 1237 | 1238 | def min(self) -> int: 1239 | if not self: 1240 | raise ValueError("Called min() on an empty range") 1241 | if self.step > 0: 1242 | return self.start 1243 | else: 1244 | x = (self.stop - self.start) // self.step 1245 | return self.start + x * self.step 1246 | 1247 | def __iter__(self) -> typing.Iterator[int]: 1248 | return builtins.iter(builtins.range(self.start, self.stop, self.step)) 1249 | 1250 | def __next__(self) -> int: 1251 | raise ValueError("range object is not an iterator") 1252 | 1253 | def __len__(self) -> int: 1254 | return builtins.len(builtins.range(self.start, self.stop, self.step)) 1255 | 1256 | def __contains__(self, item: int) -> bool: 1257 | return item in builtins.range(self.start, self.stop, self.step) 1258 | 1259 | def __and__(self, other: "range") -> "range": 1260 | if not isinstance(other, range): 1261 | return NotImplemented 1262 | if self.step != other.step: 1263 | raise ValueError("Step sizes must match") 1264 | if not (self.start in other or other.start in self): 1265 | return range(0, 0) 1266 | if self.step > 0: 1267 | return range( 1268 | max(self.start, other.start), min(self.stop, other.stop), self.step 1269 | ) 1270 | else: 1271 | return range( 1272 | min(self.start, other.start), max(self.stop, other.stop), self.step 1273 | ) 1274 | 1275 | def __or__(self, other: "range") -> typing.Union["range", "multirange"]: 1276 | if not isinstance(other, range): 1277 | return NotImplemented 1278 | return self + other 1279 | 1280 | def __xor__(self, other: "range") -> "multirange": 1281 | if not isinstance(other, range): 1282 | return NotImplemented 1283 | if self.step != other.step: 1284 | raise ValueError("Step sizes must match") 1285 | return multirange(self - other, other - self) 1286 | 1287 | def __sub__( 1288 | self, other: typing.Union[int, "range"] 1289 | ) -> typing.Union["range", "multirange"]: 1290 | if isinstance(other, int): 1291 | return range(self.start - other, self.stop - other, self.step) 1292 | elif isinstance(other, range): 1293 | if self.step != other.step: 1294 | raise ValueError("Step sizes must match") 1295 | if not (self.start in other or other.start in self): 1296 | # no intersection 1297 | return self 1298 | elif self.start == other.start: 1299 | if (self.stop <= other.stop and self.step > 0) or ( 1300 | self.stop >= other.stop and self.step < 0 1301 | ): 1302 | return multirange() 1303 | else: 1304 | return range(other.stop, self.stop, self.step) 1305 | if self.step > 0: 1306 | if self.start < other.start: 1307 | if self.stop <= other.stop: 1308 | # self.start other.start self.stop other.stop 1309 | return range(self.start, other.start, self.step) 1310 | else: 1311 | # self.start other.start other.stop self.stop 1312 | return multirange( 1313 | range(self.start, other.start, self.step), 1314 | range(other.stop, self.stop, self.step), 1315 | ) 1316 | elif self.stop >= other.stop: 1317 | # other.start self.start other.stop self.stop 1318 | return range(other.stop, self.stop, self.step) 1319 | else: 1320 | # other.start self.start self.stop other.stop 1321 | return multirange() 1322 | 1323 | else: 1324 | if self.start > other.start: 1325 | if self.stop >= other.stop: 1326 | # other.stop self.stop other.start self.start 1327 | return range(self.start, other.start, self.step) 1328 | else: 1329 | # other.stop self.stop self.start other.start 1330 | return multirange() 1331 | elif self.stop <= other.stop: 1332 | # other.stop self.stop other.start self.start 1333 | return range(other.stop, self.stop, self.step) 1334 | else: 1335 | # self.stop other.stop other.start self.start 1336 | return multirange( 1337 | range(self.start, other.start, self.step), 1338 | range(other.stop, self.stop, self.step), 1339 | ) 1340 | else: 1341 | return NotImplemented 1342 | 1343 | def __add__( 1344 | self, other: typing.Union[int, "range"] 1345 | ) -> typing.Union["range", "multirange"]: 1346 | if isinstance(other, int): 1347 | return range(self.start + other, self.stop + other, self.step) 1348 | elif isinstance(other, range): 1349 | if self.step != other.step: 1350 | raise ValueError("Step sizes must match") 1351 | if not ( 1352 | self.start in other 1353 | or other.start in self 1354 | or (self.stop == other.start and other.start - self.step in self) 1355 | or (self.start == other.stop and self.start - self.step in other) 1356 | ): 1357 | # no intersection 1358 | return multirange(self, other) 1359 | if self.step > 0: 1360 | return range( 1361 | min(self.start, other.start), max(self.stop, other.stop), self.step 1362 | ) 1363 | else: 1364 | return range( 1365 | max(self.start, other.start), min(self.stop, other.stop), self.step 1366 | ) 1367 | else: 1368 | return NotImplemented 1369 | 1370 | def __bool__(self) -> bool: 1371 | return bool(builtins.range(self.start, self.stop, self.step)) 1372 | 1373 | def __eq__(self, other: object) -> bool: 1374 | if not isinstance(other, range): 1375 | return NotImplemented 1376 | if not self: 1377 | return not other 1378 | return ( 1379 | self.start == other.start 1380 | and self.step == other.step 1381 | and len(self) == len(other) 1382 | ) 1383 | 1384 | def __repr__(self) -> str: 1385 | return f"range({self.start}, {self.stop}" + ( 1386 | f", {self.step})" if self.step != 1 else ")" 1387 | ) 1388 | 1389 | 1390 | if typing.TYPE_CHECKING: 1391 | # make type-checkers allow interchangeable use of range and multirange 1392 | __multirange_base = range 1393 | else: 1394 | __multirange_base = iter[T] 1395 | 1396 | 1397 | class multirange(__multirange_base): 1398 | """Multirange class. Represents many disjoint ranges of integers. Step sizes 1399 | must always be 1. 1400 | """ 1401 | 1402 | def __init__(self, *ranges: typing.Union[range, "multirange"]): 1403 | self.ranges: builtins.list[range] = [] 1404 | for the_range in ranges: 1405 | if isinstance(the_range, multirange): 1406 | self.ranges.extend(the_range.ranges) 1407 | else: 1408 | self.ranges.append(the_range) 1409 | self.simplify_ranges() 1410 | 1411 | def __iter__(self) -> typing.Iterator[int]: 1412 | return itertools.chain(*self.ranges) 1413 | 1414 | def __next__(self) -> int: 1415 | raise ValueError("multirange object is not an iterator") 1416 | 1417 | def __len__(self) -> int: 1418 | return sum(len(range) for range in self.ranges) 1419 | 1420 | def __contains__(self, item: int) -> bool: 1421 | return any(item in r for r in self.ranges) 1422 | 1423 | def simplify_ranges(self): 1424 | self.ranges.sort(key=lambda range: range.start) 1425 | (*self.ranges,) = filter(None, self.ranges) 1426 | if not self.ranges: 1427 | return 1428 | last_range = self.ranges[0] 1429 | out_ranges = [last_range] 1430 | for range_ in self.ranges: 1431 | if range_.step != 1: 1432 | raise ValueError("Step sizes must be 1 for all ranges in a multirange") 1433 | if range_.start >= range_.stop: 1434 | continue 1435 | if range_.start <= last_range.stop: 1436 | last_range.stop = max(range_.stop, last_range.stop) 1437 | out_ranges[-1] = last_range 1438 | else: 1439 | last_range = range_ 1440 | out_ranges.append(last_range) 1441 | self.ranges = out_ranges 1442 | 1443 | def min(self) -> int: 1444 | if self.ranges: 1445 | return self.ranges[0].start 1446 | else: 1447 | raise ValueError("Called min() on an empty multirange") 1448 | 1449 | def __and__(self, other: typing.Union["multirange", range]) -> "multirange": 1450 | if isinstance(other, multirange): 1451 | return multirange(*(r & s for r in self.ranges for s in other.ranges)) 1452 | elif isinstance(other, range): 1453 | return multirange(*(r & other for r in self.ranges)) 1454 | else: 1455 | return NotImplemented 1456 | 1457 | def __rand__(self, other: typing.Union["multirange", range]) -> "multirange": 1458 | return self & other 1459 | 1460 | def __or__(self, other: typing.Union["multirange", range]) -> "multirange": 1461 | if isinstance(other, multirange): 1462 | return multirange(*self.ranges, *other.ranges) 1463 | elif isinstance(other, range): 1464 | return multirange(*self.ranges, other) 1465 | else: 1466 | return NotImplemented 1467 | 1468 | def __ror__(self, other: typing.Union["multirange", range]) -> "multirange": 1469 | return self | other 1470 | 1471 | def __xor__(self, other: typing.Union["multirange", range]) -> "multirange": 1472 | if isinstance(other, (range, multirange)): 1473 | return multirange(self - other, other - self) 1474 | else: 1475 | return NotImplemented 1476 | 1477 | def __rxor__(self, other: typing.Union["multirange", range]) -> "multirange": 1478 | return self ^ other 1479 | 1480 | def __sub__(self, other: typing.Union[int, "multirange", range]) -> "multirange": 1481 | if isinstance(other, int): 1482 | return multirange(*(r - other for r in self.ranges)) 1483 | elif isinstance(other, multirange): 1484 | result = [] 1485 | for r in self.ranges: 1486 | for s in other.ranges: 1487 | r = r - s 1488 | result.append(r) 1489 | return multirange(*result) 1490 | elif isinstance(other, range): 1491 | return multirange(*(r - other for r in self.ranges)) 1492 | else: 1493 | return NotImplemented 1494 | 1495 | def __add__(self, other: typing.Union[int, "multirange", range]) -> "multirange": 1496 | if isinstance(other, int): 1497 | return multirange(*(r + other for r in self.ranges)) 1498 | elif isinstance(other, multirange): 1499 | return multirange(*self.ranges, *other.ranges) 1500 | elif isinstance(other, range): 1501 | return multirange(*self.ranges, other) 1502 | else: 1503 | return NotImplemented 1504 | 1505 | def __radd__(self, other: typing.Union[int, "multirange", range]) -> "multirange": 1506 | return self + other 1507 | 1508 | def __bool__(self) -> bool: 1509 | return any(self.ranges) 1510 | 1511 | def __eq__(self, other: object) -> bool: 1512 | if isinstance(other, multirange): 1513 | return self.ranges == other.ranges 1514 | elif isinstance(other, range): 1515 | return self.ranges == [other] 1516 | else: 1517 | return NotImplemented 1518 | 1519 | def __repr__(self) -> str: 1520 | return f"multirange({self.ranges})" 1521 | 1522 | 1523 | if not typing.TYPE_CHECKING: 1524 | range = functools.wraps(builtins.range, updated=())(range) 1525 | 1526 | 1527 | @functools.wraps(builtins.map) 1528 | def map(*args, **kw): 1529 | return iter(builtins.map(*args, **kw)) 1530 | 1531 | 1532 | def irange(start: int, stop: int) -> range: 1533 | """Inclusive range. Returns an iterator that 1534 | yields values from start to stop, including both 1535 | endpoints, stepping by one. Works even when 1536 | stop < start (the iterator will step backwards). 1537 | """ 1538 | if start <= stop: 1539 | return range(start, stop + 1) 1540 | else: 1541 | return range(start, stop - 1, -1) 1542 | 1543 | 1544 | def _frange( 1545 | start: float, stop: float, step: float 1546 | ) -> typing.Generator[float, None, None]: 1547 | if step == 0.0: 1548 | raise ValueError("frange() arg 3 must not be zero") 1549 | if step > 0: 1550 | while start < stop: 1551 | yield start 1552 | start += step 1553 | else: 1554 | while start > stop: 1555 | yield start 1556 | start += step 1557 | 1558 | 1559 | def frange(start: float, stop: float, step: float = 0.1) -> iter[float]: 1560 | """Float range. Returns an iterator that yields values 1561 | from start (inclusive) to stop (exclusive), changing by step. 1562 | """ 1563 | return iter(_frange(start, stop, step)) 1564 | 1565 | 1566 | class TailRecursionDetected(Exception): 1567 | def __init__(self, args, kwargs): 1568 | self.args = args 1569 | self.kwargs = kwargs 1570 | 1571 | 1572 | def tail_call(func: typing.Callable[P, U]) -> typing.Callable[P, U]: 1573 | """Add tail recursion optimisation to func. 1574 | 1575 | Useful for avoiding RecursionErrors. 1576 | 1577 | This is done by throwing an exception 1578 | if the wrapper is its own grandparent (i.e. the wrapped 1579 | function would be its own parent), and catching such 1580 | exceptions to fake the tail call optimisation. 1581 | 1582 | func will behave strangely if the decorated 1583 | function recurses in a non-tail context. 1584 | """ 1585 | 1586 | @functools.wraps(func) 1587 | def wrapped(*args: P.args, **kwargs: P.kwargs): 1588 | f = sys._getframe() 1589 | if f.f_back and f.f_back.f_back and f.f_back.f_back.f_code == f.f_code: 1590 | raise TailRecursionDetected(args, kwargs) 1591 | else: 1592 | while 1: 1593 | try: 1594 | return func(*args, **kwargs) 1595 | except TailRecursionDetected as e: 1596 | args = e.args # type: ignore 1597 | kwargs = e.kwargs # type: ignore 1598 | raise Exception("unreachable") 1599 | 1600 | return wrapped 1601 | 1602 | 1603 | LetterRow = tuple[ 1604 | bool, 1605 | bool, 1606 | bool, 1607 | bool, 1608 | bool, 1609 | ] 1610 | Letter = tuple[ 1611 | LetterRow, 1612 | LetterRow, 1613 | LetterRow, 1614 | LetterRow, 1615 | LetterRow, 1616 | LetterRow, 1617 | ] 1618 | 1619 | 1620 | def encode_letter(dots: Letter) -> int: 1621 | """Encode a matrix of dots to an integer for efficient 1622 | storage and lookup. Not expected to be used outside of 1623 | this module and contributions to the lookup table. 1624 | 1625 | The matrix of dots should be 6 tall and 5 wide. 1626 | """ 1627 | # Letters are 4 dots wide; the 5th column should always be empty. 1628 | # This function assumes that input is not malformed; any dots in the 1629 | # 5th column are treated as if they are in the 1st column of the next 1630 | # row. 1631 | # If something includes the 5th column it is malformed, but this 1632 | # function will not check. 1633 | out = 0 1634 | for y, row in enumerate(dots): 1635 | for x, dot in enumerate(row): 1636 | if dot: 1637 | out |= 1 << (x + 4 * y) 1638 | return out 1639 | 1640 | 1641 | LETTERS: dict[int, str] = { 1642 | # todo: fill in this lookup table 1643 | 0: " ", 1644 | 10090902: "A", 1645 | 7968663: "B", 1646 | 6885782: "C", 1647 | 15800095: "E", 1648 | 1120031: "F", 1649 | 15323542: "G", 1650 | 10067865: "H", 1651 | 14959694: "I", 1652 | 6916236: "J", 1653 | 9786201: "K", 1654 | 15798545: "L", 1655 | 6920598: "O", 1656 | 1145239: "P", 1657 | 9795991: "R", 1658 | 7889182: "S", 1659 | 6920601: "U", 1660 | # 4475409: "Y", 1661 | 15803535: "Z", 1662 | } 1663 | 1664 | 1665 | def decode_letter(dots: Letter) -> str: 1666 | """Decode a matrix of dots to a single letter. 1667 | 1668 | The matrix of dots should be 6 tall and 5 wide. 1669 | """ 1670 | encoded = encode_letter(dots) 1671 | if encoded not in LETTERS: 1672 | print("Unrecognised letter:", encoded) 1673 | for row in dots: 1674 | for dot in row: 1675 | print(" #"[dot], end="") 1676 | print() 1677 | print("Please consider contributing this to the lookup table:") 1678 | print("https://github.com/starwort/aoc_helper") 1679 | return "?" 1680 | return LETTERS[encoded] 1681 | 1682 | 1683 | def decode_text(dots: builtins.list[builtins.list[bool]]) -> str: 1684 | """Decode a matrix of dots to text. 1685 | 1686 | The matrix of dots should be 6 tall and 5n - 1 wide. 1687 | """ 1688 | broken_rows = [list(chunk_default(row, 5, False)) for row in dots] 1689 | letters = list(zip(*broken_rows)) 1690 | out = "".join(decode_letter(letter) for letter in letters) 1691 | assert "?" not in out, f"Output {out} contained unrecognised letters!" 1692 | return out 1693 | 1694 | 1695 | def _default_classifier(char: str, /) -> int: 1696 | if char in "0123456789": 1697 | return int(char) 1698 | elif char in ".#": 1699 | return ".#".index(char) 1700 | else: 1701 | raise ValueError(f"Could not classify {char}. Please use a custom classifier.") 1702 | 1703 | 1704 | class Grid(typing.Generic[T]): 1705 | data: list[list[T]] 1706 | 1707 | def __init__(self, data: list[list[T]]) -> None: 1708 | self.data = data 1709 | 1710 | @classmethod 1711 | def from_string( 1712 | cls, data: str, classify: typing.Callable[[str], U] = _default_classifier 1713 | ) -> "Grid[U]": 1714 | """Create a grid from a string (e.g. a puzzle input). 1715 | 1716 | Can take a classifier to use a custom classification. The default will 1717 | map numbers from 0 to 9 to themselves, and . and # to 0 and 1 respectively. 1718 | """ 1719 | return Grid(list(data.splitlines()).mapped(lambda i: list(i).mapped(classify))) 1720 | 1721 | @property 1722 | def width(self) -> int: 1723 | """ 1724 | Return the width of the grid. Will not be correct if the underlying 1725 | store is ragged. 1726 | """ 1727 | return len(self.data[0]) 1728 | 1729 | @property 1730 | def height(self) -> int: 1731 | """Return the height of the grid.""" 1732 | return len(self.data) 1733 | 1734 | def find_all( 1735 | self, other: "typing.Union[SparseGrid[T], T]" 1736 | ) -> iter[tuple[int, int]]: 1737 | """Find all occurrences of other in self.""" 1738 | 1739 | if not isinstance(other, SparseGrid): 1740 | _other = other 1741 | other = SparseGrid(lambda: _other) 1742 | other[0, 0] = _other 1743 | 1744 | def find(): 1745 | for y, row in enumerate(self.data): 1746 | for x, _ in enumerate(row): 1747 | if self.contains_at( 1748 | x, 1749 | y, 1750 | other, 1751 | ): 1752 | yield (x, y) 1753 | 1754 | return iter(find()) 1755 | 1756 | def contains_at(self, x: int, y: int, other: "SparseGrid[T]") -> bool: 1757 | """Check if other is contained at the given position in self.""" 1758 | sentinel = object() 1759 | for (ox, oy), value in other.items(): 1760 | if self.get(x + ox, y + oy, sentinel) != value: 1761 | return False 1762 | return True 1763 | 1764 | def get(self, x: int, y: int, default: U = None) -> typing.Union[T, U]: 1765 | """Get the value at the given position in the grid.""" 1766 | if 0 <= y < len(self.data) and 0 <= x < len(self.data[y]): 1767 | return self.data[y][x] 1768 | return default 1769 | 1770 | def to_sparse(self, default_factory: typing.Callable[[], T]) -> "SparseGrid[T]": 1771 | """Convert this grid to a sparse grid.""" 1772 | out = SparseGrid(default_factory) 1773 | for y, row in enumerate(self.data): 1774 | for x, value in enumerate(row): 1775 | out[x, y] = value 1776 | return out 1777 | 1778 | def vertical_chunks(self, n: int) -> list["Grid[T]"]: 1779 | """Create a list of grids formed by splitting this grid every n rows. 1780 | 1781 | Any extra rows that cannot form a group of n will be lost (see 1782 | vertical_chunks_default) 1783 | """ 1784 | chunked_rows = self.data.chunked(n) 1785 | return chunked_rows.mapped(list).mapped(Grid) 1786 | 1787 | def vertical_chunks_default(self, n: int, fill_value: T) -> list["Grid[T]"]: 1788 | """Create a list of grids formed by splitting this grid every n rows. 1789 | 1790 | Grids will be padded out to have n rows, where every cell in the padded 1791 | rows is fill_value. 1792 | """ 1793 | if self.data.len() == 0: 1794 | return list() 1795 | fill_row = list(fill_value for _ in self.data[0]) 1796 | chunked_rows = self.data.chunked_default(n, fill_row) 1797 | result = chunked_rows.mapped(list) 1798 | result[-1] = result[-1].mapped(lambda i: i.deepcopy() if i is fill_row else i) 1799 | return result.mapped(Grid) 1800 | 1801 | def horizontal_chunks(self, n: int) -> list["Grid[T]"]: 1802 | """Create a list of grids formed by splitting this grid every n columns. 1803 | 1804 | Any extra columns that cannot form a group of n will be lost (see 1805 | horizontal_chunks_default) 1806 | """ 1807 | chunked_data = [list(chunk(row, n)) for row in self.data] 1808 | return list(zip(*chunked_data)).mapped(list).mapped(Grid) 1809 | 1810 | def horizontal_chunks_default(self, n: int, fill_value: T) -> list["Grid[T]"]: 1811 | """Create a list of grids formed by splitting this grid every n columns. 1812 | 1813 | Rows will be padded out to have n values, where every cell in the padded 1814 | columns is fill_value. 1815 | """ 1816 | chunked_data = [list(chunk_default(row, n, fill_value)) for row in self.data] 1817 | return list(zip(*chunked_data)).mapped(list).mapped(Grid) 1818 | 1819 | def transpose(self) -> "Grid[T]": 1820 | """Create a grid that is the transposition of this grid. 1821 | 1822 | This operation looks similar to a 90° rotation followed by a reflection: 1823 | 1824 | ``` 1825 | ABC 1826 | DEF 1827 | HIJ 1828 | KLM 1829 | ``` 1830 | 1831 | transposes to: 1832 | 1833 | ``` 1834 | ADHK 1835 | BEIL 1836 | CFJM 1837 | ``` 1838 | """ 1839 | return Grid(self.data.transposition()) 1840 | 1841 | def rotate_clockwise(self) -> "Grid[T]": 1842 | """Create a new grid that is the clockwise rotation of this grid. 1843 | 1844 | self[0][0] is considered to be the top-left corner. 1845 | """ 1846 | return Grid(self.data[::-1].transposition()) 1847 | 1848 | def rotate_anticlockwise(self) -> "Grid[T]": 1849 | """Create a new grid that is the anti-clockwise rotation of this grid. 1850 | 1851 | self[0][0] is considered to be the top-left corner. 1852 | """ 1853 | return Grid(self.data.transposition()[::-1]) 1854 | 1855 | def to_bool_grid(self, convert: typing.Callable[[T], bool] = bool) -> "Grid[bool]": 1856 | """Create a new grid of booleans by using the given conversion function 1857 | on self. The default conversion function is bool, converting via 1858 | truthiness value. 1859 | """ 1860 | # Would love to replace this with a mapped_each call but it doesn't type-check 1861 | return Grid(self.data.mapped(lambda i: i.mapped(convert))) 1862 | 1863 | def decode_as_text(self: "Grid[bool]") -> str: 1864 | """Decode self as a grid of letters using decode_text. 1865 | 1866 | This method will check that self is the correct dimensions and raise an 1867 | AssertionError if not. 1868 | """ 1869 | self = self.trim_to_content() 1870 | assert ( 1871 | len(self.data) == 6 1872 | ), f"Expected a height of 6, found height of {len(self.data)}" 1873 | assert len(self.data[0]) % 5 == 4, ( 1874 | f"Expected a width of 5n + 4, found width of {len(self.data[0])} (5n +" 1875 | f" {len(self.data[0]) % 5})" 1876 | ) 1877 | return decode_text([[i for i in row] for row in self.data]) 1878 | 1879 | def trim_to_content(self, keep: typing.Callable[[T], bool] = bool) -> "Grid[T]": 1880 | """Create a new grid of booleans by using the given conversion function 1881 | on self. The default conversion function is bool, converting via 1882 | truthiness value.""" 1883 | trim_rows = self.data.mapped(lambda i: i.none(keep)) 1884 | if trim_rows.all(): # Trim out the entire grid 1885 | return Grid(list()) 1886 | trim_cols = self.transpose().data.mapped(lambda i: i.none(keep)) 1887 | if trim_rows.none() and trim_cols.none(): # Nothing to trim 1888 | return self.deepcopy() 1889 | trim_cols = trim_cols.enumerated() 1890 | trim_rows = trim_rows.enumerated() 1891 | top = expect(trim_rows.find(lambda i: not i[1]))[0] 1892 | bottom = expect(trim_rows[::-1].find(lambda i: not i[1]))[0] 1893 | left = expect(trim_cols.find(lambda i: not i[1]))[0] 1894 | right = expect(trim_cols[::-1].find(lambda i: not i[1]))[0] 1895 | return Grid(self.data[top : bottom + 1].mapped(lambda i: i[left : right + 1])) 1896 | 1897 | def neighbours(self, x: int, y: int) -> list[tuple[tuple[int, int], T]]: 1898 | """Return the neighbours of a point in the grid (but not the point itself). 1899 | 1900 | Examples below: 1901 | - A is the point (x, y) 1902 | - * are points returned 1903 | - . are other points in the grid 1904 | 1905 | ``` 1906 | ........... 1907 | ..***...... 1908 | ..*A*...... 1909 | ..***...... 1910 | ``` 1911 | 1912 | ``` 1913 | A*......... 1914 | **......... 1915 | ........... 1916 | ........... 1917 | ``` 1918 | """ 1919 | return ( 1920 | irange(max(y - 1, 0), min(y + 1, len(self.data) - 1)) 1921 | .map( 1922 | lambda y_: irange(max(x - 1, 0), min(x + 1, len(self.data[0]) - 1)) 1923 | .filter(lambda x_: (x, y) != (x_, y_)) 1924 | .map(lambda x: ((x, y_), self.data[y_][x])) 1925 | ) 1926 | .flatten(False) 1927 | ).collect() 1928 | 1929 | def orthogonal_neighbours(self, x: int, y: int) -> list[tuple[tuple[int, int], T]]: 1930 | """Return the orthogonal neighbours of a point in the grid (but not the 1931 | point itself). 1932 | 1933 | Examples below: 1934 | - A is the point (x, y) 1935 | - * are points returned 1936 | - . are other points in the grid 1937 | 1938 | ``` 1939 | ........... 1940 | ...*....... 1941 | ..*A*...... 1942 | ...*....... 1943 | ``` 1944 | 1945 | ``` 1946 | A*......... 1947 | *.......... 1948 | ........... 1949 | ........... 1950 | ``` 1951 | """ 1952 | rv = list() 1953 | if x > 0: 1954 | rv.append(((x - 1, y), self.data[y][x - 1])) 1955 | if x < len(self.data[0]) - 1: 1956 | rv.append(((x + 1, y), self.data[y][x + 1])) 1957 | if y > 0: 1958 | rv.append(((x, y - 1), self.data[y - 1][x])) 1959 | if y < len(self.data) - 1: 1960 | rv.append(((x, y + 1), self.data[y + 1][x])) 1961 | return rv 1962 | 1963 | def region( 1964 | self, 1965 | start: tuple[int, int], 1966 | is_in_region: typing.Callable[ 1967 | [tuple[int, int], T, tuple[int, int], T], bool 1968 | ] = lambda from_pos, from_cell, to_pos, to_cell: (from_cell == to_cell), 1969 | neighbour_type: typing.Literal["ortho", "full"] = "ortho", 1970 | ) -> tuple[set[tuple[int, int]], list[set[tuple[float, float]]]]: 1971 | """Return the region of points, and the walls enclosing that region, that 1972 | form a contiguous region containing `start`, where a point is part of 1973 | the region if it is adjacent to a cell in the region and `is_in_region` 1974 | returns True when called on that cell's position and value as well as 1975 | the candidate cell's position and value. 1976 | 1977 | The walls are returned as a list of sets of points, where each set in 1978 | the list is a continuous wall. Each point in the wall is the midpoint 1979 | between two cells in the grid; either every point in the wall will be 1980 | of the form (x + 0.5, y) or every point will be of the form (x, y + 0.5). 1981 | 1982 | If `neighbour_type` is "full", diagonal-only connections will cause some 1983 | of the returned walls to intersect: 1984 | 1985 | ``` 1986 | - 1987 | A. |A|. 1988 | .A -> -+- 1989 | .|A| 1990 | _ 1991 | ``` 1992 | """ 1993 | region = set() 1994 | walls = list[set[tuple[float, float]]]() 1995 | q = deque([start]) 1996 | 1997 | def is_boundary(x, y, x2, y2): 1998 | return not ( 1999 | 0 <= x2 < self.width 2000 | and 0 <= y2 < self.height 2001 | and is_in_region((x, y), self[x, y], (x2, y2), self[x2, y2]) 2002 | ) 2003 | 2004 | def scan_wall(x, y, dx, dy): 2005 | wall = set() 2006 | # scan rightwards 2007 | x_, y_ = x, y 2008 | x2, y2 = x, y 2009 | while 0 <= x2 < self.width and 0 <= y2 < self.height: 2010 | if not is_in_region((x_, y_), self[x_, y_], (x2, y2), self[x2, y2]): 2011 | break 2012 | x_, y_ = x2, y2 2013 | if not is_boundary(x_, y_, x_ + dx, y_ + dy): 2014 | break 2015 | wall.add((x_ + dx / 2, y_ + dy / 2)) 2016 | x2, y2 = x_ - dy, y_ + dx 2017 | # scan leftwards 2018 | x_, y_ = x, y 2019 | x2, y2 = x, y 2020 | while 0 <= x2 < self.width and 0 <= y2 < self.height: 2021 | if not is_in_region((x_, y_), self[x_, y_], (x2, y2), self[x2, y2]): 2022 | break 2023 | x_, y_ = x2, y2 2024 | if not is_boundary(x_, y_, x_ + dx, y_ + dy): 2025 | break 2026 | wall.add((x_ + dx / 2, y_ + dy / 2)) 2027 | x2, y2 = x_ + dy, y_ - dx 2028 | return wall 2029 | 2030 | while q: 2031 | pos = q.popleft() 2032 | if pos in region: 2033 | continue 2034 | region.add(pos) 2035 | x, y = pos 2036 | for dx, dy in ((0, 1), (1, 0), (0, -1), (-1, 0)): 2037 | if walls.none(lambda wall: (x + dx / 2, y + dy / 2) in wall) and ( 2038 | wall := scan_wall(x, y, dx, dy) 2039 | ): 2040 | walls.append(wall) 2041 | for neighbour_pos, neighbour_cell in ( 2042 | self.neighbours 2043 | if neighbour_type == "full" 2044 | else self.orthogonal_neighbours 2045 | )(*pos): 2046 | if is_in_region(pos, self[pos], neighbour_pos, neighbour_cell): 2047 | q.append(neighbour_pos) 2048 | return region, walls 2049 | 2050 | def regions( 2051 | self, 2052 | is_in_region: typing.Callable[ 2053 | [tuple[int, int], T, tuple[int, int], T], bool 2054 | ] = lambda from_pos, from_cell, to_pos, to_cell: (from_cell == to_cell), 2055 | neighbour_type: typing.Literal["ortho", "full"] = "ortho", 2056 | ) -> iter[tuple[set[tuple[int, int]], list[set[tuple[float, float]]]]]: 2057 | """Return all regions in the grid. See `region()` for more details.""" 2058 | 2059 | def gen(): 2060 | seen = set() 2061 | for y, row in enumerate(self.data): 2062 | for x, cell in enumerate(row): 2063 | if (x, y) in seen: 2064 | continue 2065 | region, walls = self.region( 2066 | (x, y), is_in_region=is_in_region, neighbour_type=neighbour_type 2067 | ) 2068 | seen |= region 2069 | yield region, walls 2070 | 2071 | return iter(gen()) 2072 | 2073 | def explore( 2074 | self, 2075 | can_move: typing.Callable[[tuple[int, int], T, tuple[int, int], T], bool], 2076 | return_path_when: typing.Callable[ 2077 | [tuple[int, int], T], bool 2078 | ] = lambda pos, cell: True, 2079 | start: tuple[int, int] = (0, 0), 2080 | neighbour_type: typing.Literal["ortho", "full"] = "ortho", 2081 | unique_paths: bool = False, 2082 | ) -> iter[tuple[tuple[int, int], ...]]: 2083 | def explore(): 2084 | neighbours = ( 2085 | self.neighbours 2086 | if neighbour_type == "full" 2087 | else self.orthogonal_neighbours 2088 | ) 2089 | seen = set() 2090 | q = deque([(start, self[start], tuple[tuple[int, int], ...]((start,)))]) 2091 | while q: 2092 | pos, cell, path = q.popleft() 2093 | if not unique_paths: 2094 | if pos in seen: 2095 | continue 2096 | seen.add(pos) 2097 | if return_path_when(pos, cell): 2098 | yield path 2099 | for neighbour_pos, neighbour_cell in neighbours(*pos): 2100 | if can_move(pos, cell, neighbour_pos, neighbour_cell): 2101 | q.append( 2102 | (neighbour_pos, neighbour_cell, path + (neighbour_pos,)) 2103 | ) 2104 | 2105 | return iter(explore()) 2106 | 2107 | def pathfind( 2108 | self: "Grid[AddableT]", 2109 | start: tuple[int, int] = (0, 0), 2110 | end: typing.Optional[tuple[int, int]] = None, 2111 | initial_state: HashableU = (), 2112 | is_valid_end: typing.Callable[[HashableU], bool] = lambda _: True, 2113 | next_state: typing.Callable[ 2114 | [HashableU, int, int, AddableT, AddableT], typing.Optional[HashableU] 2115 | ] = lambda old, dx, dy, i, j: (), 2116 | cost_function: typing.Callable[[AddableT, AddableT], AddableT] = ( 2117 | lambda i, j: j - i # type: ignore 2118 | ), 2119 | neighbour_type: typing.Literal["ortho", "full"] = "ortho", 2120 | initial_cost: AddableT = 0, 2121 | heuristic_multiplier: float = 1, 2122 | ) -> typing.Optional[AddableT]: 2123 | """Use the A* algorithm to find the best path from start to end, and 2124 | return the total cost. 2125 | 2126 | start defaults to the top left, and end defaults to the bottom right. 2127 | 2128 | initial_state is for custom state for the pathfinding algorithm (e.g. 2129 | extra restrictions on the path). State must be a hashable type 2130 | 2131 | is_valid_end is a function that takes in a state and returns whether the 2132 | search can end with that state. It will only be called if the target 2133 | position has been found. 2134 | 2135 | next_state is a function that takes the current state, the change in the 2136 | x and y coordinates, the previous cell value, and the current cell 2137 | value, and returns either the next state or None if the traversal cannot 2138 | be performed. 2139 | 2140 | cost_function is a function that takes the start value and the end value 2141 | of a traversal, and returns the cost of that traversal. The default is 2142 | that the cost is the difference between the two values. 2143 | 2144 | neighbour_type is either "ortho" or "full", and determines whether 2145 | diagonal traversals are considered. The default is "ortho", meaning no 2146 | diagonal traversals will be considered in the solution. 2147 | 2148 | initial_cost is the zero-value of the cell type. You should only need to 2149 | modify this if you're using a non-numeric type. 2150 | 2151 | heuristic_multiplier is a multiplier applied to the heuristic function. 2152 | The heuristic function will be either Manhattan distance from the 2153 | current state to the goal (in "ortho" mode) or Euclidean distance from 2154 | the current state to the goal (in "full" mode). The default is 1, which 2155 | means that the heuristic function will be used as-is. A value of 0 will 2156 | devolve the search to Dijkstra's algorithm, and a value higher than 1 2157 | may improve search time, potentially at the cost of accuracy. 2158 | """ 2159 | if neighbour_type not in ("ortho", "full"): 2160 | raise ValueError( 2161 | f"neighbour_type must be one of 'ortho' or 'full', not {neighbour_type}" 2162 | ) 2163 | # DEPRECATED: start should never be None, but as it was previously accepted, 2164 | # I'll leave this in for now 2165 | if start is None: 2166 | from warnings import warn 2167 | 2168 | warn( 2169 | "`start` argument to pathfind() should not be None", DeprecationWarning 2170 | ) 2171 | start = 0, 0 2172 | to_visit = PrioQueue([(initial_cost, initial_cost, start, initial_state)]) 2173 | visited = set() 2174 | if end is None: 2175 | target = len(self.data[0]) - 1, len(self.data) - 1 2176 | else: 2177 | target = end 2178 | 2179 | neighbours = ( 2180 | self.orthogonal_neighbours if neighbour_type == "ortho" else self.neighbours 2181 | ) 2182 | heuristic: typing.Callable[[int, int], float] = ( 2183 | ( 2184 | (lambda x, y: abs(x - target[0]) + abs(y - target[1])) 2185 | if neighbour_type == "ortho" 2186 | else ( 2187 | lambda x, y: math.sqrt((x - target[0]) ** 2 + (y - target[1]) ** 2) 2188 | ) 2189 | ) 2190 | if heuristic_multiplier != 0 2191 | else (lambda x, y: 0) 2192 | ) # don't bother with expensive sqrt if we're not using it 2193 | 2194 | for _heuristic_cost, cost, (x, y), state in to_visit: 2195 | if (x, y) == target and is_valid_end(state): 2196 | return cost 2197 | if (x, y, state) in visited: 2198 | continue 2199 | visited.add((x, y, state)) 2200 | for neighbour, value in neighbours(x, y): 2201 | new_state = next_state( 2202 | state, neighbour[0] - x, neighbour[1] - y, self.data[y][x], value 2203 | ) 2204 | if new_state is not None: 2205 | next_cost = cost + cost_function(self.data[y][x], value) 2206 | to_visit.push( 2207 | ( 2208 | next_cost + heuristic(*neighbour) * heuristic_multiplier, 2209 | next_cost, 2210 | neighbour, 2211 | new_state, 2212 | ) 2213 | ) 2214 | 2215 | dijkstras = pathfind 2216 | 2217 | def deepcopy(self) -> "Grid[T]": 2218 | return Grid(self.data.deepcopy()) 2219 | 2220 | @typing.overload 2221 | def __getitem__(self, index: tuple[int, int]) -> T: ... 2222 | 2223 | @typing.overload 2224 | def __getitem__(self, index: int) -> list[T]: ... 2225 | 2226 | def __getitem__(self, index): 2227 | if isinstance(index, tuple): 2228 | x, y = index 2229 | return self.data[y][x] 2230 | return self.data[index] 2231 | 2232 | def __setitem__(self, index: tuple[int, int], value: T) -> None: 2233 | x, y = index 2234 | self.data[y][x] = value 2235 | 2236 | def __repr_row(self, row: list[T]) -> str: 2237 | # Specialise output for empty, bool, and int 2238 | if row.len() == 0: 2239 | return " [],\n" 2240 | elif narrow_list(row, bool): 2241 | return " " + "".join(row.mapped("_█".__getitem__)) + "\n" 2242 | elif narrow_list(row, int): 2243 | if not hasattr(self, "_cached_int_width"): 2244 | self._cached_int_width = ( 2245 | typing.cast(Grid[int], self) 2246 | .data.mapped(lambda i: i.mapped(str).mapped(len).max()) 2247 | .max() 2248 | ) 2249 | return ( 2250 | " " 2251 | + " ".join(f"{{: >{self._cached_int_width}}}".format(i) for i in row) 2252 | + "\n" 2253 | ) 2254 | else: 2255 | return " " + repr(row.data) + ",\n" 2256 | 2257 | def __repr__(self) -> str: 2258 | if self.data.len() == 0: 2259 | return "Grid([])" 2260 | else: 2261 | out = "Grid([\n" 2262 | for row in self.data: 2263 | out += self.__repr_row(row) 2264 | return out + "])" 2265 | 2266 | 2267 | @typing.overload 2268 | def clamp( 2269 | val: SupportsRichComparisonT, max: SupportsRichComparisonT, / 2270 | ) -> SupportsRichComparisonT: ... 2271 | 2272 | 2273 | @typing.overload 2274 | def clamp( 2275 | val: SupportsRichComparisonT, 2276 | min: SupportsRichComparisonT, 2277 | max: SupportsRichComparisonT, 2278 | /, 2279 | ) -> SupportsRichComparisonT: ... 2280 | 2281 | 2282 | _SENTINEL = object() 2283 | 2284 | 2285 | def clamp( 2286 | val, 2287 | min_, 2288 | max_=_SENTINEL, # type: ignore 2289 | /, 2290 | ): 2291 | """Clamp a value between two bounds.""" 2292 | if max_ is _SENTINEL: 2293 | return max(min(val, -min_), min_) 2294 | return max(min(val, max_), min_) # type: ignore 2295 | 2296 | 2297 | def points_between( 2298 | start: tuple[int, int], end: tuple[int, int] 2299 | ) -> iter[tuple[int, int]]: 2300 | """Return an iterator of points between start and end, inclusive. 2301 | Start to end must be horizontal, vertical, or a perfect diagonal. 2302 | """ 2303 | dx = end[0] - start[0] 2304 | dy = end[1] - start[1] 2305 | assert abs(dx) == abs(dy) or dx == 0 or dy == 0 2306 | if dx == 0: 2307 | return iter(zip(itertools.repeat(start[0]), irange(start[1], end[1]))) 2308 | elif dy == 0: 2309 | return iter(zip(irange(start[0], end[0]), itertools.repeat(start[1]))) 2310 | else: 2311 | return iter(zip(irange(start[0], end[0]), irange(start[1], end[1]))) 2312 | 2313 | 2314 | class SparseGrid(typing.Generic[T]): 2315 | data: typing.DefaultDict[tuple[int, int], T] 2316 | 2317 | def __init__(self, default_factory: typing.Callable[[], T]) -> None: 2318 | self.data = collections.defaultdict(default_factory) 2319 | 2320 | @classmethod 2321 | def from_string( 2322 | cls, 2323 | data: str, 2324 | default_factory: typing.Callable[[], U], 2325 | classify: typing.Callable[[str], U] = _default_classifier, 2326 | empty_char: str = ".", 2327 | ) -> "SparseGrid[U]": 2328 | """Create a grid from a string (e.g. a puzzle input). 2329 | 2330 | Can take a classifier to use a custom classification. The default will 2331 | map numbers from 0 to 9 to themselves, and . and # to 0 and 1 respectively. 2332 | """ 2333 | out = SparseGrid(default_factory) 2334 | for y, row in enumerate(data.splitlines()): 2335 | for x, char in enumerate(row): 2336 | if char != empty_char: 2337 | out[x, y] = classify(char) 2338 | return out 2339 | 2340 | def _new_of_type(self) -> "SparseGrid[T]": 2341 | return SparseGrid(self.data.default_factory) # type: ignore 2342 | 2343 | def shear_horizontal(self, row_height: int = 1) -> "SparseGrid[T]": 2344 | """Shear the grid horizontally keeping rows of a given height. 2345 | 2346 | e.g. with row_height = 2, assuming E is the centre: 2347 | 2348 | ``` 2349 | ABC ABC 2350 | DEF -> DEF 2351 | GHI GHI 2352 | ``` 2353 | """ 2354 | out = self._new_of_type() 2355 | for (x, y), value in self.items(): 2356 | out[x + y // row_height, y] = value 2357 | return out 2358 | 2359 | def shear_vertical(self, column_width: int = 1) -> "SparseGrid[T]": 2360 | """Shear the grid vertically keeping columns of a given width. 2361 | 2362 | e.g. with column_width = 2, assuming E is the centre: 2363 | 2364 | ``` 2365 | ABC BC 2366 | DEF -> AEF 2367 | GHI DHI 2368 | G 2369 | ``` 2370 | """ 2371 | out = self._new_of_type() 2372 | for (x, y), value in self.items(): 2373 | out[x, y + x // column_width] = value 2374 | return out 2375 | 2376 | def rotate_45_clockwise(self) -> "SparseGrid[T]": 2377 | """Rotate the grid 45° clockwise. 2378 | 2379 | This is a shear rotation, so the output may look strange: 2380 | 2381 | ``` 2382 | ABC DAB 2383 | DEF -> GEC 2384 | GHI HIF 2385 | ``` 2386 | """ 2387 | out = self._new_of_type() 2388 | for (x, y), value in self.items(): 2389 | out[ 2390 | clamp(x - y, -max(abs(x), abs(y)), max(abs(x), abs(y))), 2391 | clamp(x + y, -max(abs(x), abs(y)), max(abs(x), abs(y))), 2392 | ] = value 2393 | return out 2394 | 2395 | def rotate_45_anticlockwise(self) -> "SparseGrid[T]": 2396 | """Rotate the grid 45° clockwise. 2397 | 2398 | This is a shear rotation, so the output may look strange: 2399 | 2400 | ``` 2401 | ABC BCF 2402 | DEF -> AEI 2403 | GHI DGH 2404 | ``` 2405 | """ 2406 | out = self._new_of_type() 2407 | for (x, y), value in self.items(): 2408 | out[ 2409 | clamp(x + y, -max(abs(x), abs(y)), max(abs(x), abs(y))), 2410 | clamp(y - x, -max(abs(x), abs(y)), max(abs(x), abs(y))), 2411 | ] = value 2412 | return out 2413 | 2414 | def rotations(self) -> "list[SparseGrid[T]]": 2415 | """Return a list of all 45° rotations of the grid.""" 2416 | out = list() 2417 | for _ in range(8): 2418 | self = self.rotate_45_clockwise() 2419 | out.append(self) 2420 | return out 2421 | 2422 | def cardinal_rotations(self) -> "list[SparseGrid[T]]": 2423 | """Return a list of all 90° rotations of the grid.""" 2424 | out = list() 2425 | for _ in range(4): 2426 | self = self.rotate_45_clockwise().rotate_45_clockwise() 2427 | out.append(self) 2428 | return out 2429 | 2430 | def draw_line( 2431 | self, 2432 | start: tuple[int, int], 2433 | end: tuple[int, int], 2434 | value: T, 2435 | ) -> None: 2436 | """Draw a line on a sparse grid, setting all points between start and end 2437 | to value. 2438 | """ 2439 | for x, y in points_between(start, end): 2440 | self[x, y] = value 2441 | 2442 | def draw_lines( 2443 | self, 2444 | lines: Iterable[tuple[int, int]], 2445 | value: T, 2446 | ) -> None: 2447 | """Draw a series of lines on a sparse grid, setting all points between 2448 | each pair of points to value. 2449 | """ 2450 | _lines: list[tuple[int, int]] = list(lines) 2451 | if _lines: 2452 | x, y = _lines[0] # allows for lists to be used instead of tuples 2453 | self[x, y] = value 2454 | for start, end in _lines.windowed(2): 2455 | self.draw_line(start, end, value) 2456 | 2457 | def bounds(self, empty: builtins.list[T]) -> tuple[int, int, int, int]: 2458 | """Return the bounds of a sparse grid, as a tuple of (min_x, min_y, max_x, max_y).""" 2459 | if len(self) == 0: 2460 | return 0, 0, 0, 0 2461 | else: 2462 | return ( 2463 | min(x for x, _ in filter(lambda i: self[i] not in empty, self)), 2464 | min(y for _, y in filter(lambda i: self[i] not in empty, self)), 2465 | max(x for x, _ in filter(lambda i: self[i] not in empty, self)), 2466 | max(y for _, y in filter(lambda i: self[i] not in empty, self)), 2467 | ) 2468 | 2469 | def pretty_print( 2470 | self, to_char: typing.Callable[[T], str], empty: builtins.list[T] 2471 | ) -> None: 2472 | """Print a sparse grid to the console.""" 2473 | min_x, min_y, max_x, max_y = self.bounds(empty) 2474 | max_y_width = max(len(str(max_y)), len(str(min_y))) 2475 | max_x_width = max(len(str(max_x)), len(str(min_x))) 2476 | x_labels = [ 2477 | f"{x:={max_x_width}}" if x % 2 == 0 else (" " * max_x_width) 2478 | for x in irange(min_x, max_x) 2479 | ] 2480 | for char in range(max_x_width): 2481 | print(" " * max_y_width, end=" ") 2482 | for label in x_labels: 2483 | print(label[char], end="") 2484 | print() 2485 | for y in irange(min_y, max_y): 2486 | print(f"{y:={max_y_width}}", end=" ") 2487 | for x in irange(min_x, max_x): 2488 | print(to_char(self[x, y]), end="") 2489 | print() 2490 | 2491 | def __getitem__(self, index: tuple[int, int]) -> T: 2492 | return self.data[index] 2493 | 2494 | def __setitem__(self, index: tuple[int, int], value: T) -> None: 2495 | self.data[index] = value 2496 | 2497 | def __delitem__(self, index: tuple[int, int]) -> None: 2498 | del self.data[index] 2499 | 2500 | def __len__(self) -> int: 2501 | return len(self.data) 2502 | 2503 | def __iter__(self) -> iter[tuple[int, int]]: 2504 | return iter(self.data) 2505 | 2506 | def __repr__(self) -> str: 2507 | return f"SparseGrid({self.data})" 2508 | 2509 | def keys(self) -> iter[tuple[int, int]]: 2510 | return iter(self.data.keys()) 2511 | 2512 | def values(self) -> iter[T]: 2513 | return iter(self.data.values()) 2514 | 2515 | def items(self) -> iter[tuple[tuple[int, int], T]]: 2516 | return iter(self.data.items()) 2517 | 2518 | 2519 | def expect(val: typing.Optional[T]) -> T: 2520 | """Expect that a value is not None.""" 2521 | assert val is not None 2522 | return val 2523 | 2524 | 2525 | def narrow_list(list: list, type: typing.Type[T]) -> typing.TypeGuard[list[T]]: 2526 | """Narrow the type of list based on the passed type. 2527 | 2528 | Assumes that list is homogenous. 2529 | """ 2530 | return isinstance(list[0], type) 2531 | 2532 | 2533 | def pathfind( 2534 | grid: builtins.list[builtins.list[int]], 2535 | start: tuple[int, int] = (0, 0), 2536 | end: typing.Optional[tuple[int, int]] = None, 2537 | ) -> int: 2538 | """Use the A* algorithm to find the best path from start to end, and 2539 | return the total cost. 2540 | 2541 | start defaults to the top left, and end defaults to the bottom right. 2542 | 2543 | grid is assumed to be a rectangular 2D array of integers, *not* a ragged 2544 | array. Bad things will happen if you pass a ragged array. 2545 | """ 2546 | max_x = len(grid[-1]) - 1 2547 | max_y = len(grid) - 1 2548 | if end is None: 2549 | end = max_x, max_y 2550 | return search( 2551 | start, 2552 | lambda state: state == end, 2553 | lambda state: filter( 2554 | None, 2555 | [ 2556 | (state[0] - 1, state[1]) if state[0] > 0 else None, 2557 | (state[0] + 1, state[1]) if state[0] < max_x else None, 2558 | (state[0], state[1] - 1) if state[1] > 0 else None, 2559 | (state[0], state[1] + 1) if state[1] < max_y else None, 2560 | ], 2561 | ), 2562 | heuristic=lambda state: abs(state[0] - end[0]) + abs(state[1] - end[1]), 2563 | )[0] 2564 | 2565 | 2566 | dijkstras = pathfind 2567 | 2568 | 2569 | class PrioQueue(typing.Generic[T], typing.Iterator[T], typing.Iterable[T]): 2570 | _data: builtins.list[T] 2571 | 2572 | def __init__(self, data: builtins.list[T]) -> None: 2573 | self._data = data 2574 | heapify(self._data) 2575 | 2576 | def __next__(self) -> T: 2577 | if not self._data: 2578 | raise StopIteration 2579 | return heappop(self._data) 2580 | 2581 | def __iter__(self): 2582 | return self 2583 | 2584 | def __bool__(self) -> bool: 2585 | return bool(self._data) 2586 | 2587 | def next(self) -> T: 2588 | return next(self) 2589 | 2590 | def push(self, val: T) -> None: 2591 | heappush(self._data, val) 2592 | 2593 | def __repr__(self) -> str: 2594 | return f"PrioQueue({self._data})" 2595 | 2596 | 2597 | def rsearch( 2598 | pattern: typing.Union[str, typing.Pattern[str]], 2599 | text: str, 2600 | ) -> typing.Optional[typing.Match]: 2601 | """ 2602 | Search for the rightmost occurrence of a pattern in a string. 2603 | 2604 | This is *not* the same as re.findall(pattern, text)[-1], as that will not 2605 | detect the rightmost match if it overlaps with a previous match. 2606 | 2607 | Be aware that this function is not very efficient, and so should not be 2608 | used with complex patterns. 2609 | """ 2610 | start = len(text) - 1 2611 | match = None 2612 | while match is None and start >= 0: 2613 | match = re.match(pattern, text[start:]) 2614 | start -= 1 2615 | return match 2616 | 2617 | 2618 | def search( 2619 | state: T, 2620 | finished: typing.Callable[[T], bool], 2621 | next_states: typing.Callable[[T], typing.Iterable[T]], 2622 | heuristic: typing.Callable[[T], float] = lambda i: 0, 2623 | freeze: typing.Callable[[T], Hashable] | None = None, 2624 | ) -> tuple[int, builtins.list[T]]: 2625 | """Perform A* (or Dijkstra if heuristic is not provided) search on a state 2626 | space, returning the number of steps and all states in the chosen path 2627 | (including both start and end point). 2628 | 2629 | Will optimise to avoid revisiting seen states if the state type is hashable, 2630 | or if the freeze function is provided. The freeze function will take 2631 | priority over the default hashing behaviour for the state type, if present. 2632 | 2633 | It is probably a good idea to either make your state type hashable, or 2634 | provide a freeze function. 2635 | """ 2636 | 2637 | queue = PrioQueue([(heuristic(state), int(), state, [state])]) 2638 | visited = set() 2639 | for _, steps, state, history in queue: 2640 | if finished(state): 2641 | return steps, history 2642 | # check for freeze first, as it allows for caller customisation without 2643 | # having to make a custom class in order to modify hash behaviour of 2644 | # their state type (might be useful for e.g. making certain parts of 2645 | # a state equivalent when by default they wouldn't be) 2646 | if freeze is not None: 2647 | frozen = freeze(state) 2648 | if frozen in visited: 2649 | continue 2650 | visited.add(frozen) 2651 | elif isinstance(state, Hashable): 2652 | if state in visited: 2653 | continue 2654 | visited.add(state) 2655 | for next_state in next_states(state): 2656 | queue.push( 2657 | ( 2658 | steps + 1 + heuristic(next_state), 2659 | steps + 1, 2660 | next_state, 2661 | history + [next_state], 2662 | ) 2663 | ) 2664 | raise ValueError("No path found; ran out of states to visit") 2665 | 2666 | 2667 | def chinese_remainder_theorem( 2668 | moduli: builtins.list[int], residues: builtins.list[int] 2669 | ) -> int: 2670 | """Given the numbers N % modulus_i = residue_i, return N % prod(modulus_i). 2671 | 2672 | Moduli must be pairwise coprime (i.e. no pair of moduli may share a factor 2673 | other than 1) - violating this constraint will produce an undefined result. 2674 | """ 2675 | from math import prod 2676 | 2677 | N = prod(moduli) 2678 | 2679 | return ( 2680 | sum( 2681 | (div := (N // modulus)) * pow(div, -1, modulus) * residue 2682 | for modulus, residue in zip(moduli, residues) 2683 | ) 2684 | % N 2685 | ) 2686 | --------------------------------------------------------------------------------