├── .gitignore ├── .mergify.yml ├── DEVELOPMENT.md ├── LICENSE ├── README.md ├── deploykit ├── __init__.py └── py.typed ├── examples └── basic.py ├── flake.lock ├── flake.nix ├── nix ├── default.nix └── shell.nix ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── command.py ├── conftest.py ├── getpwnam-preload.c ├── ports.py ├── root.py ├── sshd.py ├── test_local.py └── test_ssh.py └── treefmt.nix /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | *.py[cod] 7 | *.egg 8 | build 9 | htmlcov 10 | -------------------------------------------------------------------------------- /.mergify.yml: -------------------------------------------------------------------------------- 1 | queue_rules: 2 | - name: default 3 | merge_conditions: 4 | - check-success=Evaluate flake.nix 5 | - check-success=check treefmt [x86_64-linux] 6 | - check-success=devShell default [x86_64-linux] 7 | - check-success=package default [x86_64-linux] 8 | - check-success=package deploykit [x86_64-linux] 9 | defaults: 10 | actions: 11 | queue: 12 | allow_merging_configuration_change: true 13 | method: rebase 14 | pull_request_rules: 15 | - name: merge using the merge queue 16 | conditions: 17 | - base=main 18 | - label~=merge-queue|dependencies 19 | actions: 20 | queue: {} 21 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | You will need python3 and openssh installed at a minimum. 4 | Optionally the following python tools are required: 5 | 6 | - flake8 7 | - black 8 | - pytest 9 | - mypy 10 | 11 | Clone the project: 12 | 13 | ```console 14 | $ git clone git@github.com:numtide/deploykit.git 15 | ``` 16 | 17 | To run test, you need to install [pytest](https://pytest.org): 18 | 19 | ```console 20 | $ pytest ./tests 21 | ``` 22 | 23 | The project also is fully typechecked with [mypy](http://www.mypy-lang.org/). 24 | You can run the typechecking like this 25 | 26 | ```console 27 | $ MYPYPATH=$(pwd):$(pwd)/tests mypy --strict --namespace-packages --explicit-package-bases . 28 | ``` 29 | 30 | Furthermore all code is formated with black: 31 | 32 | ```console 33 | $ black . 34 | ``` 35 | 36 | and linted with flake8: 37 | 38 | ```console 39 | $ flake8 . 40 | ``` 41 | 42 | ## Logging 43 | 44 | We use python3s `logging` library. 45 | DeployHost-related logging starting with `[hostname]` is handled by a logger called `deploykit.command`, other logging is handled by the `deploykit.main` logger. 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Numtide 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deploykit 2 | 3 | A Python library that executes commands in parallel, locally and remotely, 4 | over a group of hosts. 5 | 6 | This library has been extracted from existing projects where it was used as a 7 | basis to create deployment scripts. It's a bit like a mini Ansible, without 8 | the YAML overhead, and usable as a simple composable library. 9 | 10 | Here are some important facts about deploykit: 11 | 12 | - Local commands and user-defined function: In contrast to many other libraries 13 | in the space, deploykit allows to also run commands and user-defined 14 | functions locally in parallel for each host. 15 | - OpenSSH: To retain compatibility with existing configuration, deploykit uses 16 | the openssh executable for running the commands remotely. 17 | - Threaded: Deploykit starts a thread per target host run commands and 18 | user-defined functions and collects their results for inspection. To run 19 | commands, deploykit wraps around python's subprocess API. 20 | - Clean output: command outputs are prefixed by the hostname of the target. 21 | 22 | ## Example 23 | 24 | ```python 25 | from deploykit import parse_hosts 26 | import subprocess 27 | 28 | hosts = parse_hosts("server1,server2,server3") 29 | runs = hosts.run("uptime", stdout=subprocess.PIPE) 30 | for r in runs: 31 | print(f"The uptime of {r.host.hostname} is {r.result.stdout}") 32 | ``` 33 | 34 | A more comprehensive example explaining all the concepts of the API can be found 35 | [here](https://github.com/numtide/deploykit/blob/main/examples/basic.py). 36 | 37 | ## Differences to other libraries and tools 38 | 39 | - [Fabric](http://fabfile.org): 40 | - Deploykit took inspiration from fabric and addresses its limitation that 41 | local commands cannot be executed in parallel for a range of hosts (i.e. rsync). 42 | By also allowing to run a function per host it provides higher flexibility. Fabric 43 | uses [pyinvoke]() as a task runner frontend, so deploykit can be also used in 44 | combination with the same. 45 | - [Ansible](https://ansible.org): 46 | - Deploykit is more lightweight and has a faster startup time. 47 | - Using python for task definitions allows for more flexibility than YAML. 48 | - Use ansible if you need declarative configuration management. Use deploykit 49 | if you want to imperatively quickly execute a series of commands on a number 50 | of hosts. 51 | 52 | ## Contributing 53 | 54 | Contributions and discussions are welcome. Please make sure to send a WIP PR or 55 | issue before doing large refactors, so your work doesn't get wasted (in case of 56 | disagreement). In our [development](DEVELOPMENT.md) documentation we explain how 57 | to get starte on deploykit. 58 | 59 | ## License 60 | 61 | This project is copyright Numtide and contributors, and licensed under the 62 | [MIT](LICENSE). 63 | -------------------------------------------------------------------------------- /deploykit/__init__.py: -------------------------------------------------------------------------------- 1 | import fcntl 2 | import logging 3 | import math 4 | import os 5 | import select 6 | import shlex 7 | import subprocess 8 | import sys 9 | import time 10 | from contextlib import ExitStack, contextmanager 11 | from enum import Enum 12 | from pathlib import Path 13 | from shlex import quote 14 | from threading import Thread 15 | from typing import ( 16 | IO, 17 | Any, 18 | Callable, 19 | Dict, 20 | Generic, 21 | Iterator, 22 | List, 23 | Literal, 24 | Optional, 25 | Tuple, 26 | TypeVar, 27 | Union, 28 | overload, 29 | ) 30 | 31 | # https://no-color.org 32 | DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != "" 33 | 34 | 35 | def ansi_color(color: int) -> str: 36 | return f"\x1b[{color}m" 37 | 38 | 39 | class CommandFormatter(logging.Formatter): 40 | """ 41 | print errors in red and warnings in yellow 42 | """ 43 | 44 | def __init__(self) -> None: 45 | super().__init__( 46 | "%(prefix_color)s[%(command_prefix)s]%(color_reset)s %(color)s%(message)s%(color_reset)s" 47 | ) 48 | self.hostnames: List[str] = [] 49 | self.hostname_color_offset = 1 # first host shouldn't get agressive red 50 | 51 | def formatMessage(self, record: logging.LogRecord) -> str: 52 | colorcode = 0 53 | if record.levelno == logging.ERROR: 54 | colorcode = 31 # red 55 | if record.levelno == logging.WARN: 56 | colorcode = 33 # yellow 57 | 58 | color, prefix_color, color_reset = "", "", "" 59 | if not DISABLE_COLOR: 60 | command_prefix = getattr(record, "command_prefix", "") 61 | color = ansi_color(colorcode) 62 | prefix_color = ansi_color(self.hostname_colorcode(command_prefix)) 63 | color_reset = "\x1b[0m" 64 | 65 | setattr(record, "color", color) 66 | setattr(record, "prefix_color", prefix_color) 67 | setattr(record, "color_reset", color_reset) 68 | 69 | return super().formatMessage(record) 70 | 71 | def hostname_colorcode(self, hostname: str) -> int: 72 | try: 73 | index = self.hostnames.index(hostname) 74 | except ValueError: 75 | self.hostnames += [hostname] 76 | index = self.hostnames.index(hostname) 77 | return 31 + (index + self.hostname_color_offset) % 7 78 | 79 | 80 | def setup_loggers() -> Tuple[logging.Logger, logging.Logger]: 81 | # If we use the default logger here (logging.error etc) or a logger called 82 | # "deploykit", then cmdlog messages are also posted on the default logger. 83 | # To avoid this message duplication, we set up a main and command logger 84 | # and use a "deploykit" main logger. 85 | kitlog = logging.getLogger("deploykit.main") 86 | kitlog.setLevel(logging.INFO) 87 | 88 | ch = logging.StreamHandler() 89 | ch.setLevel(logging.INFO) 90 | ch.setFormatter(logging.Formatter()) 91 | 92 | kitlog.addHandler(ch) 93 | 94 | # use specific logger for command outputs 95 | cmdlog = logging.getLogger("deploykit.command") 96 | cmdlog.setLevel(logging.INFO) 97 | 98 | ch = logging.StreamHandler() 99 | ch.setLevel(logging.INFO) 100 | ch.setFormatter(CommandFormatter()) 101 | 102 | cmdlog.addHandler(ch) 103 | return (kitlog, cmdlog) 104 | 105 | 106 | # loggers for: general deploykit, command output 107 | kitlog, cmdlog = setup_loggers() 108 | 109 | info = kitlog.info 110 | warn = kitlog.warning 111 | error = kitlog.error 112 | 113 | 114 | @contextmanager 115 | def _pipe() -> Iterator[Tuple[IO[str], IO[str]]]: 116 | (pipe_r, pipe_w) = os.pipe() 117 | read_end = os.fdopen(pipe_r, "r") 118 | write_end = os.fdopen(pipe_w, "w") 119 | 120 | try: 121 | fl = fcntl.fcntl(read_end, fcntl.F_GETFL) 122 | fcntl.fcntl(read_end, fcntl.F_SETFL, fl | os.O_NONBLOCK) 123 | 124 | yield (read_end, write_end) 125 | finally: 126 | read_end.close() 127 | write_end.close() 128 | 129 | 130 | FILE = Union[None, int] 131 | 132 | # Seconds until a message is printed when _run produces no output. 133 | NO_OUTPUT_TIMEOUT = 20 134 | 135 | 136 | class HostKeyCheck(Enum): 137 | # Strictly check ssh host keys, prompt for unknown ones 138 | STRICT = 0 139 | # Trust on ssh keys on first use 140 | TOFU = 1 141 | # Do not check ssh host keys 142 | NONE = 2 143 | 144 | 145 | class DeployHost: 146 | def __init__( 147 | self, 148 | host: str, 149 | user: Optional[str] = None, 150 | port: Optional[int] = None, 151 | key: Optional[str] = None, 152 | forward_agent: bool = False, 153 | command_prefix: Optional[str] = None, 154 | host_key_check: HostKeyCheck = HostKeyCheck.STRICT, 155 | meta: Dict[str, Any] = {}, 156 | verbose_ssh: bool = False, 157 | extra_ssh_opts: List[str] = [], 158 | ) -> None: 159 | """ 160 | Creates a DeployHost 161 | @host the hostname to connect to via ssh 162 | @port the port to connect to via ssh 163 | @forward_agent: wheter to forward ssh agent 164 | @command_prefix: string to prefix each line of the command output with, defaults to host 165 | @host_key_check: wether to check ssh host keys 166 | @verbose_ssh: Enables verbose logging on ssh connections 167 | @meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function` 168 | @extra_ssh_opts: Additional SSH options to use while connecting 169 | """ 170 | self.host = host 171 | self.user = user 172 | self.port = port 173 | self.key = key 174 | if command_prefix: 175 | self.command_prefix = command_prefix 176 | else: 177 | self.command_prefix = host 178 | self.forward_agent = forward_agent 179 | self.host_key_check = host_key_check 180 | self.meta = meta 181 | self.verbose_ssh = verbose_ssh 182 | self.extra_ssh_opts = extra_ssh_opts 183 | 184 | def _prefix_output( 185 | self, 186 | displayed_cmd: str, 187 | print_std_fd: Optional[IO[str]], 188 | print_err_fd: Optional[IO[str]], 189 | stdout: Optional[IO[str]], 190 | stderr: Optional[IO[str]], 191 | timeout: float = math.inf, 192 | ) -> Tuple[str, str]: 193 | rlist = [] 194 | if print_std_fd is not None: 195 | rlist.append(print_std_fd) 196 | if print_err_fd is not None: 197 | rlist.append(print_err_fd) 198 | if stdout is not None: 199 | rlist.append(stdout) 200 | 201 | if stderr is not None: 202 | rlist.append(stderr) 203 | 204 | print_std_buf = "" 205 | print_err_buf = "" 206 | stdout_buf = "" 207 | stderr_buf = "" 208 | 209 | start = time.time() 210 | last_output = time.time() 211 | while len(rlist) != 0: 212 | r, _, _ = select.select(rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)) 213 | 214 | def print_from( 215 | print_fd: IO[str], print_buf: str, is_err: bool = False 216 | ) -> Tuple[float, str]: 217 | read = os.read(print_fd.fileno(), 4096) 218 | if len(read) == 0: 219 | rlist.remove(print_fd) 220 | print_buf += read.decode("utf-8") 221 | if (read == b"" and len(print_buf) != 0) or "\n" in print_buf: 222 | # print and empty the print_buf, if the stream is draining, 223 | # but there is still something in the buffer or on newline. 224 | lines = print_buf.rstrip("\n").split("\n") 225 | for line in lines: 226 | if not is_err: 227 | cmdlog.info( 228 | line, extra=dict(command_prefix=self.command_prefix) 229 | ) 230 | pass 231 | else: 232 | cmdlog.error( 233 | line, extra=dict(command_prefix=self.command_prefix) 234 | ) 235 | print_buf = "" 236 | last_output = time.time() 237 | return (last_output, print_buf) 238 | 239 | if print_std_fd in r and print_std_fd is not None: 240 | (last_output, print_std_buf) = print_from( 241 | print_std_fd, print_std_buf, is_err=False 242 | ) 243 | if print_err_fd in r and print_err_fd is not None: 244 | (last_output, print_err_buf) = print_from( 245 | print_err_fd, print_err_buf, is_err=True 246 | ) 247 | 248 | now = time.time() 249 | elapsed = now - start 250 | if now - last_output > NO_OUTPUT_TIMEOUT: 251 | elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed)) 252 | cmdlog.warn( 253 | f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)", 254 | extra=dict(command_prefix=self.command_prefix), 255 | ) 256 | 257 | def handle_fd(fd: Optional[IO[Any]]) -> str: 258 | if fd and fd in r: 259 | read = os.read(fd.fileno(), 4096) 260 | if len(read) == 0: 261 | rlist.remove(fd) 262 | else: 263 | return read.decode("utf-8") 264 | return "" 265 | 266 | stdout_buf += handle_fd(stdout) 267 | stderr_buf += handle_fd(stderr) 268 | 269 | if now - last_output >= timeout: 270 | break 271 | return stdout_buf, stderr_buf 272 | 273 | def _run( 274 | self, 275 | cmd: List[str], 276 | displayed_cmd: str, 277 | shell: bool, 278 | stdout: FILE = None, 279 | stderr: FILE = None, 280 | extra_env: Dict[str, str] = {}, 281 | cwd: Union[None, str, Path] = None, 282 | check: bool = True, 283 | timeout: float = math.inf, 284 | ) -> subprocess.CompletedProcess[str]: 285 | with ExitStack() as stack: 286 | read_std_fd, write_std_fd = (None, None) 287 | read_err_fd, write_err_fd = (None, None) 288 | 289 | if stdout is None or stderr is None: 290 | read_std_fd, write_std_fd = stack.enter_context(_pipe()) 291 | read_err_fd, write_err_fd = stack.enter_context(_pipe()) 292 | 293 | if stdout is None: 294 | stdout_read = None 295 | stdout_write = write_std_fd 296 | elif stdout == subprocess.PIPE: 297 | stdout_read, stdout_write = stack.enter_context(_pipe()) 298 | else: 299 | raise Exception(f"unsupported value for stdout parameter: {stdout}") 300 | 301 | if stderr is None: 302 | stderr_read = None 303 | stderr_write = write_err_fd 304 | elif stderr == subprocess.PIPE: 305 | stderr_read, stderr_write = stack.enter_context(_pipe()) 306 | else: 307 | raise Exception(f"unsupported value for stderr parameter: {stderr}") 308 | 309 | env = os.environ.copy() 310 | env.update(extra_env) 311 | 312 | with subprocess.Popen( 313 | cmd, 314 | text=True, 315 | shell=shell, 316 | stdout=stdout_write, 317 | stderr=stderr_write, 318 | env=env, 319 | cwd=cwd, 320 | ) as p: 321 | if write_std_fd is not None: 322 | write_std_fd.close() 323 | if write_err_fd is not None: 324 | write_err_fd.close() 325 | if stdout == subprocess.PIPE: 326 | assert stdout_write is not None 327 | stdout_write.close() 328 | if stderr == subprocess.PIPE: 329 | assert stderr_write is not None 330 | stderr_write.close() 331 | 332 | start = time.time() 333 | stdout_data, stderr_data = self._prefix_output( 334 | displayed_cmd, 335 | read_std_fd, 336 | read_err_fd, 337 | stdout_read, 338 | stderr_read, 339 | timeout, 340 | ) 341 | try: 342 | ret = p.wait(timeout=max(0, timeout - (time.time() - start))) 343 | except subprocess.TimeoutExpired: 344 | p.kill() 345 | raise 346 | if ret != 0: 347 | if check: 348 | raise subprocess.CalledProcessError( 349 | ret, cmd=cmd, output=stdout_data, stderr=stderr_data 350 | ) 351 | else: 352 | cmdlog.warning( 353 | f"[Command failed: {ret}] {displayed_cmd}", 354 | extra=dict(command_prefix=self.command_prefix), 355 | ) 356 | return subprocess.CompletedProcess( 357 | cmd, ret, stdout=stdout_data, stderr=stderr_data 358 | ) 359 | raise RuntimeError("unreachable") 360 | 361 | def run_local( 362 | self, 363 | cmd: Union[str, List[str]], 364 | stdout: FILE = None, 365 | stderr: FILE = None, 366 | extra_env: Dict[str, str] = {}, 367 | cwd: Union[None, str, Path] = None, 368 | check: bool = True, 369 | timeout: float = math.inf, 370 | ) -> subprocess.CompletedProcess[str]: 371 | """ 372 | Command to run locally for the host 373 | 374 | @cmd the commmand to run 375 | @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE 376 | @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE 377 | @extra_env environment variables to override whe running the command 378 | @cwd current working directory to run the process in 379 | @timeout: Timeout in seconds for the command to complete 380 | 381 | @return subprocess.CompletedProcess result of the command 382 | """ 383 | shell = False 384 | if isinstance(cmd, str): 385 | cmd = [cmd] 386 | shell = True 387 | displayed_cmd = shlex.join(cmd) 388 | cmdlog.info( 389 | f"$ {displayed_cmd}", extra=dict(command_prefix=self.command_prefix) 390 | ) 391 | return self._run( 392 | cmd, 393 | displayed_cmd, 394 | shell=shell, 395 | stdout=stdout, 396 | stderr=stderr, 397 | extra_env=extra_env, 398 | cwd=cwd, 399 | check=check, 400 | timeout=timeout, 401 | ) 402 | 403 | def run( 404 | self, 405 | cmd: Union[str, List[str]], 406 | stdout: FILE = None, 407 | stderr: FILE = None, 408 | become_root: bool = False, 409 | extra_env: Dict[str, str] = {}, 410 | cwd: Union[None, str, Path] = None, 411 | check: bool = True, 412 | verbose_ssh: bool = False, 413 | timeout: float = math.inf, 414 | ) -> subprocess.CompletedProcess[str]: 415 | """ 416 | Command to run on the host via ssh 417 | 418 | @cmd the commmand to run 419 | @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE 420 | @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE 421 | @become_root if the ssh_user is not root than sudo is prepended 422 | @extra_env environment variables to override whe running the command 423 | @cwd current working directory to run the process in 424 | @verbose_ssh: Enables verbose logging on ssh connections 425 | @timeout: Timeout in seconds for the command to complete 426 | 427 | @return subprocess.CompletedProcess result of the ssh command 428 | """ 429 | sudo = "" 430 | if become_root and self.user != "root": 431 | sudo = "sudo -- " 432 | vars = [] 433 | for k, v in extra_env.items(): 434 | vars.append(f"{shlex.quote(k)}={shlex.quote(v)}") 435 | 436 | displayed_cmd = "" 437 | export_cmd = "" 438 | if vars: 439 | export_cmd = f"export {shlex.join(vars)}; " 440 | displayed_cmd += export_cmd 441 | if isinstance(cmd, list): 442 | displayed_cmd += shlex.join(cmd) 443 | else: 444 | displayed_cmd += cmd 445 | cmdlog.info( 446 | f"$ {displayed_cmd}", extra=dict(command_prefix=self.command_prefix) 447 | ) 448 | 449 | if self.user is not None: 450 | ssh_target = f"{self.user}@{self.host}" 451 | else: 452 | ssh_target = self.host 453 | 454 | ssh_opts = ["-A"] if self.forward_agent else [] 455 | if self.port: 456 | ssh_opts.extend(["-p", str(self.port)]) 457 | if self.key: 458 | ssh_opts.extend(["-i", self.key]) 459 | 460 | if self.host_key_check != HostKeyCheck.STRICT: 461 | ssh_opts.extend(["-o", "StrictHostKeyChecking=no"]) 462 | if self.host_key_check == HostKeyCheck.NONE: 463 | ssh_opts.extend(["-o", "UserKnownHostsFile=/dev/null"]) 464 | if verbose_ssh or self.verbose_ssh: 465 | ssh_opts.extend(["-v"]) 466 | 467 | bash_cmd = export_cmd 468 | bash_args = [] 469 | if isinstance(cmd, list): 470 | bash_cmd += 'exec "$@"' 471 | bash_args += cmd 472 | else: 473 | bash_cmd += cmd 474 | # FIXME we assume bash to be present here? Should be documented... 475 | ssh_cmd = ( 476 | ["ssh", ssh_target] 477 | + ssh_opts 478 | + self.extra_ssh_opts 479 | + [ 480 | "--", 481 | f"{sudo}bash -c {quote(bash_cmd)} -- {shlex.join(bash_args)}", 482 | ] 483 | ) 484 | return self._run( 485 | ssh_cmd, 486 | displayed_cmd, 487 | shell=False, 488 | stdout=stdout, 489 | stderr=stderr, 490 | cwd=cwd, 491 | check=check, 492 | timeout=timeout, 493 | ) 494 | 495 | 496 | T = TypeVar("T") 497 | 498 | 499 | class HostResult(Generic[T]): 500 | def __init__(self, host: DeployHost, result: Union[T, Exception]) -> None: 501 | self.host = host 502 | self._result = result 503 | 504 | @property 505 | def error(self) -> Optional[Exception]: 506 | """ 507 | Returns an error if the command failed 508 | """ 509 | if isinstance(self._result, Exception): 510 | return self._result 511 | return None 512 | 513 | @property 514 | def result(self) -> T: 515 | """ 516 | Unwrap the result 517 | """ 518 | if isinstance(self._result, Exception): 519 | raise self._result 520 | return self._result 521 | 522 | 523 | DeployResults = List[HostResult[subprocess.CompletedProcess[str]]] 524 | 525 | 526 | def _worker( 527 | func: Callable[[DeployHost], T], 528 | host: DeployHost, 529 | results: List[HostResult[T]], 530 | idx: int, 531 | ) -> None: 532 | try: 533 | results[idx] = HostResult(host, func(host)) 534 | except Exception as e: 535 | kitlog.exception(e) 536 | results[idx] = HostResult(host, e) 537 | 538 | 539 | class DeployGroup: 540 | def __init__(self, hosts: List[DeployHost]) -> None: 541 | self.hosts = hosts 542 | 543 | def _run_local( 544 | self, 545 | cmd: Union[str, List[str]], 546 | host: DeployHost, 547 | results: DeployResults, 548 | stdout: FILE = None, 549 | stderr: FILE = None, 550 | extra_env: Dict[str, str] = {}, 551 | cwd: Union[None, str, Path] = None, 552 | check: bool = True, 553 | verbose_ssh: bool = False, 554 | timeout: float = math.inf, 555 | ) -> None: 556 | try: 557 | proc = host.run_local( 558 | cmd, 559 | stdout=stdout, 560 | stderr=stderr, 561 | extra_env=extra_env, 562 | cwd=cwd, 563 | check=check, 564 | timeout=timeout, 565 | ) 566 | results.append(HostResult(host, proc)) 567 | except Exception as e: 568 | kitlog.exception(e) 569 | results.append(HostResult(host, e)) 570 | 571 | def _run_remote( 572 | self, 573 | cmd: Union[str, List[str]], 574 | host: DeployHost, 575 | results: DeployResults, 576 | stdout: FILE = None, 577 | stderr: FILE = None, 578 | extra_env: Dict[str, str] = {}, 579 | cwd: Union[None, str, Path] = None, 580 | check: bool = True, 581 | verbose_ssh: bool = False, 582 | timeout: float = math.inf, 583 | ) -> None: 584 | try: 585 | proc = host.run( 586 | cmd, 587 | stdout=stdout, 588 | stderr=stderr, 589 | extra_env=extra_env, 590 | cwd=cwd, 591 | check=check, 592 | verbose_ssh=verbose_ssh, 593 | timeout=timeout, 594 | ) 595 | results.append(HostResult(host, proc)) 596 | except Exception as e: 597 | kitlog.exception(e) 598 | results.append(HostResult(host, e)) 599 | 600 | def _reraise_errors(self, results: List[HostResult[Any]]) -> None: 601 | errors = 0 602 | for result in results: 603 | e = result.error 604 | if e: 605 | cmdlog.error( 606 | f"failed with: {e}", 607 | extra=dict(command_prefix=result.host.command_prefix), 608 | ) 609 | errors += 1 610 | if errors > 0: 611 | raise Exception( 612 | f"{errors} hosts failed with an error. Check the logs above" 613 | ) 614 | 615 | def _run( 616 | self, 617 | cmd: Union[str, List[str]], 618 | local: bool = False, 619 | stdout: FILE = None, 620 | stderr: FILE = None, 621 | extra_env: Dict[str, str] = {}, 622 | cwd: Union[None, str, Path] = None, 623 | check: bool = True, 624 | verbose_ssh: bool = False, 625 | timeout: float = math.inf, 626 | ) -> DeployResults: 627 | results: DeployResults = [] 628 | threads = [] 629 | for host in self.hosts: 630 | fn = self._run_local if local else self._run_remote 631 | thread = Thread( 632 | target=fn, 633 | kwargs=dict( 634 | results=results, 635 | cmd=cmd, 636 | host=host, 637 | stdout=stdout, 638 | stderr=stderr, 639 | extra_env=extra_env, 640 | cwd=cwd, 641 | check=check, 642 | verbose_ssh=verbose_ssh, 643 | timeout=timeout, 644 | ), 645 | ) 646 | thread.start() 647 | threads.append(thread) 648 | 649 | for thread in threads: 650 | thread.join() 651 | 652 | if check: 653 | self._reraise_errors(results) 654 | 655 | return results 656 | 657 | def run( 658 | self, 659 | cmd: Union[str, List[str]], 660 | stdout: FILE = None, 661 | stderr: FILE = None, 662 | extra_env: Dict[str, str] = {}, 663 | cwd: Union[None, str, Path] = None, 664 | check: bool = True, 665 | verbose_ssh: bool = False, 666 | timeout: float = math.inf, 667 | ) -> DeployResults: 668 | """ 669 | Command to run on the remote host via ssh 670 | @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE 671 | @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE 672 | @cwd current working directory to run the process in 673 | @verbose_ssh: Enables verbose logging on ssh connections 674 | @timeout: Timeout in seconds for the command to complete 675 | 676 | @return a lists of tuples containing DeployNode and the result of the command for this DeployNode 677 | """ 678 | return self._run( 679 | cmd, 680 | stdout=stdout, 681 | stderr=stderr, 682 | extra_env=extra_env, 683 | cwd=cwd, 684 | check=check, 685 | verbose_ssh=verbose_ssh, 686 | timeout=timeout, 687 | ) 688 | 689 | def run_local( 690 | self, 691 | cmd: Union[str, List[str]], 692 | stdout: FILE = None, 693 | stderr: FILE = None, 694 | extra_env: Dict[str, str] = {}, 695 | cwd: Union[None, str, Path] = None, 696 | check: bool = True, 697 | timeout: float = math.inf, 698 | ) -> DeployResults: 699 | """ 700 | Command to run locally for each host in the group in parallel 701 | @cmd the commmand to run 702 | @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE 703 | @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE 704 | @cwd current working directory to run the process in 705 | @extra_env environment variables to override whe running the command 706 | @timeout: Timeout in seconds for the command to complete 707 | 708 | @return a lists of tuples containing DeployNode and the result of the command for this DeployNode 709 | """ 710 | return self._run( 711 | cmd, 712 | local=True, 713 | stdout=stdout, 714 | stderr=stderr, 715 | extra_env=extra_env, 716 | cwd=cwd, 717 | check=check, 718 | timeout=timeout, 719 | ) 720 | 721 | def run_function( 722 | self, func: Callable[[DeployHost], T], check: bool = True 723 | ) -> List[HostResult[T]]: 724 | """ 725 | Function to run for each host in the group in parallel 726 | 727 | @func the function to call 728 | """ 729 | threads = [] 730 | results: List[HostResult[T]] = [ 731 | HostResult(h, Exception(f"No result set for thread {i}")) 732 | for (i, h) in enumerate(self.hosts) 733 | ] 734 | for i, host in enumerate(self.hosts): 735 | thread = Thread( 736 | target=_worker, 737 | args=(func, host, results, i), 738 | ) 739 | threads.append(thread) 740 | 741 | for thread in threads: 742 | thread.start() 743 | 744 | for thread in threads: 745 | thread.join() 746 | if check: 747 | self._reraise_errors(results) 748 | return results 749 | 750 | def filter(self, pred: Callable[[DeployHost], bool]) -> "DeployGroup": 751 | """Return a new DeployGroup with the results filtered by the predicate""" 752 | return DeployGroup(list(filter(pred, self.hosts))) 753 | 754 | 755 | @overload 756 | def run( 757 | cmd: Union[List[str], str], 758 | text: Literal[True] = ..., 759 | stdout: FILE = ..., 760 | stderr: FILE = ..., 761 | extra_env: Dict[str, str] = ..., 762 | cwd: Union[None, str, Path] = ..., 763 | check: bool = ..., 764 | ) -> subprocess.CompletedProcess[str]: 765 | ... 766 | 767 | 768 | @overload 769 | def run( 770 | cmd: Union[List[str], str], 771 | text: Literal[False], 772 | stdout: FILE = ..., 773 | stderr: FILE = ..., 774 | extra_env: Dict[str, str] = ..., 775 | cwd: Union[None, str, Path] = ..., 776 | check: bool = ..., 777 | ) -> subprocess.CompletedProcess[bytes]: 778 | ... 779 | 780 | 781 | def run( 782 | cmd: Union[List[str], str], 783 | text: bool = True, 784 | stdout: FILE = None, 785 | stderr: FILE = None, 786 | extra_env: Dict[str, str] = {}, 787 | cwd: Union[None, str, Path] = None, 788 | check: bool = True, 789 | ) -> subprocess.CompletedProcess[Any]: 790 | """ 791 | Run command locally 792 | 793 | @cmd if this parameter is a string the command is interpreted as a shell command, 794 | otherwise if it is a list, than the first list element is the command 795 | and the remaining list elements are passed as arguments to the 796 | command. 797 | @text when true, file objects for stdout and stderr are opened in text mode. 798 | @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE 799 | @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE 800 | @extra_env environment variables to override whe running the command 801 | @cwd current working directory to run the process in 802 | @check If check is true, and the process exits with a non-zero exit code, a 803 | CalledProcessError exception will be raised. Attributes of that exception 804 | hold the arguments, the exit code, and stdout and stderr if they were 805 | captured. 806 | """ 807 | if isinstance(cmd, list): 808 | info("$ " + " ".join(cmd)) 809 | else: 810 | info(f"$ {cmd}") 811 | env = os.environ.copy() 812 | env.update(extra_env) 813 | 814 | return subprocess.run( 815 | cmd, 816 | stdout=stdout, 817 | stderr=stderr, 818 | env=env, 819 | cwd=cwd, 820 | check=check, 821 | shell=not isinstance(cmd, list), 822 | text=text, 823 | ) 824 | 825 | 826 | def parse_hosts( 827 | hosts: str, 828 | host_key_check: HostKeyCheck = HostKeyCheck.STRICT, 829 | key: Optional[str] = None, 830 | forward_agent: bool = False, 831 | domain_suffix: str = "", 832 | default_user: Optional[str] = None, 833 | ) -> DeployGroup: 834 | """ 835 | Parse comma seperated string of hosts 836 | 837 | @hosts A comma seperated list of hostnames with optional username (defaulting to root) i.e. admin@node1.example.com,admin@node2.example.com 838 | @host_key_check wether to check ssh host keys 839 | @forward_agent wether to forward the ssh agent 840 | @domain_suffix a string to append to each hostname, i.e. hosts=admin@node0, domain_suffix=example.com -> admin@node0.example.com 841 | @default_user user to choose if no ssh user is specified with the hostname 842 | 843 | @return A deploy group containing all hosts specified in hosts 844 | """ 845 | deploy_hosts = [] 846 | for h in hosts.split(","): 847 | parts = h.split("@") 848 | if len(parts) > 1: 849 | user: Optional[str] = parts[0] 850 | hostname = parts[1] 851 | else: 852 | user = default_user 853 | hostname = parts[0] 854 | maybe_port = hostname.split(":") 855 | port = None 856 | if len(maybe_port) > 1: 857 | hostname = maybe_port[0] 858 | port = int(maybe_port[1]) 859 | deploy_hosts.append( 860 | DeployHost( 861 | hostname + domain_suffix, 862 | user=user, 863 | port=port, 864 | key=key, 865 | host_key_check=host_key_check, 866 | forward_agent=forward_agent, 867 | ) 868 | ) 869 | return DeployGroup(deploy_hosts) 870 | -------------------------------------------------------------------------------- /deploykit/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/numtide/deploykit/ec9900179242931f75d8419ca6155b645e250cc0/deploykit/py.typed -------------------------------------------------------------------------------- /examples/basic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import subprocess 5 | 6 | from deploykit import DeployHost, parse_hosts, run 7 | 8 | 9 | def deploy(host: DeployHost) -> None: 10 | # Our function will receive a DeployHost object. This object behaves 11 | # similar to DeployGroup, except that it is just for one host instead of a 12 | # group of hosts. 13 | 14 | # This is running locally 15 | host.run_local("hostname") 16 | 17 | # This is running on the remote machine 18 | host.run("hostname") 19 | 20 | # We can also use our `DeployHost` object to get connection info for other ssh hosts 21 | # host.run_local( 22 | # f"rsync {' --exclude -vaF --delete -e ssh . {host.user}@{host.host}:/etc/nixos" 23 | # ) 24 | 25 | 26 | def main() -> None: 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("hosts") 29 | args = parser.parse_args() 30 | 31 | # This command runs on the local machine 32 | run("hostname") 33 | 34 | # parse_host accepts hosts in comma seperated for form and returns a DeployGroup 35 | # Hosts can contain username and port numbers 36 | # i.e. host1,host2,admin@host3:2222 37 | g = parse_hosts(args.hosts) 38 | # g will now contain a group of hosts. This is a shorthand for writing 39 | # g = deploykit.DeployGroup([ 40 | # DeployHost(host="myhostname"), 41 | # DeployHost(host="myhostname2"), 42 | # ]) 43 | # Let's see what we can do with a `DeployGroup` 44 | 45 | # This command runs locally in parallel for all hosts 46 | g.run_local("hostname") 47 | # This commands runs remotely in parallel for all hosts 48 | g.run("hostname") 49 | 50 | # This function runs in parallel for all hosts. This is useful if you want 51 | # to run a series of commands per host. 52 | g.run_function(deploy) 53 | 54 | # By default all functions will throw a subprocess.CalledProcess exception if a command fails. 55 | # When check=False is passed, instead a subprocess.CompletedProcess value is returned and the user 56 | # can check the result of command by inspecting the `returncode` attribute 57 | runs = g.run_local("false", check=False) 58 | print(runs[0].result.returncode) 59 | 60 | # To capture the output of a command, set stdout/stderr parameter 61 | runs = g.run_local("hostname", stdout=subprocess.PIPE) 62 | print(runs[0].result.stdout) 63 | 64 | # To select a subset of the hosts, you can use the filter function 65 | g2 = g.filter(lambda h: h.host == "host2") 66 | # This should then only output "host2" 67 | g2.run("hostname") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-parts": { 4 | "inputs": { 5 | "nixpkgs-lib": [ 6 | "nixpkgs" 7 | ] 8 | }, 9 | "locked": { 10 | "lastModified": 1690933134, 11 | "narHash": "sha256-ab989mN63fQZBFrkk4Q8bYxQCktuHmBIBqUG1jl6/FQ=", 12 | "owner": "hercules-ci", 13 | "repo": "flake-parts", 14 | "rev": "59cf3f1447cfc75087e7273b04b31e689a8599fb", 15 | "type": "github" 16 | }, 17 | "original": { 18 | "owner": "hercules-ci", 19 | "repo": "flake-parts", 20 | "type": "github" 21 | } 22 | }, 23 | "nixpkgs": { 24 | "locked": { 25 | "lastModified": 1691216863, 26 | "narHash": "sha256-OXWbblQhOMI5tBZ/WINDjT955NckikGhPuk+wybx/ho=", 27 | "owner": "NixOS", 28 | "repo": "nixpkgs", 29 | "rev": "bbe9ae26ec7ab1399418e7c56030215063c45757", 30 | "type": "github" 31 | }, 32 | "original": { 33 | "owner": "NixOS", 34 | "repo": "nixpkgs", 35 | "type": "github" 36 | } 37 | }, 38 | "root": { 39 | "inputs": { 40 | "flake-parts": "flake-parts", 41 | "nixpkgs": "nixpkgs", 42 | "treefmt-nix": "treefmt-nix" 43 | } 44 | }, 45 | "treefmt-nix": { 46 | "inputs": { 47 | "nixpkgs": [ 48 | "nixpkgs" 49 | ] 50 | }, 51 | "locked": { 52 | "lastModified": 1690874496, 53 | "narHash": "sha256-qYZJVAfilFbUL6U+euMjKLXUADueMNQBqwihpNzTbDU=", 54 | "owner": "numtide", 55 | "repo": "treefmt-nix", 56 | "rev": "fab56c8ce88f593300cd8c7351c9f97d10c333c5", 57 | "type": "github" 58 | }, 59 | "original": { 60 | "owner": "numtide", 61 | "repo": "treefmt-nix", 62 | "type": "github" 63 | } 64 | } 65 | }, 66 | "root": "root", 67 | "version": 7 68 | } 69 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Execute commands remotely and locally in parallel for a group of hosts with 3 | python"; 4 | 5 | inputs = { 6 | flake-parts.inputs.nixpkgs-lib.follows = "nixpkgs"; 7 | flake-parts.url = "github:hercules-ci/flake-parts"; 8 | nixpkgs.url = "github:NixOS/nixpkgs"; 9 | treefmt-nix.url = "github:numtide/treefmt-nix"; 10 | treefmt-nix.inputs.nixpkgs.follows = "nixpkgs"; 11 | }; 12 | 13 | outputs = inputs @ { flake-parts, nixpkgs, ... }: 14 | (flake-parts.lib.evalFlakeModule { inherit inputs; } ({ lib, pkgs, ... }: { 15 | imports = [ 16 | inputs.treefmt-nix.flakeModule 17 | ]; 18 | systems = 19 | let 20 | opensshPlatforms = lib.intersectLists lib.systems.flakeExposed nixpkgs.legacyPackages.x86_64-linux.openssh.meta.platforms; 21 | in 22 | nixpkgs.lib.subtractLists [ "mipsel-linux" "armv5tel-linux" ] opensshPlatforms; 23 | perSystem = { self', pkgs, ... }: { 24 | packages.deploykit = pkgs.python3.pkgs.callPackage ./nix/default.nix { }; 25 | packages.default = self'.packages.deploykit; 26 | devShells.default = pkgs.callPackage ./nix/shell.nix { }; 27 | treefmt = import ./treefmt.nix; 28 | }; 29 | })).config.flake; 30 | } 31 | -------------------------------------------------------------------------------- /nix/default.nix: -------------------------------------------------------------------------------- 1 | { buildPythonPackage 2 | , mypy 3 | , setuptools 4 | , glibcLocales 5 | , pytestCheckHook 6 | , openssh 7 | , bash 8 | , lib 9 | , stdenv 10 | }: 11 | 12 | buildPythonPackage { 13 | name = "deploykit"; 14 | src = ./..; 15 | 16 | buildInputs = [ 17 | setuptools 18 | ]; 19 | 20 | nativeCheckInputs = [ openssh mypy bash glibcLocales pytestCheckHook ]; 21 | 22 | disabledTests = lib.optionals stdenv.isDarwin [ "test_ssh" ]; 23 | 24 | # don't swallow stdout/stderr 25 | pytestFlagsArray = [ "-s" ]; 26 | 27 | postCheck = '' 28 | echo -e "\x1b[32m## run mypy\x1b[0m" 29 | MYPYPATH=$(pwd):$(pwd)/tests mypy --strict --namespace-packages --explicit-package-bases . 30 | ''; 31 | meta = with lib; { 32 | description = "Execute commands remote via ssh and locally in parallel with python"; 33 | homepage = "https://github.com/numtide/deploykit"; 34 | license = licenses.mit; 35 | maintainers = with maintainers; [ mic92 ]; 36 | platforms = platforms.unix; 37 | }; 38 | } 39 | -------------------------------------------------------------------------------- /nix/shell.nix: -------------------------------------------------------------------------------- 1 | { pkgs ? import { } }: 2 | pkgs.mkShell { 3 | nativeBuildInputs = [ 4 | pkgs.bashInteractive 5 | pkgs.openssh 6 | pkgs.mypy 7 | pkgs.python3.pkgs.pytest 8 | pkgs.python3.pkgs.setuptools 9 | ]; 10 | } 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 88 3 | 4 | select = ["E", "F", "I"] 5 | ignore = [ "E501" ] 6 | 7 | [tool.mypy] 8 | python_version = "3.10" 9 | warn_redundant_casts = true 10 | disallow_untyped_calls = true 11 | disallow_untyped_defs = true 12 | no_implicit_optional = true 13 | 14 | [[tool.mypy.overrides]] 15 | module = "setuptools.*" 16 | ignore_missing_imports = true 17 | 18 | [[tool.mypy.overrides]] 19 | module = "pytest.*" 20 | ignore_missing_imports = true 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = deploykit 3 | version = 0.0.0 4 | author = Jörg Thalheim 5 | author-email = joerg@thalheim.io 6 | home-page = https://github.com/numtide/deploykit 7 | description = Execute commands remote via ssh and locally in parallel with python 8 | long-description = file: README.rst 9 | license = MIT 10 | license-file = LICENSE 11 | platform = any 12 | 13 | classifiers = 14 | Development Status :: 5 - Production/Stable 15 | Environment :: Console 16 | Intended Audience :: Developers 17 | Intended Audience :: Information Technology 18 | Intended Audience :: System Administrators 19 | License :: OSI Approved :: MIT License 20 | Natural Language :: English 21 | Operating System :: Posix 22 | Programming Language :: Python :: 3 23 | Programming Language :: Python :: 3.8 24 | Programming Language :: Python :: 3.9 25 | Programming Language :: Python :: 3.10 26 | Programming Language :: Python :: 3.11 27 | Programming Language :: Python :: 3 :: Only 28 | Topic :: System :: Installation/Setup 29 | Topic :: System :: Systems Administration 30 | Topic :: Utilities 31 | 32 | [options.package_data] 33 | # include type hints if present in any package 34 | * = py.typed 35 | 36 | [options] 37 | zip_safe = true 38 | include_package_data = true 39 | python_requires = >= 3.8 40 | packages = find: 41 | setup_requires = 42 | setuptools 43 | 44 | [bdist_wheel] 45 | universal = true 46 | 47 | [check] 48 | metadata = true 49 | restructuredtext = true 50 | strict = true[wheel] 51 | universal = 1 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup() 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/numtide/deploykit/ec9900179242931f75d8419ca6155b645e250cc0/tests/__init__.py -------------------------------------------------------------------------------- /tests/command.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import subprocess 4 | from typing import IO, Any, Dict, Iterator, List, Union 5 | 6 | import pytest 7 | 8 | _FILE = Union[None, int, IO[Any]] 9 | 10 | 11 | class Command: 12 | def __init__(self) -> None: 13 | self.processes: List[subprocess.Popen[str]] = [] 14 | 15 | def run( 16 | self, 17 | command: List[str], 18 | extra_env: Dict[str, str] = {}, 19 | stdin: _FILE = None, 20 | stdout: _FILE = None, 21 | stderr: _FILE = None, 22 | ) -> subprocess.Popen[str]: 23 | env = os.environ.copy() 24 | env.update(extra_env) 25 | # We start a new session here so that we can than more reliably kill all childs as well 26 | p = subprocess.Popen( 27 | command, 28 | env=env, 29 | start_new_session=True, 30 | stdout=stdout, 31 | stderr=stderr, 32 | stdin=stdin, 33 | text=True, 34 | ) 35 | self.processes.append(p) 36 | return p 37 | 38 | def terminate(self) -> None: 39 | # Stop in reverse order in case there are dependencies. 40 | # We just kill all processes as quickly as possible because we don't 41 | # care about corrupted state and want to make tests fasts. 42 | for p in reversed(self.processes): 43 | try: 44 | os.killpg(os.getpgid(p.pid), signal.SIGKILL) 45 | except OSError: 46 | pass 47 | 48 | 49 | @pytest.fixture 50 | def command() -> Iterator[Command]: 51 | """ 52 | Starts a background command. The process is automatically terminated in the end. 53 | >>> p = command.run(["some", "daemon"]) 54 | >>> print(p.pid) 55 | """ 56 | c = Command() 57 | try: 58 | yield c 59 | finally: 60 | c.terminate() 61 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(__file__)) 5 | 6 | pytest_plugins = [ 7 | "root", 8 | "command", 9 | "ports", 10 | "sshd", 11 | ] 12 | -------------------------------------------------------------------------------- /tests/getpwnam-preload.c: -------------------------------------------------------------------------------- 1 | #define _GNU_SOURCE 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | typedef struct passwd *(*getpwnam_type)(const char *name); 10 | 11 | struct passwd *getpwnam(const char *name) { 12 | struct passwd *pw; 13 | getpwnam_type orig_getpwnam; 14 | orig_getpwnam = (getpwnam_type)dlsym(RTLD_NEXT, "getpwnam"); 15 | pw = orig_getpwnam(name); 16 | 17 | if (pw) { 18 | const char *shell = getenv("LOGIN_SHELL"); 19 | if (!shell) { 20 | fprintf(stderr, "no LOGIN_SHELL set\n"); 21 | exit(1); 22 | } 23 | fprintf(stderr, "SHELL:%s\n", shell); 24 | pw->pw_shell = strdup(shell); 25 | } 26 | return pw; 27 | } 28 | -------------------------------------------------------------------------------- /tests/ports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import socket 4 | 5 | import pytest 6 | 7 | NEXT_PORT = 10000 8 | 9 | 10 | def check_port(port: int) -> bool: 11 | tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 12 | udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 13 | with tcp, udp: 14 | try: 15 | tcp.bind(("127.0.0.1", port)) 16 | udp.bind(("127.0.0.1", port)) 17 | return True 18 | except socket.error: 19 | return False 20 | 21 | 22 | def check_port_range(port_range: range) -> bool: 23 | for port in port_range: 24 | if not check_port(port): 25 | return False 26 | return True 27 | 28 | 29 | class Ports: 30 | def allocate(self, num: int) -> int: 31 | """ 32 | Allocates 33 | """ 34 | global NEXT_PORT 35 | while NEXT_PORT + num <= 65535: 36 | start = NEXT_PORT 37 | NEXT_PORT += num 38 | if not check_port_range(range(start, NEXT_PORT)): 39 | continue 40 | return start 41 | raise Exception("cannot find enough free port") 42 | 43 | 44 | @pytest.fixture 45 | def ports() -> Ports: 46 | return Ports() 47 | -------------------------------------------------------------------------------- /tests/root.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | TEST_ROOT = Path(__file__).parent.resolve() 6 | PROJECT_ROOT = TEST_ROOT.parent 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def project_root() -> Path: 11 | """ 12 | Root directory of the tests 13 | """ 14 | return PROJECT_ROOT 15 | 16 | 17 | @pytest.fixture(scope="session") 18 | def test_root() -> Path: 19 | """ 20 | Root directory of the tests 21 | """ 22 | return TEST_ROOT 23 | -------------------------------------------------------------------------------- /tests/sshd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import time 5 | from pathlib import Path 6 | from sys import platform 7 | from tempfile import TemporaryDirectory 8 | from typing import Iterator, Optional 9 | 10 | import pytest 11 | from command import Command 12 | from ports import Ports 13 | 14 | 15 | class Sshd: 16 | def __init__(self, port: int, proc: subprocess.Popen[str], key: str) -> None: 17 | self.port = port 18 | self.proc = proc 19 | self.key = key 20 | 21 | 22 | class SshdConfig: 23 | def __init__(self, path: str, key: str, preload_lib: Optional[str]) -> None: 24 | self.path = path 25 | self.key = key 26 | self.preload_lib = preload_lib 27 | 28 | 29 | @pytest.fixture(scope="session") 30 | def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]: 31 | # FIXME, if any parent of `project_root` is world-writable than sshd will refuse it. 32 | with TemporaryDirectory(dir=project_root) as _dir: 33 | dir = Path(_dir) 34 | host_key = dir / "host_ssh_host_ed25519_key" 35 | subprocess.run( 36 | [ 37 | "ssh-keygen", 38 | "-t", 39 | "ed25519", 40 | "-f", 41 | host_key, 42 | "-N", 43 | "", 44 | ], 45 | check=True, 46 | ) 47 | 48 | sshd_config = dir / "sshd_config" 49 | sshd_config.write_text( 50 | f""" 51 | HostKey {host_key} 52 | LogLevel DEBUG3 53 | # In the nix build sandbox we don't get any meaningful PATH after login 54 | SetEnv PATH={os.environ.get("PATH", "")} 55 | MaxStartups 64:30:256 56 | AuthorizedKeysFile {host_key}.pub 57 | """ 58 | ) 59 | 60 | lib_path = None 61 | if platform == "linux": 62 | # This enforces a login shell by overriding the login shell of `getpwnam(3)` 63 | lib_path = str(dir / "libgetpwnam-preload.so") 64 | subprocess.run( 65 | [ 66 | os.environ.get("CC", "cc"), 67 | "-shared", 68 | "-o", 69 | lib_path, 70 | str(test_root / "getpwnam-preload.c"), 71 | ], 72 | check=True, 73 | ) 74 | 75 | yield SshdConfig(str(sshd_config), str(host_key), lib_path) 76 | 77 | 78 | @pytest.fixture 79 | def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]: 80 | port = ports.allocate(1) 81 | sshd = shutil.which("sshd") 82 | assert sshd is not None, "no sshd binary found" 83 | env = {} 84 | if sshd_config.preload_lib is not None: 85 | bash = shutil.which("bash") 86 | assert bash is not None 87 | env = dict(LD_PRELOAD=str(sshd_config.preload_lib), LOGIN_SHELL=bash) 88 | proc = command.run( 89 | [sshd, "-f", sshd_config.path, "-D", "-p", str(port)], extra_env=env 90 | ) 91 | 92 | while True: 93 | if ( 94 | subprocess.run( 95 | [ 96 | "ssh", 97 | "-o", 98 | "StrictHostKeyChecking=no", 99 | "-o", 100 | "UserKnownHostsFile=/dev/null", 101 | "-i", 102 | sshd_config.key, 103 | "localhost", 104 | "-p", 105 | str(port), 106 | "true", 107 | ] 108 | ).returncode 109 | == 0 110 | ): 111 | yield Sshd(port, proc, sshd_config.key) 112 | return 113 | else: 114 | rc = proc.poll() 115 | if rc is not None: 116 | raise Exception(f"sshd processes was terminated with {rc}") 117 | time.sleep(0.1) 118 | -------------------------------------------------------------------------------- /tests/test_local.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | from deploykit import DeployHost, parse_hosts, run 4 | 5 | 6 | def test_run() -> None: 7 | p = run("echo hello") 8 | assert p.stdout is None 9 | 10 | 11 | def test_run_failure() -> None: 12 | p = run("exit 1", check=False) 13 | assert p.returncode == 1 14 | 15 | try: 16 | p = run("exit 1") 17 | except Exception: 18 | pass 19 | else: 20 | assert False, "Command should have raised an error" 21 | 22 | 23 | def test_run_environment() -> None: 24 | p1 = run("echo $env_var", stdout=subprocess.PIPE, extra_env=dict(env_var="true")) 25 | assert p1.stdout == "true\n" 26 | 27 | hosts = parse_hosts("some_host") 28 | p2 = hosts.run_local( 29 | "echo $env_var", extra_env=dict(env_var="true"), stdout=subprocess.PIPE 30 | ) 31 | assert p2[0].result.stdout == "true\n" 32 | 33 | p3 = hosts.run_local( 34 | ["env"], extra_env=dict(env_var="true"), stdout=subprocess.PIPE 35 | ) 36 | assert "env_var=true" in p3[0].result.stdout 37 | 38 | 39 | def test_run_non_shell() -> None: 40 | p = run(["echo", "$hello"], stdout=subprocess.PIPE) 41 | assert p.stdout == "$hello\n" 42 | 43 | 44 | def test_run_stderr_stdout() -> None: 45 | p = run("echo 1; echo 2 >&2", stdout=subprocess.PIPE, stderr=subprocess.PIPE) 46 | assert p.stdout == "1\n" 47 | assert p.stderr == "2\n" 48 | 49 | 50 | def test_run_local() -> None: 51 | hosts = parse_hosts("some_host") 52 | hosts.run_local("echo hello") 53 | 54 | 55 | def test_timeout() -> None: 56 | hosts = parse_hosts("some_host") 57 | try: 58 | hosts.run_local("sleep 10", timeout=0.01) 59 | except Exception: 60 | pass 61 | else: 62 | assert False, "should have raised TimeoutExpired" 63 | 64 | 65 | def test_run_function() -> None: 66 | def some_func(h: DeployHost) -> bool: 67 | p = h.run_local("echo hello", stdout=subprocess.PIPE) 68 | return p.stdout == "hello\n" 69 | 70 | hosts = parse_hosts("some_host") 71 | res = hosts.run_function(some_func) 72 | assert res[0].result 73 | 74 | 75 | def test_run_exception() -> None: 76 | hosts = parse_hosts("some_host") 77 | try: 78 | hosts.run_local("exit 1") 79 | except Exception: 80 | pass 81 | else: 82 | assert False, "should have raised Exception" 83 | 84 | 85 | def test_run_function_exception() -> None: 86 | def some_func(h: DeployHost) -> None: 87 | h.run_local("exit 1") 88 | 89 | hosts = parse_hosts("some_host") 90 | try: 91 | hosts.run_function(some_func) 92 | except Exception: 93 | pass 94 | else: 95 | assert False, "should have raised Exception" 96 | 97 | 98 | def test_run_local_non_shell() -> None: 99 | hosts = parse_hosts("some_host") 100 | p2 = hosts.run_local(["echo", "1"], stdout=subprocess.PIPE) 101 | assert p2[0].result.stdout == "1\n" 102 | -------------------------------------------------------------------------------- /tests/test_ssh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pwd 3 | import subprocess 4 | 5 | from sshd import Sshd 6 | 7 | from deploykit import DeployGroup, DeployHost, HostKeyCheck, parse_hosts 8 | 9 | 10 | def deploy_group(sshd: Sshd) -> DeployGroup: 11 | login = pwd.getpwuid(os.getuid()).pw_name 12 | return parse_hosts( 13 | f"{login}@127.0.0.1:{sshd.port}", host_key_check=HostKeyCheck.NONE, key=sshd.key 14 | ) 15 | 16 | 17 | def test_run(sshd: Sshd) -> None: 18 | g = deploy_group(sshd) 19 | proc = g.run("echo hello", stdout=subprocess.PIPE) 20 | assert proc[0].result.stdout == "hello\n" 21 | 22 | 23 | def test_run_environment(sshd: Sshd) -> None: 24 | g = deploy_group(sshd) 25 | p1 = g.run("echo $env_var", stdout=subprocess.PIPE, extra_env=dict(env_var="true")) 26 | assert p1[0].result.stdout == "true\n" 27 | p2 = g.run(["env"], stdout=subprocess.PIPE, extra_env=dict(env_var="true")) 28 | assert "env_var=true" in p2[0].result.stdout 29 | 30 | 31 | def test_run_no_shell(sshd: Sshd) -> None: 32 | g = deploy_group(sshd) 33 | proc = g.run(["echo", "$hello"], stdout=subprocess.PIPE) 34 | assert proc[0].result.stdout == "$hello\n" 35 | 36 | 37 | def test_run_function(sshd: Sshd) -> None: 38 | def some_func(h: DeployHost) -> bool: 39 | p = h.run("echo hello", stdout=subprocess.PIPE) 40 | return p.stdout == "hello\n" 41 | 42 | g = deploy_group(sshd) 43 | res = g.run_function(some_func) 44 | assert res[0].result 45 | 46 | 47 | def test_timeout(sshd: Sshd) -> None: 48 | g = deploy_group(sshd) 49 | try: 50 | g.run("sleep 10", timeout=0.01) 51 | except Exception: 52 | pass 53 | else: 54 | assert False, "should have raised TimeoutExpired" 55 | 56 | 57 | def test_run_exception(sshd: Sshd) -> None: 58 | g = deploy_group(sshd) 59 | 60 | r = g.run("exit 1", check=False) 61 | assert r[0].result.returncode == 1 62 | 63 | try: 64 | g.run("exit 1") 65 | except Exception: 66 | pass 67 | else: 68 | assert False, "should have raised Exception" 69 | 70 | 71 | def test_run_function_exception(sshd: Sshd) -> None: 72 | def some_func(h: DeployHost) -> subprocess.CompletedProcess[str]: 73 | return h.run_local("exit 1") 74 | 75 | g = deploy_group(sshd) 76 | 77 | try: 78 | g.run_function(some_func) 79 | except Exception: 80 | pass 81 | else: 82 | assert False, "should have raised Exception" 83 | -------------------------------------------------------------------------------- /treefmt.nix: -------------------------------------------------------------------------------- 1 | { lib, pkgs, ... }: 2 | { 3 | projectRootFile = "flake.lock"; 4 | 5 | programs.nixpkgs-fmt.enable = true; 6 | settings.formatter = { 7 | python = { 8 | command = "sh"; 9 | options = [ 10 | "-eucx" 11 | '' 12 | ${lib.getExe pkgs.ruff} --fix "$@" 13 | ${lib.getExe pkgs.black} "$@" 14 | '' 15 | "--" # this argument is ignored by bash 16 | ]; 17 | includes = [ "*.py" ]; 18 | }; 19 | 20 | }; 21 | } 22 | --------------------------------------------------------------------------------