├── VERSION ├── ircrobots ├── py.typed ├── matching │ ├── __init__.py │ ├── responses.py │ └── params.py ├── contexts.py ├── __init__.py ├── struct.py ├── security.py ├── formatting.py ├── glob.py ├── asyncs.py ├── bot.py ├── params.py ├── transport.py ├── interface.py ├── scram.py ├── ircv3.py ├── sasl.py └── server.py ├── test ├── __init__.py └── glob.py ├── MANIFEST.in ├── requirements.txt ├── .travis.yml ├── README.md ├── examples ├── sasl.py ├── simple.py └── factoids.py ├── LICENSE ├── setup.py └── .gitignore /VERSION: -------------------------------------------------------------------------------- 1 | 0.7.2 2 | -------------------------------------------------------------------------------- /ircrobots/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | from .glob import * 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | include requirements.txt 3 | -------------------------------------------------------------------------------- /ircrobots/matching/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .responses import * 3 | from .params import * 4 | -------------------------------------------------------------------------------- /ircrobots/contexts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from .interface import IServer 3 | 4 | @dataclass 5 | class ServerContext(object): 6 | server: IServer 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio ~=2.0.2 2 | asyncio-rlock ~=0.1.0 3 | asyncio-throttle ~=1.0.1 4 | ircstates ~=0.13.0 5 | async_stagger ~=0.3.0 6 | async_timeout ~=4.0.2 7 | -------------------------------------------------------------------------------- /ircrobots/__init__.py: -------------------------------------------------------------------------------- 1 | from .bot import Bot 2 | from .server import Server 3 | from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, 4 | STSPolicy, ResumePolicy) 5 | from .ircv3 import Capability 6 | from .security import TLS 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | cache: pip 3 | python: 4 | - "3.7" 5 | - "3.8" 6 | - "3.9" 7 | install: 8 | - pip3 install mypy -r requirements.txt 9 | script: 10 | - pip3 freeze 11 | - mypy ircrobots examples --ignore-missing-imports 12 | - python3 -m unittest test 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ircrobots 2 | 3 | ## rationale 4 | I wanted a very-bare-bones IRC bot framework that deals with most of the 5 | concerns one would deal with in scheduling and awaiting async stuff, e.g. 6 | creating and awaiting a new task for each server while dynamically being able 7 | to add/remove servers. 8 | 9 | ## usage 10 | see [examples/](examples/) for some usage demonstration. 11 | 12 | ## contact 13 | 14 | Come say hi at `#irctokens` on irc.libera.chat 15 | -------------------------------------------------------------------------------- /test/glob.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from ircrobots import glob 3 | 4 | class GlobTestCollapse(unittest.TestCase): 5 | def test(self): 6 | c1 = glob.collapse("**?*") 7 | self.assertEqual(c1, "?*") 8 | 9 | c2 = glob.collapse("a**?a*") 10 | self.assertEqual(c2, "a?*a*") 11 | 12 | c3 = glob.collapse("?*?*?*?*a") 13 | self.assertEqual(c3, "????*a") 14 | 15 | c4 = glob.collapse("a*?*a?**") 16 | self.assertEqual(c4, "a?*a?*") 17 | -------------------------------------------------------------------------------- /ircrobots/struct.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from dataclasses import dataclass 3 | 4 | from ircstates import ChannelUser 5 | 6 | class Whois(object): 7 | server: Optional[str] = None 8 | server_info: Optional[str] = None 9 | operator: bool = False 10 | 11 | secure: bool = False 12 | 13 | signon: Optional[int] = None 14 | idle: Optional[int] = None 15 | 16 | channels: Optional[List[ChannelUser]] = None 17 | 18 | nickname: str = "" 19 | username: str = "" 20 | hostname: str = "" 21 | realname: str = "" 22 | account: Optional[str] = None 23 | 24 | -------------------------------------------------------------------------------- /ircrobots/security.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | @dataclass 6 | class TLS: 7 | client_keypair: Optional[Tuple[str, str]] = None 8 | 9 | # tls without verification 10 | class TLSNoVerify(TLS): 11 | pass 12 | 13 | # verify via CAs 14 | class TLSVerifyChain(TLS): 15 | pass 16 | 17 | # verify by a pinned hash 18 | class TLSVerifyHash(TLSNoVerify): 19 | def __init__(self, sum: str): 20 | self.sum = sum.lower() 21 | class TLSVerifySHA512(TLSVerifyHash): 22 | pass 23 | 24 | def tls_context(verify: bool=True) -> ssl.SSLContext: 25 | ctx = ssl.create_default_context() 26 | if not verify: 27 | ctx.check_hostname = False 28 | ctx.verify_mode = ssl.CERT_NONE 29 | return ctx 30 | -------------------------------------------------------------------------------- /examples/sasl.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from irctokens import build, Line 4 | from ircrobots import Bot as BaseBot 5 | from ircrobots import Server as BaseServer 6 | from ircrobots import ConnectionParams, SASLUserPass, SASLSCRAM 7 | 8 | class Server(BaseServer): 9 | async def line_read(self, line: Line): 10 | print(f"{self.name} < {line.format()}") 11 | async def line_send(self, line: Line): 12 | print(f"{self.name} > {line.format()}") 13 | 14 | class Bot(BaseBot): 15 | def create_server(self, name: str): 16 | return Server(self, name) 17 | 18 | async def main(): 19 | bot = Bot() 20 | 21 | sasl_params = SASLUserPass("myusername", "invalidpassword") 22 | params = ConnectionParams( 23 | "MyNickname", 24 | host = "chat.freenode.invalid", 25 | port = 6697, 26 | sasl = sasl_params) 27 | 28 | await bot.add_server("freenode", params) 29 | await bot.run() 30 | 31 | if __name__ == "__main__": 32 | asyncio.run(main()) 33 | -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from irctokens import build, Line 4 | from ircrobots import Bot as BaseBot 5 | from ircrobots import Server as BaseServer 6 | from ircrobots import ConnectionParams 7 | 8 | SERVERS = [ 9 | ("freenode", "chat.freenode.invalid") 10 | ] 11 | 12 | class Server(BaseServer): 13 | async def line_read(self, line: Line): 14 | print(f"{self.name} < {line.format()}") 15 | if line.command == "001": 16 | print(f"connected to {self.isupport.network}") 17 | await self.send(build("JOIN", ["#testchannel"])) 18 | async def line_send(self, line: Line): 19 | print(f"{self.name} > {line.format()}") 20 | 21 | class Bot(BaseBot): 22 | def create_server(self, name: str): 23 | return Server(self, name) 24 | 25 | async def main(): 26 | bot = Bot() 27 | for name, host in SERVERS: 28 | params = ConnectionParams("BitBotNewTest", host, 6697) 29 | await bot.add_server(name, params) 30 | 31 | await bot.run() 32 | 33 | if __name__ == "__main__": 34 | asyncio.run(main()) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 jesopo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ircrobots/formatting.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | BOLD = "\x02" 4 | COLOR = "\x03" 5 | INVERT = "\x16" 6 | ITALIC = "\x1D" 7 | UNDERLINE = "\x1F" 8 | RESET = "\x0F" 9 | 10 | FORMATTERS = [ 11 | BOLD, 12 | INVERT, 13 | ITALIC, 14 | UNDERLINE, 15 | RESET 16 | ] 17 | 18 | def tokens(s: str) -> List[str]: 19 | tokens: List[str] = [] 20 | 21 | s_copy = list(s) 22 | while s_copy: 23 | token = s_copy.pop(0) 24 | if token == COLOR: 25 | for i in range(2): 26 | if s_copy and s_copy[0].isdigit(): 27 | token += s_copy.pop(0) 28 | if (len(s_copy) > 1 and 29 | s_copy[0] == "," and 30 | s_copy[1].isdigit()): 31 | token += s_copy.pop(0) 32 | token += s_copy.pop(0) 33 | if s_copy and s_copy[0].isdigit(): 34 | token += s_copy.pop(0) 35 | 36 | tokens.append(token) 37 | elif token in FORMATTERS: 38 | tokens.append(token) 39 | return tokens 40 | 41 | def strip(s: str): 42 | for token in tokens(s): 43 | s = s.replace(token, "", 1) 44 | return s 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_namespace_packages, setup 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | with open("VERSION", "r") as version_file: 6 | version = version_file.read().strip() 7 | with open("requirements.txt", "r") as requirements_file: 8 | install_requires = requirements_file.read().splitlines() 9 | 10 | setup( 11 | name="ircrobots", 12 | version=version, 13 | author="jesopo", 14 | author_email="pip@jesopo.uk", 15 | description="Asyncio IRC bot framework", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/jesopo/ircrobots", 19 | packages=["ircrobots"] + find_namespace_packages(include=["ircrobots.*"]), 20 | package_data={"ircrobots": ["py.typed"]}, 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Operating System :: POSIX", 26 | "Operating System :: Microsoft :: Windows", 27 | "Topic :: Communications :: Chat :: Internet Relay Chat" 28 | ], 29 | python_requires='>=3.7', 30 | install_requires=install_requires 31 | ) 32 | -------------------------------------------------------------------------------- /ircrobots/glob.py: -------------------------------------------------------------------------------- 1 | 2 | def collapse(pattern: str) -> str: 3 | out = "" 4 | i = 0 5 | while i < len(pattern): 6 | seen_ast = False 7 | while pattern[i:] and pattern[i] in ["*", "?"]: 8 | if pattern[i] == "?": 9 | out += "?" 10 | elif pattern[i] == "*": 11 | seen_ast = True 12 | i += 1 13 | if seen_ast: 14 | out += "*" 15 | 16 | if pattern[i:]: 17 | out += pattern[i] 18 | i += 1 19 | return out 20 | 21 | def _match(pattern: str, s: str): 22 | i, j = 0, 0 23 | 24 | i_backup = -1 25 | j_backup = -1 26 | while j < len(s): 27 | p = (pattern[i:] or [None])[0] 28 | 29 | if p == "*": 30 | i += 1 31 | i_backup = i 32 | j_backup = j 33 | 34 | elif p in ["?", s[j]]: 35 | i += 1 36 | j += 1 37 | 38 | else: 39 | if i_backup == -1: 40 | return False 41 | else: 42 | j_backup += 1 43 | j = j_backup 44 | i = i_backup 45 | 46 | return i == len(pattern) 47 | 48 | class Glob(object): 49 | def __init__(self, pattern: str): 50 | self._pattern = pattern 51 | def match(self, s: str) -> bool: 52 | return _match(self._pattern, s) 53 | def compile(pattern: str) -> Glob: 54 | return Glob(collapse(pattern)) 55 | -------------------------------------------------------------------------------- /ircrobots/asyncs.py: -------------------------------------------------------------------------------- 1 | from asyncio import Future 2 | from typing import (Any, Awaitable, Callable, Generator, Generic, Optional, 3 | TypeVar) 4 | 5 | from irctokens import Line 6 | from .matching import IMatchResponse 7 | from .interface import IServer 8 | from .ircv3 import TAG_LABEL 9 | 10 | TEvent = TypeVar("TEvent") 11 | class MaybeAwait(Generic[TEvent]): 12 | def __init__(self, func: Callable[[], Awaitable[TEvent]]): 13 | self._func = func 14 | 15 | def __await__(self) -> Generator[Any, None, TEvent]: 16 | coro = self._func() 17 | return coro.__await__() 18 | 19 | class WaitFor(object): 20 | def __init__(self, 21 | response: IMatchResponse, 22 | deadline: float): 23 | self.response = response 24 | self.deadline = deadline 25 | self._label: Optional[str] = None 26 | self._our_fut: "Future[Line]" = Future() 27 | 28 | def __await__(self) -> Generator[Any, None, Line]: 29 | return self._our_fut.__await__() 30 | 31 | def with_label(self, label: str): 32 | self._label = label 33 | 34 | def match(self, server: IServer, line: Line): 35 | if (self._label is not None and 36 | line.tags is not None): 37 | label = TAG_LABEL.get(line.tags) 38 | if (label is not None and 39 | label == self._label): 40 | return True 41 | return self.response.match(server, line) 42 | 43 | def resolve(self, line: Line): 44 | self._our_fut.set_result(line) 45 | -------------------------------------------------------------------------------- /ircrobots/bot.py: -------------------------------------------------------------------------------- 1 | import asyncio, traceback 2 | import anyio 3 | from typing import Dict 4 | 5 | from ircstates.server import ServerDisconnectedException 6 | 7 | from .server import ConnectionParams, Server 8 | from .transport import TCPTransport 9 | from .interface import IBot, IServer, ITCPTransport 10 | 11 | class Bot(IBot): 12 | def __init__(self): 13 | self.servers: Dict[str, Server] = {} 14 | self._server_queue: asyncio.Queue[Server] = asyncio.Queue() 15 | 16 | def create_server(self, name: str): 17 | return Server(self, name) 18 | 19 | async def disconnected(self, server: IServer): 20 | if (server.name in self.servers and 21 | server.params is not None and 22 | server.disconnected): 23 | 24 | reconnect = server.params.reconnect 25 | 26 | while True: 27 | await asyncio.sleep(reconnect) 28 | try: 29 | await self.add_server(server.name, server.params) 30 | except Exception as e: 31 | traceback.print_exc() 32 | # let's try again, exponential backoff up to 5 mins 33 | reconnect = min(reconnect*2, 300) 34 | else: 35 | break 36 | 37 | async def disconnect(self, server: IServer): 38 | del self.servers[server.name] 39 | await server.disconnect() 40 | 41 | async def add_server(self, 42 | name: str, 43 | params: ConnectionParams, 44 | transport: ITCPTransport = TCPTransport()) -> Server: 45 | server = self.create_server(name) 46 | self.servers[name] = server 47 | await server.connect(transport, params) 48 | await self._server_queue.put(server) 49 | return server 50 | 51 | async def _run_server(self, server: Server): 52 | try: 53 | async with anyio.create_task_group() as tg: 54 | await tg.spawn(server._read_lines) 55 | await tg.spawn(server._send_lines) 56 | except ServerDisconnectedException: 57 | server.disconnected = True 58 | 59 | await self.disconnected(server) 60 | 61 | async def run(self): 62 | async with anyio.create_task_group() as tg: 63 | while not tg.cancel_scope.cancel_called: 64 | server = await self._server_queue.get() 65 | await tg.spawn(self._run_server, server) 66 | -------------------------------------------------------------------------------- /ircrobots/params.py: -------------------------------------------------------------------------------- 1 | from re import compile as re_compile 2 | from typing import List, Optional 3 | from dataclasses import dataclass, field 4 | 5 | from .security import TLS, TLSNoVerify, TLSVerifyChain 6 | 7 | class SASLParams(object): 8 | mechanism: str 9 | 10 | @dataclass 11 | class _SASLUserPass(SASLParams): 12 | username: str 13 | password: str 14 | 15 | class SASLUserPass(_SASLUserPass): 16 | mechanism = "USERPASS" 17 | class SASLSCRAM(_SASLUserPass): 18 | mechanism = "SCRAM" 19 | class SASLExternal(SASLParams): 20 | mechanism = "EXTERNAL" 21 | 22 | @dataclass 23 | class STSPolicy(object): 24 | created: int 25 | port: int 26 | duration: int 27 | preload: bool 28 | 29 | @dataclass 30 | class ResumePolicy(object): 31 | address: str 32 | token: str 33 | 34 | RE_IPV6HOST = re_compile(r"\[([a-fA-F0-9:]+)\]") 35 | 36 | _TLS_TYPES = { 37 | "+": TLSVerifyChain, 38 | "~": TLSNoVerify, 39 | } 40 | @dataclass 41 | class ConnectionParams(object): 42 | nickname: str 43 | host: str 44 | port: int 45 | tls: Optional[TLS] = field(default_factory=TLSVerifyChain) 46 | 47 | username: Optional[str] = None 48 | realname: Optional[str] = None 49 | bindhost: Optional[str] = None 50 | 51 | password: Optional[str] = None 52 | sasl: Optional[SASLParams] = None 53 | 54 | sts: Optional[STSPolicy] = None 55 | resume: Optional[ResumePolicy] = None 56 | 57 | reconnect: int = 10 # seconds 58 | alt_nicknames: List[str] = field(default_factory=list) 59 | 60 | autojoin: List[str] = field(default_factory=list) 61 | 62 | @staticmethod 63 | def from_hoststring( 64 | nickname: str, 65 | hoststring: str 66 | ) -> "ConnectionParams": 67 | 68 | ipv6host = RE_IPV6HOST.search(hoststring) 69 | if ipv6host is not None and ipv6host.start() == 0: 70 | host = ipv6host.group(1) 71 | port_s = hoststring[ipv6host.end()+1:] 72 | else: 73 | host, _, port_s = hoststring.strip().partition(":") 74 | 75 | tls_type: Optional[TLS] = None 76 | if not port_s: 77 | port_s = "6667" 78 | else: 79 | tls_type = _TLS_TYPES.get(port_s[0], lambda: None)() 80 | if tls_type is not None: 81 | port_s = port_s[1:] or "6697" 82 | 83 | return ConnectionParams(nickname, host, int(port_s), tls_type) 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ircrobots/matching/responses.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Sequence, Union 2 | from irctokens import Line 3 | from ..interface import (IServer, IMatchResponse, IMatchResponseParam, 4 | IMatchResponseHostmask) 5 | from .params import * 6 | 7 | TYPE_PARAM = Union[str, IMatchResponseParam] 8 | class Responses(IMatchResponse): 9 | def __init__(self, 10 | commands: Sequence[str], 11 | params: Sequence[TYPE_PARAM]=[], 12 | source: Optional[IMatchResponseHostmask]=None): 13 | self._commands = commands 14 | self._source = source 15 | 16 | self._params: Sequence[IMatchResponseParam] = [] 17 | for param in params: 18 | if isinstance(param, str): 19 | self._params.append(Literal(param)) 20 | elif isinstance(param, IMatchResponseParam): 21 | self._params.append(param) 22 | 23 | def __repr__(self) -> str: 24 | return f"Responses({self._commands!r}: {self._params!r})" 25 | 26 | def match(self, server: IServer, line: Line) -> bool: 27 | for command in self._commands: 28 | if (line.command == command and ( 29 | self._source is None or ( 30 | line.hostmask is not None and 31 | self._source.match(server, line.hostmask) 32 | ))): 33 | 34 | for i, param in enumerate(self._params): 35 | if (i >= len(line.params) or 36 | not param.match(server, line.params[i])): 37 | break 38 | else: 39 | return True 40 | else: 41 | return False 42 | 43 | class Response(Responses): 44 | def __init__(self, 45 | command: str, 46 | params: Sequence[TYPE_PARAM]=[], 47 | source: Optional[IMatchResponseHostmask]=None): 48 | super().__init__([command], params, source=source) 49 | 50 | def __repr__(self) -> str: 51 | return f"Response({self._commands[0]}: {self._params!r})" 52 | 53 | class ResponseOr(IMatchResponse): 54 | def __init__(self, *responses: IMatchResponse): 55 | self._responses = responses 56 | def __repr__(self) -> str: 57 | return f"ResponseOr({self._responses!r})" 58 | def match(self, server: IServer, line: Line) -> bool: 59 | for response in self._responses: 60 | if response.match(server, line): 61 | return True 62 | else: 63 | return False 64 | -------------------------------------------------------------------------------- /ircrobots/transport.py: -------------------------------------------------------------------------------- 1 | from hashlib import sha512 2 | from ssl import SSLContext 3 | from typing import Optional, Tuple 4 | from asyncio import StreamReader, StreamWriter 5 | from async_stagger import open_connection 6 | 7 | from .interface import ITCPTransport, ITCPReader, ITCPWriter 8 | from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash, 9 | TLSVerifySHA512) 10 | 11 | class TCPReader(ITCPReader): 12 | def __init__(self, reader: StreamReader): 13 | self._reader = reader 14 | 15 | async def read(self, byte_count: int) -> bytes: 16 | return await self._reader.read(byte_count) 17 | class TCPWriter(ITCPWriter): 18 | def __init__(self, writer: StreamWriter): 19 | self._writer = writer 20 | 21 | def get_peer(self) -> Tuple[str, int]: 22 | address, port, *_ = self._writer.transport.get_extra_info("peername") 23 | return (address, port) 24 | 25 | def write(self, data: bytes): 26 | self._writer.write(data) 27 | 28 | async def drain(self): 29 | await self._writer.drain() 30 | 31 | async def close(self): 32 | self._writer.close() 33 | await self._writer.wait_closed() 34 | 35 | class TCPTransport(ITCPTransport): 36 | async def connect(self, 37 | hostname: str, 38 | port: int, 39 | tls: Optional[TLS], 40 | bindhost: Optional[str]=None 41 | ) -> Tuple[ITCPReader, ITCPWriter]: 42 | 43 | cur_ssl: Optional[SSLContext] = None 44 | if tls is not None: 45 | cur_ssl = tls_context(not isinstance(tls, TLSNoVerify)) 46 | if tls.client_keypair is not None: 47 | (client_cert, client_key) = tls.client_keypair 48 | cur_ssl.load_cert_chain(client_cert, keyfile=client_key) 49 | 50 | local_addr: Optional[Tuple[str, int]] = None 51 | if not bindhost is None: 52 | local_addr = (bindhost, 0) 53 | 54 | server_hostname = hostname if tls else None 55 | 56 | reader, writer = await open_connection( 57 | hostname, 58 | port, 59 | server_hostname=server_hostname, 60 | ssl =cur_ssl, 61 | local_addr =local_addr) 62 | 63 | if isinstance(tls, TLSVerifyHash): 64 | cert: bytes = writer.transport.get_extra_info( 65 | "ssl_object" 66 | ).getpeercert(True) 67 | if isinstance(tls, TLSVerifySHA512): 68 | sum = sha512(cert).hexdigest() 69 | else: 70 | raise ValueError(f"unknown hash pinning {type(tls)}") 71 | 72 | if not sum == tls.sum: 73 | raise ValueError( 74 | f"pinned hash for {hostname} does not match ({sum})" 75 | ) 76 | 77 | return (TCPReader(reader), TCPWriter(writer)) 78 | 79 | -------------------------------------------------------------------------------- /ircrobots/interface.py: -------------------------------------------------------------------------------- 1 | from asyncio import Future 2 | from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union 3 | from enum import IntEnum 4 | 5 | from ircstates import Server, Emit 6 | from irctokens import Line, Hostmask 7 | 8 | from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy 9 | from .security import TLS 10 | 11 | class ITCPReader(object): 12 | async def read(self, byte_count: int): 13 | pass 14 | class ITCPWriter(object): 15 | def write(self, data: bytes): 16 | pass 17 | 18 | def get_peer(self) -> Tuple[str, int]: 19 | pass 20 | 21 | async def drain(self): 22 | pass 23 | async def close(self): 24 | pass 25 | 26 | class ITCPTransport(object): 27 | async def connect(self, 28 | hostname: str, 29 | port: int, 30 | tls: Optional[TLS], 31 | bindhost: Optional[str]=None 32 | ) -> Tuple[ITCPReader, ITCPWriter]: 33 | pass 34 | 35 | class SendPriority(IntEnum): 36 | HIGH = 0 37 | MEDIUM = 10 38 | LOW = 20 39 | DEFAULT = MEDIUM 40 | 41 | class SentLine(object): 42 | def __init__(self, 43 | id: int, 44 | priority: int, 45 | line: Line): 46 | self.id = id 47 | self.priority = priority 48 | self.line = line 49 | self.future: "Future[SentLine]" = Future() 50 | 51 | def __lt__(self, other: "SentLine") -> bool: 52 | return self.priority < other.priority 53 | 54 | class ICapability(object): 55 | def available(self, capabilities: Iterable[str]) -> Optional[str]: 56 | pass 57 | 58 | def match(self, capability: str) -> bool: 59 | pass 60 | 61 | def copy(self) -> "ICapability": 62 | pass 63 | 64 | class IMatchResponse(object): 65 | def match(self, server: "IServer", line: Line) -> bool: 66 | pass 67 | class IMatchResponseParam(object): 68 | def match(self, server: "IServer", arg: str) -> bool: 69 | pass 70 | class IMatchResponseValueParam(IMatchResponseParam): 71 | def value(self, server: "IServer"): 72 | pass 73 | def set_value(self, value: str): 74 | pass 75 | class IMatchResponseHostmask(object): 76 | def match(self, server: "IServer", hostmask: Hostmask) -> bool: 77 | pass 78 | 79 | class IServer(Server): 80 | bot: "IBot" 81 | disconnected: bool 82 | params: ConnectionParams 83 | desired_caps: Set[ICapability] 84 | last_read: float 85 | 86 | def send_raw(self, line: str, priority=SendPriority.DEFAULT 87 | ) -> Awaitable[SentLine]: 88 | pass 89 | def send(self, line: Line, priority=SendPriority.DEFAULT 90 | ) -> Awaitable[SentLine]: 91 | pass 92 | 93 | def wait_for(self, 94 | response: Union[IMatchResponse, Set[IMatchResponse]] 95 | ) -> Awaitable[Line]: 96 | pass 97 | 98 | def set_throttle(self, rate: int, time: float): 99 | pass 100 | 101 | def server_address(self) -> Tuple[str, int]: 102 | pass 103 | 104 | async def connect(self, 105 | transport: ITCPTransport, 106 | params: ConnectionParams): 107 | pass 108 | async def disconnect(self): 109 | pass 110 | 111 | def line_preread(self, line: Line): 112 | pass 113 | def line_presend(self, line: Line): 114 | pass 115 | async def line_read(self, line: Line): 116 | pass 117 | async def line_send(self, line: Line): 118 | pass 119 | async def sts_policy(self, sts: STSPolicy): 120 | pass 121 | async def resume_policy(self, resume: ResumePolicy): 122 | pass 123 | 124 | def cap_agreed(self, capability: ICapability) -> bool: 125 | pass 126 | def cap_available(self, capability: ICapability) -> Optional[str]: 127 | pass 128 | 129 | async def sasl_auth(self, sasl: SASLParams) -> bool: 130 | pass 131 | 132 | class IBot(object): 133 | def create_server(self, name: str) -> IServer: 134 | pass 135 | async def disconnected(self, server: IServer): 136 | pass 137 | 138 | async def disconnect(self, server: IServer): 139 | pass 140 | 141 | async def add_server(self, name: str, params: ConnectionParams) -> IServer: 142 | pass 143 | 144 | async def run(self): 145 | pass 146 | -------------------------------------------------------------------------------- /ircrobots/matching/params.py: -------------------------------------------------------------------------------- 1 | from re import compile as re_compile 2 | from typing import Optional, Pattern, Union 3 | from irctokens import Hostmask 4 | from ..interface import (IMatchResponseParam, IMatchResponseValueParam, 5 | IMatchResponseHostmask, IServer) 6 | from ..glob import Glob, compile as glob_compile 7 | from .. import formatting 8 | 9 | class Any(IMatchResponseParam): 10 | def __repr__(self) -> str: 11 | return "Any()" 12 | def match(self, server: IServer, arg: str) -> bool: 13 | return True 14 | ANY = Any() 15 | 16 | # NOT 17 | # FORMAT FOLD 18 | # REGEX 19 | # LITERAL 20 | 21 | class Literal(IMatchResponseValueParam): 22 | def __init__(self, value: str): 23 | self._value = value 24 | def __repr__(self) -> str: 25 | return f"{self._value!r}" 26 | 27 | def value(self, server: IServer) -> str: 28 | return self._value 29 | def set_value(self, value: str): 30 | self._value = value 31 | def match(self, server: IServer, arg: str) -> bool: 32 | return arg == self._value 33 | 34 | TYPE_MAYBELIT = Union[str, IMatchResponseParam] 35 | TYPE_MAYBELIT_VALUE = Union[str, IMatchResponseValueParam] 36 | def _assure_lit(value: TYPE_MAYBELIT_VALUE) -> IMatchResponseValueParam: 37 | if isinstance(value, str): 38 | return Literal(value) 39 | else: 40 | return value 41 | 42 | class Not(IMatchResponseParam): 43 | def __init__(self, param: IMatchResponseParam): 44 | self._param = param 45 | def __repr__(self) -> str: 46 | return f"Not({self._param!r})" 47 | def match(self, server: IServer, arg: str) -> bool: 48 | return not self._param.match(server, arg) 49 | 50 | class ParamValuePassthrough(IMatchResponseValueParam): 51 | _value: IMatchResponseValueParam 52 | def value(self, server: IServer): 53 | return self._value.value(server) 54 | def set_value(self, value: str): 55 | self._value.set_value(value) 56 | 57 | class Folded(ParamValuePassthrough): 58 | def __init__(self, value: TYPE_MAYBELIT_VALUE): 59 | self._value = _assure_lit(value) 60 | self._folded = False 61 | def __repr__(self) -> str: 62 | return f"Folded({self._value!r})" 63 | def match(self, server: IServer, arg: str) -> bool: 64 | if not self._folded: 65 | value = self.value(server) 66 | folded = server.casefold(value) 67 | self.set_value(folded) 68 | self._folded = True 69 | 70 | return self._value.match(server, server.casefold(arg)) 71 | 72 | class Formatless(IMatchResponseParam): 73 | def __init__(self, value: TYPE_MAYBELIT_VALUE): 74 | self._value = _assure_lit(value) 75 | def __repr__(self) -> str: 76 | return f"Formatless({self._value!r})" 77 | def match(self, server: IServer, arg: str) -> bool: 78 | strip = formatting.strip(arg) 79 | return self._value.match(server, strip) 80 | 81 | class Regex(IMatchResponseParam): 82 | def __init__(self, value: str): 83 | self._value = value 84 | self._pattern: Optional[Pattern] = None 85 | def match(self, server: IServer, arg: str) -> bool: 86 | if self._pattern is None: 87 | self._pattern = re_compile(self._value) 88 | return bool(self._pattern.search(arg)) 89 | 90 | class Self(IMatchResponseParam): 91 | def __repr__(self) -> str: 92 | return "Self()" 93 | def match(self, server: IServer, arg: str) -> bool: 94 | return server.casefold(arg) == server.nickname_lower 95 | SELF = Self() 96 | 97 | class MaskSelf(IMatchResponseHostmask): 98 | def __repr__(self) -> str: 99 | return "MaskSelf()" 100 | def match(self, server: IServer, hostmask: Hostmask): 101 | return server.casefold(hostmask.nickname) == server.nickname_lower 102 | MASK_SELF = MaskSelf() 103 | 104 | class Nick(IMatchResponseHostmask): 105 | def __init__(self, nickname: str): 106 | self._nickname = nickname 107 | self._folded: Optional[str] = None 108 | def __repr__(self) -> str: 109 | return f"Nick({self._nickname!r})" 110 | def match(self, server: IServer, hostmask: Hostmask): 111 | if self._folded is None: 112 | self._folded = server.casefold(self._nickname) 113 | return self._folded == server.casefold(hostmask.nickname) 114 | 115 | class Mask(IMatchResponseHostmask): 116 | def __init__(self, mask: str): 117 | self._mask = mask 118 | self._compiled: Optional[Glob] 119 | def __repr__(self) -> str: 120 | return f"Mask({self._mask!r})" 121 | def match(self, server: IServer, hostmask: Hostmask): 122 | if self._compiled is None: 123 | self._compiled = glob_compile(self._mask) 124 | return self._compiled.match(str(hostmask)) 125 | -------------------------------------------------------------------------------- /ircrobots/scram.py: -------------------------------------------------------------------------------- 1 | import base64, hashlib, hmac, os 2 | from enum import Enum 3 | from typing import Dict 4 | 5 | # IANA Hash Function Textual Names 6 | # https://tools.ietf.org/html/rfc5802#section-4 7 | # https://www.iana.org/assignments/hash-function-text-names/ 8 | # MD2 has been removed as it's unacceptably weak 9 | class SCRAMAlgorithm(Enum): 10 | MD5 = "MD5" 11 | SHA_1 = "SHA1" 12 | SHA_224 = "SHA224" 13 | SHA_256 = "SHA256" 14 | SHA_384 = "SHA384" 15 | SHA_512 = "SHA512" 16 | 17 | SCRAM_ERRORS = [ 18 | "invalid-encoding", 19 | "extensions-not-supported", # unrecognized 'm' value 20 | "invalid-proof", 21 | "channel-bindings-dont-match", 22 | "server-does-support-channel-binding", 23 | "channel-binding-not-supported", 24 | "unsupported-channel-binding-type", 25 | "unknown-user", 26 | "invalid-username-encoding", # invalid utf8 or bad SASLprep 27 | "no-resources" 28 | ] 29 | 30 | def _scram_nonce() -> bytes: 31 | return base64.b64encode(os.urandom(32)) 32 | def _scram_escape(s: bytes) -> bytes: 33 | return s.replace(b"=", b"=3D").replace(b",", b"=2C") 34 | def _scram_unescape(s: bytes) -> bytes: 35 | return s.replace(b"=3D", b"=").replace(b"=2C", b",") 36 | def _scram_xor(s1: bytes, s2: bytes) -> bytes: 37 | return bytes(a ^ b for a, b in zip(s1, s2)) 38 | 39 | class SCRAMState(Enum): 40 | NONE = 0 41 | CLIENT_FIRST = 1 42 | CLIENT_FINAL = 2 43 | SUCCESS = 3 44 | FAILURE = 4 45 | VERIFY_FAILURE = 5 46 | 47 | class SCRAMError(Exception): 48 | pass 49 | 50 | class SCRAMContext(object): 51 | def __init__(self, algo: SCRAMAlgorithm, 52 | username: str, 53 | password: str): 54 | self._algo = algo 55 | self._username = username.encode("utf8") 56 | self._password = password.encode("utf8") 57 | 58 | self.state = SCRAMState.NONE 59 | self.error = "" 60 | self.raw_error = "" 61 | 62 | self._client_first = b"" 63 | self._client_nonce = b"" 64 | 65 | self._salted_password = b"" 66 | self._auth_message = b"" 67 | 68 | def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]: 69 | pieces = (piece.split(b"=", 1) for piece in data.split(b",")) 70 | return dict((piece[0], piece[1]) for piece in pieces) 71 | 72 | def _hmac(self, key: bytes, msg: bytes) -> bytes: 73 | return hmac.new(key, msg, self._algo.value).digest() 74 | def _hash(self, msg: bytes) -> bytes: 75 | return hashlib.new(self._algo.value, msg).digest() 76 | 77 | def _constant_time_compare(self, b1: bytes, b2: bytes): 78 | return hmac.compare_digest(b1, b2) 79 | 80 | def _fail(self, error: str): 81 | self.raw_error = error 82 | if error in SCRAM_ERRORS: 83 | self.error = error 84 | else: 85 | self.error = "other-error" 86 | self.state = SCRAMState.FAILURE 87 | 88 | def client_first(self) -> bytes: 89 | self.state = SCRAMState.CLIENT_FIRST 90 | self._client_nonce = _scram_nonce() 91 | self._client_first = b"n=%s,r=%s" % ( 92 | _scram_escape(self._username), self._client_nonce) 93 | 94 | # n,,n=,r= 95 | return b"n,,%s" % self._client_first 96 | 97 | def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool: 98 | if b"e" in pieces: 99 | error = pieces[b"e"].decode("utf8") 100 | self._fail(error) 101 | return True 102 | else: 103 | return False 104 | 105 | def server_first(self, data: bytes) -> bytes: 106 | self.state = SCRAMState.CLIENT_FINAL 107 | 108 | pieces = self._get_pieces(data) 109 | if self._assert_error(pieces): 110 | return b"" 111 | 112 | nonce = pieces[b"r"] # server combines your nonce with it's own 113 | if (not nonce.startswith(self._client_nonce) or 114 | nonce == self._client_nonce): 115 | self._fail("nonce-unacceptable") 116 | return b"" 117 | 118 | salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded 119 | iterations = int(pieces[b"i"]) 120 | 121 | salted_password = hashlib.pbkdf2_hmac(self._algo.value, 122 | self._password, salt, iterations, dklen=None) 123 | self._salted_password = salted_password 124 | 125 | client_key = self._hmac(salted_password, b"Client Key") 126 | stored_key = self._hash(client_key) 127 | 128 | channel = base64.b64encode(b"n,,") 129 | auth_noproof = b"c=%s,r=%s" % (channel, nonce) 130 | auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof) 131 | self._auth_message = auth_message 132 | 133 | client_signature = self._hmac(stored_key, auth_message) 134 | client_proof_xor = _scram_xor(client_key, client_signature) 135 | client_proof = base64.b64encode(client_proof_xor) 136 | 137 | # c=,r=,p= 138 | return b"%s,p=%s" % (auth_noproof, client_proof) 139 | 140 | def server_final(self, data: bytes) -> bool: 141 | pieces = self._get_pieces(data) 142 | if self._assert_error(pieces): 143 | return False 144 | 145 | verifier = base64.b64decode(pieces[b"v"]) 146 | 147 | server_key = self._hmac(self._salted_password, b"Server Key") 148 | server_signature = self._hmac(server_key, self._auth_message) 149 | 150 | if server_signature == verifier: 151 | self.state = SCRAMState.SUCCESS 152 | return True 153 | else: 154 | self.state = SCRAMState.VERIFY_FAILURE 155 | return False 156 | -------------------------------------------------------------------------------- /examples/factoids.py: -------------------------------------------------------------------------------- 1 | import asyncio, re 2 | from argparse import ArgumentParser 3 | from typing import Dict, List, Optional 4 | 5 | from irctokens import build, Line 6 | from ircrobots import Bot as BaseBot 7 | from ircrobots import Server as BaseServer 8 | from ircrobots import ConnectionParams 9 | 10 | TRIGGER = "!" 11 | 12 | def _delims(s: str, delim: str): 13 | s_copy = list(s) 14 | while s_copy: 15 | char = s_copy.pop(0) 16 | if char == delim: 17 | if not s_copy: 18 | yield len(s)-(len(s_copy)+1) 19 | elif not s_copy.pop(0) == delim: 20 | yield len(s)-(len(s_copy)+2) 21 | 22 | def _sed(sed: str, s: str) -> Optional[str]: 23 | if len(sed) > 1: 24 | delim = sed[1] 25 | last = 0 26 | parts: List[str] = [] 27 | for i in _delims(sed, delim): 28 | parts.append(sed[last:i]) 29 | last = i+1 30 | if len(parts) == 4: 31 | break 32 | if last < (len(sed)): 33 | parts.append(sed[last:]) 34 | 35 | _, pattern, replace, *args = parts 36 | flags_s = (args or [""])[0] 37 | 38 | flags = re.I if "i" in flags_s else 0 39 | count = 0 if "g" in flags_s else 1 40 | 41 | for i in reversed(list(_delims(replace, "&"))): 42 | replace = replace[:i] + "\\g<0>" + replace[i+1:] 43 | 44 | try: 45 | compiled = re.compile(pattern, flags) 46 | except: 47 | return None 48 | return re.sub(compiled, replace, s, count) 49 | else: 50 | return None 51 | 52 | class Database: 53 | def __init__(self): 54 | self._settings: Dict[str, str] = {} 55 | 56 | async def get(self, context: str, setting: str) -> Optional[str]: 57 | return self._settings.get(setting, None) 58 | async def set(self, context: str, setting: str, value: str): 59 | self._settings[setting] = value 60 | async def rem(self, context: str, setting: str): 61 | if setting in self._settings: 62 | del self._settings[setting] 63 | 64 | class Server(BaseServer): 65 | def __init__(self, bot: Bot, name: str, channel: str, database: Database): 66 | super().__init__(bot, name) 67 | self._channel = channel 68 | self._database = database 69 | 70 | async def line_send(self, line: Line): 71 | print(f"> {line.format()}") 72 | 73 | async def line_read(self, line: Line): 74 | print(f"< {line.format()}") 75 | 76 | me = self.nickname_lower 77 | if line.command == "001": 78 | await self.send(build("JOIN", [self._channel])) 79 | 80 | if ( 81 | line.command == "PRIVMSG" and 82 | self.has_channel(line.params[0]) and 83 | not line.hostmask is None and 84 | not self.casefold(line.hostmask.nickname) == me and 85 | self.has_user(line.hostmask.nickname) and 86 | line.params[1].startswith(TRIGGER)): 87 | 88 | channel = self.channels[self.casefold(line.params[0])] 89 | user = self.users[self.casefold(line.hostmask.nickname)] 90 | cuser = channel.users[user.nickname_lower] 91 | text = line.params[1].replace(TRIGGER, "", 1) 92 | db_context = f"{self.name}:{channel.name}" 93 | 94 | name, _, text = text.partition(" ") 95 | action, _, text = text.partition(" ") 96 | name = name.lower() 97 | key = f"factoid-{name}" 98 | 99 | 100 | out = "" 101 | if not action or action == "@": 102 | value = await self._database.get(db_context, key) 103 | if not value is None: 104 | out = f"({name}) {value}" 105 | if action == "@" and text: 106 | target, _, _ = text.partition(" ") 107 | out = f"{target}: {out}" 108 | else: 109 | out = f"{user.nickname}: '{name}' not found" 110 | 111 | elif action in ["==", "~="]: 112 | if "o" in cuser.modes: 113 | value, _, _ = text.partition(" ") 114 | if action == "==": 115 | if value: 116 | await self._database.set(db_context, key, value) 117 | out = f"{user.nickname}: added factoid {name}" 118 | else: 119 | await self._database.rem(db_context, key) 120 | out = f"{user.nickname}: removed factoid {name}" 121 | elif action == "~=": 122 | current = await self._database.get(db_context, key) 123 | if current is None: 124 | out = f"{user.nickname}: '{name}' not found" 125 | elif value: 126 | changed = _sed(value, current) 127 | if not changed is None: 128 | await self._database.set( 129 | db_context, key, changed) 130 | out = (f"{user.nickname}: " 131 | f"changed '{name}' factoid") 132 | else: 133 | out = f"{user.nickname}: invalid sed" 134 | else: 135 | out = f"{user.nickname}: please provide a sed" 136 | else: 137 | out = f"{user.nickname}: you are not an op" 138 | 139 | 140 | else: 141 | out = f"{user.nickname}: unknown action '{action}'" 142 | await self.send(build("PRIVMSG", [line.params[0], out])) 143 | 144 | class Bot(BaseBot): 145 | def __init__(self, channel: str): 146 | super().__init__() 147 | self._channel = channel 148 | def create_server(self, name: str): 149 | return Server(self, name, self._channel, Database()) 150 | 151 | async def main(hostname: str, channel: str, nickname: str): 152 | bot = Bot(channel) 153 | 154 | params = ConnectionParams( 155 | nickname, 156 | hostname, 157 | 6697 158 | ) 159 | await bot.add_server("freenode", params) 160 | await bot.run() 161 | 162 | if __name__ == "__main__": 163 | parser = ArgumentParser(description="A simple IRC bot for factoids") 164 | parser.add_argument("hostname") 165 | parser.add_argument("channel") 166 | parser.add_argument("nickname") 167 | args = parser.parse_args() 168 | 169 | asyncio.run(main(args.hostname, args.channel, args.nickname)) 170 | -------------------------------------------------------------------------------- /ircrobots/ircv3.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Dict, Iterable, List, Optional, Tuple 3 | from dataclasses import dataclass 4 | from irctokens import build 5 | from ircstates.server import ServerDisconnectedException 6 | 7 | from .contexts import ServerContext 8 | from .matching import Response, ANY 9 | from .interface import ICapability 10 | from .params import ConnectionParams, STSPolicy, ResumePolicy 11 | from .security import TLSVerifyChain 12 | 13 | class Capability(ICapability): 14 | def __init__(self, 15 | ratified_name: Optional[str], 16 | draft_name: Optional[str]=None, 17 | alias: Optional[str]=None, 18 | depends_on: List[str]=[]): 19 | self.name = ratified_name 20 | self.draft = draft_name 21 | self.alias = alias or ratified_name 22 | self.depends_on = depends_on.copy() 23 | 24 | self._caps = [ratified_name, draft_name] 25 | 26 | def match(self, capability: str) -> bool: 27 | return capability in self._caps 28 | 29 | def available(self, capabilities: Iterable[str] 30 | ) -> Optional[str]: 31 | for cap in self._caps: 32 | if not cap is None and cap in capabilities: 33 | return cap 34 | else: 35 | return None 36 | 37 | def copy(self): 38 | return Capability( 39 | self.name, 40 | self.draft, 41 | alias=self.alias, 42 | depends_on=self.depends_on[:]) 43 | 44 | class MessageTag(object): 45 | def __init__(self, 46 | name: Optional[str], 47 | draft_name: Optional[str]=None): 48 | self.name = name 49 | self.draft = draft_name 50 | self._tags = [self.name, self.draft] 51 | 52 | def available(self, tags: Iterable[str]) -> Optional[str]: 53 | for tag in self._tags: 54 | if tag is not None and tag in tags: 55 | return tag 56 | else: 57 | return None 58 | 59 | def get(self, tags: Dict[str, str]) -> Optional[str]: 60 | name = self.available(tags) 61 | if name is not None: 62 | return tags[name] 63 | else: 64 | return None 65 | 66 | CAP_SASL = Capability("sasl") 67 | CAP_ECHO = Capability("echo-message") 68 | CAP_STS = Capability("sts", "draft/sts") 69 | CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume") 70 | 71 | CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") 72 | TAG_LABEL = MessageTag("label", "draft/label") 73 | LABEL_TAG_MAP = { 74 | "draft/labeled-response-0.2": "draft/label", 75 | "labeled-response": "label" 76 | } 77 | 78 | CAPS: List[ICapability] = [ 79 | Capability("multi-prefix"), 80 | Capability("chghost"), 81 | Capability("away-notify"), 82 | 83 | Capability("invite-notify"), 84 | Capability("account-tag"), 85 | Capability("account-notify"), 86 | Capability("extended-join"), 87 | 88 | Capability("message-tags", "draft/message-tags-0.2"), 89 | Capability("cap-notify"), 90 | Capability("batch"), 91 | 92 | Capability(None, "draft/rename", alias="rename"), 93 | Capability("setname", "draft/setname"), 94 | CAP_RESUME 95 | ] 96 | 97 | def _cap_dict(s: str) -> Dict[str, str]: 98 | d: Dict[str, str] = {} 99 | for token in s.split(","): 100 | key, _, value = token.partition("=") 101 | d[key] = value 102 | return d 103 | 104 | async def sts_transmute(params: ConnectionParams): 105 | if not params.sts is None and params.tls is None: 106 | now = time() 107 | since = (now-params.sts.created) 108 | if since <= params.sts.duration: 109 | params.port = params.sts.port 110 | params.tls = TLSVerifyChain() 111 | async def resume_transmute(params: ConnectionParams): 112 | if params.resume is not None: 113 | params.host = params.resume.address 114 | 115 | class HandshakeCancel(Exception): 116 | pass 117 | 118 | class CAPContext(ServerContext): 119 | async def on_ls(self, tokens: Dict[str, str]): 120 | await self._sts(tokens) 121 | 122 | caps = list(self.server.desired_caps)+CAPS 123 | 124 | if (not self.server.params.sasl is None and 125 | not CAP_SASL in caps): 126 | caps.append(CAP_SASL) 127 | 128 | matched = (c.available(tokens) for c in caps) 129 | cap_names = [name for name in matched if not name is None] 130 | 131 | if cap_names: 132 | await self.server.send(build("CAP", ["REQ", " ".join(cap_names)])) 133 | 134 | while cap_names: 135 | line = await self.server.wait_for({ 136 | Response("CAP", [ANY, "ACK"]), 137 | Response("CAP", [ANY, "NAK"]) 138 | }) 139 | 140 | current_caps = line.params[2].split(" ") 141 | for cap in current_caps: 142 | if cap in cap_names: 143 | cap_names.remove(cap) 144 | if CAP_RESUME.available(current_caps): 145 | await self.resume_token() 146 | 147 | if (self.server.cap_agreed(CAP_SASL) and 148 | not self.server.params.sasl is None): 149 | await self.server.sasl_auth(self.server.params.sasl) 150 | 151 | async def resume_token(self): 152 | line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY])) 153 | token = line.params[1] 154 | address, port = self.server.server_address() 155 | resume_policy = ResumePolicy(address, token) 156 | 157 | previous_policy = self.server.params.resume 158 | self.server.params.resume = resume_policy 159 | await self.server.resume_policy(resume_policy) 160 | 161 | if previous_policy is not None and not self.server.registered: 162 | await self.server.send(build("RESUME", [previous_policy.token])) 163 | line = await self.server.wait_for({ 164 | Response("RESUME", ["SUCCESS"]), 165 | Response("FAIL", ["RESUME"]) 166 | }) 167 | if line.command == "RESUME": 168 | raise HandshakeCancel() 169 | 170 | async def handshake(self): 171 | try: 172 | await self.on_ls(self.server.available_caps) 173 | except HandshakeCancel: 174 | return 175 | else: 176 | await self.server.send(build("CAP", ["END"])) 177 | 178 | async def _sts(self, tokens: Dict[str, str]): 179 | cap_sts = CAP_STS.available(tokens) 180 | if not cap_sts is None: 181 | sts_dict = _cap_dict(tokens[cap_sts]) 182 | params = self.server.params 183 | if not params.tls: 184 | if "port" in sts_dict: 185 | params.port = int(sts_dict["port"]) 186 | params.tls = TLSVerifyChain() 187 | 188 | await self.server.bot.disconnect(self.server) 189 | await self.server.bot.add_server(self.server.name, params) 190 | raise ServerDisconnectedException() 191 | 192 | elif "duration" in sts_dict: 193 | policy = STSPolicy( 194 | int(time()), 195 | params.port, 196 | int(sts_dict["duration"]), 197 | "preload" in sts_dict) 198 | await self.server.sts_policy(policy) 199 | 200 | -------------------------------------------------------------------------------- /ircrobots/sasl.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from enum import Enum 3 | from base64 import b64decode, b64encode 4 | from irctokens import build 5 | from ircstates.numerics import * 6 | 7 | from .matching import Responses, Response, ANY 8 | from .contexts import ServerContext 9 | from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal 10 | from .scram import SCRAMContext, SCRAMAlgorithm 11 | 12 | SASL_SCRAM_MECHANISMS = [ 13 | "SCRAM-SHA-512", 14 | "SCRAM-SHA-256", 15 | "SCRAM-SHA-1", 16 | ] 17 | SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"] 18 | 19 | class SASLResult(Enum): 20 | NONE = 0 21 | SUCCESS = 1 22 | FAILURE = 2 23 | ALREADY = 3 24 | 25 | class SASLError(Exception): 26 | pass 27 | class SASLUnknownMechanismError(SASLError): 28 | pass 29 | 30 | AUTH_BYTE_MAX = 400 31 | 32 | AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY]) 33 | 34 | NUMERICS_FAIL = Response(ERR_SASLFAIL) 35 | NUMERICS_INITIAL = Responses([ 36 | ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED 37 | ]) 38 | NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) 39 | 40 | def _b64e(s: str): 41 | return b64encode(s.encode("utf8")).decode("ascii") 42 | 43 | def _b64eb(s: bytes) -> str: 44 | # encode-from-bytes 45 | return b64encode(s).decode("ascii") 46 | def _b64db(s: str) -> bytes: 47 | # decode-to-bytes 48 | return b64decode(s) 49 | 50 | class SASLContext(ServerContext): 51 | async def from_params(self, params: SASLParams) -> SASLResult: 52 | if isinstance(params, SASLUserPass): 53 | return await self.userpass(params.username, params.password) 54 | elif isinstance(params, SASLSCRAM): 55 | return await self.scram(params.username, params.password) 56 | elif isinstance(params, SASLExternal): 57 | return await self.external() 58 | else: 59 | raise SASLUnknownMechanismError( 60 | "SASLParams given with unknown mechanism " 61 | f"{params.mechanism!r}") 62 | 63 | async def external(self) -> SASLResult: 64 | await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) 65 | line = await self.server.wait_for({ 66 | AUTHENTICATE_ANY, 67 | NUMERICS_INITIAL 68 | }) 69 | 70 | if line.command == "907": 71 | # we've done SASL already. cleanly abort 72 | return SASLResult.ALREADY 73 | elif line.command == "908": 74 | available = line.params[1].split(",") 75 | raise SASLUnknownMechanismError( 76 | "Server does not support SASL EXTERNAL " 77 | f"(it supports {available}") 78 | elif line.command == "AUTHENTICATE" and line.params[0] == "+": 79 | await self.server.send(build("AUTHENTICATE", ["+"])) 80 | 81 | line = await self.server.wait_for(NUMERICS_LAST) 82 | if line.command == "903": 83 | return SASLResult.SUCCESS 84 | return SASLResult.FAILURE 85 | 86 | async def plain(self, username: str, password: str) -> SASLResult: 87 | return await self.userpass(username, password, ["PLAIN"]) 88 | 89 | async def scram(self, username: str, password: str) -> SASLResult: 90 | return await self.userpass(username, password, SASL_SCRAM_MECHANISMS) 91 | 92 | async def userpass(self, 93 | username: str, 94 | password: str, 95 | mechanisms: List[str]=SASL_USERPASS_MECHANISMS 96 | ) -> SASLResult: 97 | def _common(server_mechs) -> List[str]: 98 | mechs: List[str] = [] 99 | for our_mech in mechanisms: 100 | if our_mech in server_mechs: 101 | mechs.append(our_mech) 102 | 103 | if mechs: 104 | return mechs 105 | else: 106 | raise SASLUnknownMechanismError( 107 | "No matching SASL mechanims. " 108 | f"(we want: {mechanisms} " 109 | f"server has: {server_mechs})") 110 | 111 | if self.server.available_caps["sasl"]: 112 | # CAP v3.2 tells us what mechs it supports 113 | available = self.server.available_caps["sasl"].split(",") 114 | match = _common(available) 115 | else: 116 | # CAP v3.1 does not. pick the pick and wait for 907 to inform us of 117 | # what mechanisms are supported 118 | match = mechanisms 119 | 120 | while match: 121 | await self.server.send(build("AUTHENTICATE", [match[0]])) 122 | line = await self.server.wait_for({ 123 | AUTHENTICATE_ANY, 124 | NUMERICS_INITIAL 125 | }) 126 | 127 | if line.command == "907": 128 | # we've done SASL already. cleanly abort 129 | return SASLResult.ALREADY 130 | elif line.command == "908": 131 | # prior to CAP v3.2 - ERR telling us which mechs are supported 132 | available = line.params[1].split(",") 133 | match = _common(available) 134 | await self.server.wait_for(NUMERICS_FAIL) 135 | elif line.command == "AUTHENTICATE" and line.params[0] == "+": 136 | auth_text = "" 137 | 138 | if match[0] == "PLAIN": 139 | auth_text = f"{username}\0{username}\0{password}" 140 | elif match[0].startswith("SCRAM-SHA-"): 141 | auth_text = await self._scram( 142 | match[0], username, password) 143 | 144 | if not auth_text == "+": 145 | auth_text = _b64e(auth_text) 146 | 147 | if auth_text: 148 | await self._send_auth_text(auth_text) 149 | 150 | line = await self.server.wait_for(NUMERICS_LAST) 151 | if line.command == "903": 152 | return SASLResult.SUCCESS 153 | elif line.command == "904": 154 | match.pop(0) 155 | else: 156 | break 157 | 158 | return SASLResult.FAILURE 159 | 160 | async def _scram(self, algo_str: str, 161 | username: str, 162 | password: str) -> str: 163 | algo_str_prep = algo_str.replace("SCRAM-", "", 1 164 | ).replace("-", "").upper() 165 | try: 166 | algo = SCRAMAlgorithm(algo_str_prep) 167 | except ValueError: 168 | raise ValueError("Unknown SCRAM algorithm '%s'" % algo_str_prep) 169 | scram = SCRAMContext(algo, username, password) 170 | 171 | client_first = _b64eb(scram.client_first()) 172 | await self._send_auth_text(client_first) 173 | line = await self.server.wait_for(AUTHENTICATE_ANY) 174 | 175 | server_first = _b64db(line.params[0]) 176 | client_final = _b64eb(scram.server_first(server_first)) 177 | if not client_final == "": 178 | await self._send_auth_text(client_final) 179 | line = await self.server.wait_for(AUTHENTICATE_ANY) 180 | 181 | server_final = _b64db(line.params[0]) 182 | verified = scram.server_final(server_final) 183 | #TODO PANIC if verified is false! 184 | return "+" 185 | else: 186 | return "" 187 | 188 | async def _send_auth_text(self, text: str): 189 | n = AUTH_BYTE_MAX 190 | chunks = [text[i:i+n] for i in range(0, len(text), n)] 191 | if len(chunks[-1]) == 400: 192 | chunks.append("+") 193 | 194 | for chunk in chunks: 195 | await self.server.send(build("AUTHENTICATE", [chunk])) 196 | -------------------------------------------------------------------------------- /ircrobots/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import Future, PriorityQueue 3 | from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List, 4 | Optional, Set, Tuple, Union) 5 | from collections import deque 6 | from time import monotonic 7 | 8 | import anyio 9 | from asyncio_rlock import RLock 10 | from asyncio_throttle import Throttler 11 | from async_timeout import timeout as timeout_ 12 | from ircstates import Emit, Channel, ChannelUser 13 | from ircstates.numerics import * 14 | from ircstates.server import ServerDisconnectedException 15 | from ircstates.names import Name 16 | from irctokens import build, Line, tokenise 17 | 18 | from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, 19 | CAP_LABEL, LABEL_TAG_MAP, resume_transmute) 20 | from .sasl import SASLContext, SASLResult 21 | from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, 22 | Folded) 23 | from .asyncs import MaybeAwait, WaitFor 24 | from .struct import Whois 25 | from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy 26 | from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, 27 | IMatchResponse) 28 | from .interface import ITCPTransport, ITCPReader, ITCPWriter 29 | 30 | THROTTLE_RATE = 4 # lines 31 | THROTTLE_TIME = 2 # seconds 32 | PING_TIMEOUT = 60 # seconds 33 | WAIT_TIMEOUT = 20 # seconds 34 | 35 | JOIN_ERR_FIRST = [ 36 | ERR_NOSUCHCHANNEL, 37 | ERR_BADCHANNAME, 38 | ERR_UNAVAILRESOURCE, 39 | ERR_TOOMANYCHANNELS, 40 | ERR_BANNEDFROMCHAN, 41 | ERR_INVITEONLYCHAN, 42 | ERR_BADCHANNELKEY, 43 | ERR_NEEDREGGEDNICK, 44 | ERR_THROTTLE 45 | ] 46 | 47 | class Server(IServer): 48 | _reader: ITCPReader 49 | _writer: ITCPWriter 50 | params: ConnectionParams 51 | 52 | def __init__(self, bot: IBot, name: str): 53 | super().__init__(name) 54 | self.bot = bot 55 | 56 | self.disconnected = False 57 | 58 | self.throttle = Throttler(rate_limit=100, period=1) 59 | 60 | self.sasl_state = SASLResult.NONE 61 | self.last_read = monotonic() 62 | 63 | self._sent_count: int = 0 64 | self._send_queue: PriorityQueue[SentLine] = PriorityQueue() 65 | self.desired_caps: Set[ICapability] = set([]) 66 | 67 | self._read_queue: Deque[Line] = deque() 68 | self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() 69 | 70 | self._ping_sent = False 71 | self._read_lguard = RLock() 72 | self.read_lock = self._read_lguard 73 | self._read_lwork = asyncio.Lock() 74 | self._wait_for = asyncio.Event() 75 | 76 | self._pending_who: Deque[str] = deque() 77 | self._alt_nicks: List[str] = [] 78 | 79 | def hostmask(self) -> str: 80 | hostmask = self.nickname 81 | if not self.username is None: 82 | hostmask += f"!{self.username}" 83 | if not self.hostname is None: 84 | hostmask += f"@{self.hostname}" 85 | return hostmask 86 | 87 | def send_raw(self, line: str, priority=SendPriority.DEFAULT 88 | ) -> Awaitable[SentLine]: 89 | return self.send(tokenise(line), priority) 90 | def send(self, 91 | line: Line, 92 | priority=SendPriority.DEFAULT 93 | ) -> Awaitable[SentLine]: 94 | 95 | self.line_presend(line) 96 | sent_line = SentLine(self._sent_count, priority, line) 97 | self._sent_count += 1 98 | 99 | label = self.cap_available(CAP_LABEL) 100 | if not label is None: 101 | tag = LABEL_TAG_MAP[label] 102 | if line.tags is None or not tag in line.tags: 103 | if line.tags is None: 104 | line.tags = {} 105 | line.tags[tag] = str(sent_line.id) 106 | 107 | self._send_queue.put_nowait(sent_line) 108 | 109 | return sent_line.future 110 | 111 | def set_throttle(self, rate: int, time: float): 112 | self.throttle.rate_limit = rate 113 | self.throttle.period = time 114 | 115 | def server_address(self) -> Tuple[str, int]: 116 | return self._writer.get_peer() 117 | 118 | async def connect(self, 119 | transport: ITCPTransport, 120 | params: ConnectionParams): 121 | await sts_transmute(params) 122 | await resume_transmute(params) 123 | 124 | reader, writer = await transport.connect( 125 | params.host, 126 | params.port, 127 | tls =params.tls, 128 | bindhost =params.bindhost) 129 | 130 | self._reader = reader 131 | self._writer = writer 132 | 133 | self.params = params 134 | await self.handshake() 135 | async def disconnect(self): 136 | if not self._writer is None: 137 | await self._writer.close() 138 | self._writer = None 139 | self._read_queue.clear() 140 | 141 | async def handshake(self): 142 | nickname = self.params.nickname 143 | username = self.params.username or nickname 144 | realname = self.params.realname or nickname 145 | 146 | alt_nicks = self.params.alt_nicknames 147 | if not alt_nicks: 148 | alt_nicks = [nickname+"_"*i for i in range(1, 4)] 149 | self._alt_nicks = alt_nicks 150 | 151 | # these must remain non-awaited; reading hasn't started yet 152 | if not self.params.password is None: 153 | self.send(build("PASS", [self.params.password])) 154 | self.send(build("CAP", ["LS", "302"])) 155 | self.send(build("NICK", [nickname])) 156 | self.send(build("USER", [username, "0", "*", realname])) 157 | 158 | # to be overridden 159 | def line_preread(self, line: Line): 160 | pass 161 | def line_presend(self, line: Line): 162 | pass 163 | async def line_read(self, line: Line): 164 | pass 165 | async def line_send(self, line: Line): 166 | pass 167 | async def sts_policy(self, sts: STSPolicy): 168 | pass 169 | async def resume_policy(self, resume: ResumePolicy): 170 | pass 171 | # /to be overriden 172 | 173 | async def _on_read(self, line: Line, emit: Optional[Emit]): 174 | if line.command == "PING": 175 | await self.send(build("PONG", line.params)) 176 | 177 | elif line.command == RPL_ENDOFWHO: 178 | chan = self.casefold(line.params[1]) 179 | if (self._pending_who and 180 | self._pending_who[0] == chan): 181 | self._pending_who.popleft() 182 | await self._next_who() 183 | elif (line.command in { 184 | ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE 185 | } and not self.registered): 186 | if self._alt_nicks: 187 | nick = self._alt_nicks.pop(0) 188 | await self.send(build("NICK", [nick])) 189 | else: 190 | await self.send(build("QUIT")) 191 | 192 | elif line.command in [RPL_ENDOFMOTD, ERR_NOMOTD]: 193 | # we didn't get the nickname we wanted. watch for it if we can 194 | if not self.nickname == self.params.nickname: 195 | target = self.params.nickname 196 | if self.isupport.monitor is not None: 197 | await self.send(build("MONITOR", ["+", target])) 198 | elif self.isupport.watch is not None: 199 | await self.send(build("WATCH", [f"+{target}"])) 200 | 201 | # has someone just stopped using the nickname we want? 202 | elif line.command == RPL_LOGOFF: 203 | await self._check_regain([line.params[1]]) 204 | elif line.command == RPL_MONOFFLINE: 205 | await self._check_regain(line.params[1].split(",")) 206 | elif (line.command in ["NICK", "QUIT"] and 207 | line.source is not None): 208 | await self._check_regain([line.hostmask.nickname]) 209 | 210 | elif emit is not None: 211 | if emit.command == RPL_WELCOME: 212 | await self.send(build("WHO", [self.nickname])) 213 | self.set_throttle(THROTTLE_RATE, THROTTLE_TIME) 214 | 215 | if self.params.autojoin: 216 | await self._batch_joins(self.params.autojoin) 217 | 218 | elif emit.command == "CAP": 219 | if emit.subcommand == "NEW": 220 | await self._cap_ls(emit) 221 | elif (emit.subcommand == "LS" and 222 | emit.finished): 223 | if not self.registered: 224 | await CAPContext(self).handshake() 225 | else: 226 | await self._cap_ls(emit) 227 | 228 | elif emit.command == "JOIN": 229 | if emit.self and not emit.channel is None: 230 | chan = emit.channel.name_lower 231 | await self.send(build("MODE", [chan])) 232 | 233 | modes = "".join(self.isupport.chanmodes.a_modes) 234 | await self.send(build("MODE", [chan, f"+{modes}"])) 235 | 236 | self._pending_who.append(chan) 237 | if len(self._pending_who) == 1: 238 | await self._next_who() 239 | 240 | await self.line_read(line) 241 | 242 | async def _check_regain(self, nicks: List[str]): 243 | for nick in nicks: 244 | if (self.casefold_equals(nick, self.params.nickname) and 245 | not self.nickname == self.params.nickname): 246 | await self.send(build("NICK", [self.params.nickname])) 247 | 248 | async def _batch_joins(self, 249 | channels: List[str], 250 | batch_n: int=10): 251 | #TODO: do as many JOINs in one line as we can fit 252 | #TODO: channel keys 253 | 254 | for i in range(0, len(channels), batch_n): 255 | batch = channels[i:i+batch_n] 256 | await self.send(build("JOIN", [",".join(batch)])) 257 | 258 | async def _next_who(self): 259 | if self._pending_who: 260 | chan = self._pending_who[0] 261 | if self.isupport.whox: 262 | await self.send(self.prepare_whox(chan)) 263 | else: 264 | await self.send(build("WHO", [chan])) 265 | 266 | async def _read_line(self, timeout: float) -> Optional[Line]: 267 | while True: 268 | if self._read_queue: 269 | return self._read_queue.popleft() 270 | 271 | try: 272 | async with timeout_(timeout): 273 | data = await self._reader.read(1024) 274 | except asyncio.TimeoutError: 275 | return None 276 | 277 | self.last_read = monotonic() 278 | lines = self.recv(data) 279 | for line in lines: 280 | self.line_preread(line) 281 | self._read_queue.append(line) 282 | 283 | async def _read_lines(self): 284 | while True: 285 | async with self._read_lguard: 286 | pass 287 | 288 | if not self._process_queue: 289 | async with self._read_lwork: 290 | read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT)) 291 | wait_aw = asyncio.create_task(self._wait_for.wait()) 292 | dones, notdones = await asyncio.wait( 293 | [read_aw, wait_aw], 294 | return_when=asyncio.FIRST_COMPLETED 295 | ) 296 | self._wait_for.clear() 297 | 298 | for done in dones: 299 | if isinstance(done.result(), Line): 300 | self._ping_sent = False 301 | line = done.result() 302 | emit = self.parse_tokens(line) 303 | self._process_queue.append((line, emit)) 304 | elif done.result() is None: 305 | if not self._ping_sent: 306 | await self.send(build("PING", ["hello"])) 307 | self._ping_sent = True 308 | else: 309 | await self.disconnect() 310 | raise ServerDisconnectedException() 311 | for notdone in notdones: 312 | notdone.cancel() 313 | 314 | else: 315 | line, emit = self._process_queue.popleft() 316 | await self._on_read(line, emit) 317 | 318 | async def wait_for(self, 319 | response: Union[IMatchResponse, Set[IMatchResponse]], 320 | sent_aw: Optional[Awaitable[SentLine]]=None, 321 | timeout: float=WAIT_TIMEOUT 322 | ) -> Line: 323 | 324 | response_obj: IMatchResponse 325 | if isinstance(response, set): 326 | response_obj = ResponseOr(*response) 327 | else: 328 | response_obj = response 329 | 330 | async with self._read_lguard: 331 | self._wait_for.set() 332 | async with self._read_lwork: 333 | async with timeout_(timeout): 334 | while True: 335 | line = await self._read_line(timeout) 336 | if line: 337 | self._ping_sent = False 338 | emit = self.parse_tokens(line) 339 | self._process_queue.append((line, emit)) 340 | if response_obj.match(self, line): 341 | return line 342 | 343 | async def _on_send_line(self, line: Line): 344 | if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and 345 | not self.cap_agreed(CAP_ECHO)): 346 | new_line = line.with_source(self.hostmask()) 347 | self._read_queue.append(new_line) 348 | 349 | async def _send_lines(self): 350 | while True: 351 | lines: List[SentLine] = [] 352 | 353 | while (not lines or 354 | (len(lines) < 5 and self._send_queue.qsize() > 0)): 355 | prio_line = await self._send_queue.get() 356 | lines.append(prio_line) 357 | 358 | for line in lines: 359 | async with self.throttle: 360 | self._writer.write( 361 | f"{line.line.format()}\r\n".encode("utf8")) 362 | 363 | await self._writer.drain() 364 | 365 | for line in lines: 366 | await self._on_send_line(line.line) 367 | await self.line_send(line.line) 368 | line.future.set_result(line) 369 | 370 | # CAP-related 371 | def cap_agreed(self, capability: ICapability) -> bool: 372 | return bool(self.cap_available(capability)) 373 | def cap_available(self, capability: ICapability) -> Optional[str]: 374 | return capability.available(self.agreed_caps) 375 | 376 | async def _cap_ls(self, emit: Emit): 377 | if not emit.tokens is None: 378 | tokens: Dict[str, str] = {} 379 | for token in emit.tokens: 380 | key, _, value = token.partition("=") 381 | tokens[key] = value 382 | await CAPContext(self).on_ls(tokens) 383 | 384 | async def sasl_auth(self, params: SASLParams) -> bool: 385 | if (self.sasl_state == SASLResult.NONE and 386 | self.cap_agreed(CAP_SASL)): 387 | 388 | res = await SASLContext(self).from_params(params) 389 | self.sasl_state = res 390 | return True 391 | else: 392 | return False 393 | # /CAP-related 394 | 395 | def send_nick(self, new_nick: str) -> Awaitable[bool]: 396 | fut = self.send(build("NICK", [new_nick])) 397 | async def _assure() -> bool: 398 | line = await self.wait_for({ 399 | Response("NICK", [Folded(new_nick)], source=MASK_SELF), 400 | Responses([ 401 | ERR_BANNICKCHANGE, 402 | ERR_NICKTOOFAST, 403 | ERR_CANTCHANGENICK 404 | ], [ANY]), 405 | Responses([ 406 | ERR_NICKNAMEINUSE, 407 | ERR_ERRONEUSNICKNAME, 408 | ERR_UNAVAILRESOURCE 409 | ], [ANY, Folded(new_nick)]) 410 | }, fut) 411 | return line.command == "NICK" 412 | return MaybeAwait(_assure) 413 | 414 | def send_join(self, 415 | name: str, 416 | key: Optional[str]=None 417 | ) -> Awaitable[Channel]: 418 | fut = self.send_joins([name], [] if key is None else [key]) 419 | 420 | async def _assure(): 421 | channels = await fut 422 | return channels[0] 423 | return MaybeAwait(_assure) 424 | def send_part(self, name: str): 425 | fut = self.send(build("PART", [name])) 426 | 427 | async def _assure(): 428 | line = await self.wait_for( 429 | Response("PART", [Folded(name)], source=MASK_SELF), 430 | fut 431 | ) 432 | return 433 | return MaybeAwait(_assure) 434 | 435 | def send_joins(self, 436 | names: List[str], 437 | keys: List[str]=[] 438 | ) -> Awaitable[List[Channel]]: 439 | 440 | folded_names = [self.casefold(name) for name in names] 441 | 442 | if not keys: 443 | fut = self.send(build("JOIN", [",".join(names)])) 444 | else: 445 | fut = self.send(build("JOIN", [",".join(names)]+keys)) 446 | 447 | async def _assure(): 448 | channels: List[Channel] = [] 449 | 450 | while folded_names: 451 | line = await self.wait_for({ 452 | Response(RPL_CHANNELMODEIS, [ANY, ANY]), 453 | Responses(JOIN_ERR_FIRST, [ANY, ANY]), 454 | Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]), 455 | Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]) 456 | }, fut) 457 | 458 | chan: Optional[str] = None 459 | if line.command == RPL_CHANNELMODEIS: 460 | chan = line.params[1] 461 | elif line.command in JOIN_ERR_FIRST: 462 | chan = line.params[1] 463 | elif line.command == ERR_USERONCHANNEL: 464 | chan = line.params[2] 465 | elif line.command == ERR_LINKCHANNEL: 466 | #XXX i dont like this 467 | chan = line.params[2] 468 | await self.wait_for( 469 | Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)]) 470 | ) 471 | channels.append(self.channels[self.casefold(chan)]) 472 | continue 473 | 474 | if chan is not None: 475 | folded = self.casefold(chan) 476 | if folded in folded_names: 477 | folded_names.remove(folded) 478 | channels.append(self.channels[folded]) 479 | 480 | return channels 481 | return MaybeAwait(_assure) 482 | 483 | def send_message(self, target: str, message: str 484 | ) -> Awaitable[Optional[str]]: 485 | fut = self.send(build("PRIVMSG", [target, message])) 486 | async def _assure(): 487 | line = await self.wait_for( 488 | Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), 489 | fut 490 | ) 491 | if line.command == "PRIVMSG": 492 | return line.params[1] 493 | else: 494 | return None 495 | return MaybeAwait(_assure) 496 | 497 | def send_whois(self, 498 | target: str, 499 | remote: bool=False 500 | ) -> Awaitable[Optional[Whois]]: 501 | args = [target] 502 | if remote: 503 | args.append(target) 504 | 505 | fut = self.send(build("WHOIS", args)) 506 | async def _assure() -> Optional[Whois]: 507 | folded = self.casefold(target) 508 | params = [ANY, Folded(folded)] 509 | 510 | obj = Whois() 511 | while True: 512 | line = await self.wait_for(Responses([ 513 | ERR_NOSUCHNICK, 514 | ERR_NOSUCHSERVER, 515 | RPL_WHOISUSER, 516 | RPL_WHOISSERVER, 517 | RPL_WHOISOPERATOR, 518 | RPL_WHOISIDLE, 519 | RPL_WHOISCHANNELS, 520 | RPL_WHOISHOST, 521 | RPL_WHOISACCOUNT, 522 | RPL_WHOISSECURE, 523 | RPL_ENDOFWHOIS 524 | ], params), fut) 525 | if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]: 526 | return None 527 | elif line.command == RPL_WHOISUSER: 528 | nick, user, host, _, real = line.params[1:] 529 | obj.nickname = nick 530 | obj.username = user 531 | obj.hostname = host 532 | obj.realname = real 533 | elif line.command == RPL_WHOISIDLE: 534 | idle, signon, _ = line.params[2:] 535 | obj.idle = int(idle) 536 | obj.signon = int(signon) 537 | elif line.command == RPL_WHOISACCOUNT: 538 | obj.account = line.params[2] 539 | elif line.command == RPL_WHOISCHANNELS: 540 | channels = list(filter(bool, line.params[2].split(" "))) 541 | if obj.channels is None: 542 | obj.channels = [] 543 | 544 | for i, channel in enumerate(channels): 545 | symbols = "" 546 | while channel[0] in self.isupport.prefix.prefixes: 547 | symbols += channel[0] 548 | channel = channel[1:] 549 | 550 | channel_user = ChannelUser( 551 | Name(obj.nickname, folded), 552 | Name(channel, self.casefold(channel)) 553 | ) 554 | for symbol in symbols: 555 | mode = self.isupport.prefix.from_prefix(symbol) 556 | if mode is not None: 557 | channel_user.modes.add(mode) 558 | 559 | obj.channels.append(channel_user) 560 | elif line.command == RPL_ENDOFWHOIS: 561 | return obj 562 | return MaybeAwait(_assure) 563 | --------------------------------------------------------------------------------