├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── docs ├── astdown │ ├── __init__.py │ ├── docstring.py │ ├── loader.py │ └── markdown.py ├── fora.png └── make.py ├── pylint-check.sh ├── pyproject.toml ├── setup.cfg ├── src └── fora │ ├── __init__.py │ ├── connection.py │ ├── connectors │ ├── __init__.py │ ├── connector.py │ ├── local.py │ ├── ssh.py │ ├── tunnel_connector.py │ └── tunnel_dispatcher.py │ ├── example_deploys.py │ ├── inventory_wrapper.py │ ├── loader.py │ ├── logger.py │ ├── main.py │ ├── operations │ ├── __init__.py │ ├── api.py │ ├── apt.py │ ├── files.py │ ├── git.py │ ├── local.py │ ├── pacman.py │ ├── pip.py │ ├── portage.py │ ├── postgres.py │ ├── system.py │ ├── systemd.py │ └── utils.py │ ├── remote_settings.py │ ├── types.py │ └── utils.py └── test ├── group_dependency_cycle ├── inventory.py └── test_group_dependency_cycle.py ├── group_dependency_cycle_complex ├── inventory.py └── test_group_dependency_cycle_complex.py ├── group_dependency_cycle_self ├── inventory.py └── test_group_dependency_cycle_self.py ├── group_variable_conflict ├── groups │ ├── group1.py │ ├── group2.py │ └── group3.py ├── inventory.py └── test_group_variable_conflict.py ├── inventory ├── mock_inventories │ ├── empty.py │ ├── hosts │ │ ├── host1.py │ │ ├── host2.py │ │ └── host_templ.py │ ├── invalid_hosts_entries.py │ ├── missing_definition.py │ ├── simple_test.py │ └── single_host1.py ├── test_dynamic_instanciation.py ├── test_empty.py └── test_missing_hosts.py ├── operations └── subdeploy.py ├── simple_deploy ├── deploy.py ├── deploy_bad.py ├── deploy_bad_recursive.py └── inventory.py ├── simple_inventory ├── groups │ ├── all.py │ ├── desktops.py │ ├── only34.py │ └── somehosts.py ├── hosts │ ├── host1.py │ ├── host2.py │ ├── host3.py │ ├── host4.py │ └── host5.py ├── inventory.py ├── test_simple_inventory.py └── testlink ├── templates └── test.j2 ├── test_connection.py ├── test_connector_resolve.py ├── test_init_deploy.py ├── test_loading.py └── test_operations.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | .tox 4 | .vim 5 | .coverage 6 | 7 | venv/ 8 | build/ 9 | htmlcov/ 10 | dist/ 11 | *.egg-info 12 | 13 | AUTHORS 14 | ChangeLog 15 | src/fora/version.py 16 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oddlama/fora/e270a021c45666c8b22250f1ae7a1534fe6040d3/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 oddlama 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 | 5 | 11 | 12 | ## What is Fora? 13 | 14 | Fora is an infrastructure and configuration management tool inspired by [Ansible](https://www.ansible.com) and [pyinfra](https://pyinfra.com). 15 | Yet, it implements a drastically different approach to inventory management (and some other aspects), when compared to these well-known tools. 16 | See [how it differs](https://oddlama.gitbook.io/fora/outlining-the-differences#how-is-fora-different-from-existing-tools) for more details. 17 | 18 | ## Installation & Quickstart 19 | 20 | You can install Fora with pip: 21 | 22 | ```bash 23 | pip install fora 24 | ``` 25 | 26 | Afterwards, you can use it to write scripts which will be used to run operations or commands on a remote host. 27 | 28 | ```python 29 | # deploy.py 30 | from fora.operations import files, system 31 | 32 | files.directory( 33 | name="Create a temporary directory", 34 | path="/tmp/hello") 35 | 36 | system.package( 37 | name="Install neovim", 38 | package="neovim") 39 | ``` 40 | 41 | These scripts are executed against an inventory, or a specific remote host (usually via SSH). 42 | 43 | ```bash 44 | fora root@example.com deploy.py 45 | ``` 46 | 47 | To start with your own (more complex) deploy, you can have Fora create a scaffolding in an empty directory. There are [different scaffoldings](https://oddlama.gitbook.io/fora/usage/introduction#deploy-structure) available for different use-cases. 48 | 49 | ```bash 50 | fora --init minimal 51 | ``` 52 | 53 | Fora can do a lot more than this, which is explained in the [Introduction](https://oddlama.gitbook.io/fora/usage/introduction). If you are interested in how Fora is different from existing tools, have a look at [Outlining the differences](https://oddlama.gitbook.io/fora/outlining-the-differences). 54 | -------------------------------------------------------------------------------- /docs/astdown/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oddlama/fora/e270a021c45666c8b22250f1ae7a1534fe6040d3/docs/astdown/__init__.py -------------------------------------------------------------------------------- /docs/astdown/docstring.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from dataclasses import dataclass, field 3 | import sys 4 | from typing import Optional 5 | from textwrap import dedent 6 | 7 | from astdown.loader import Module, docstring 8 | 9 | @dataclass 10 | class DocstringSection: 11 | name: str 12 | decls: dict[str, str] = field(default_factory=dict) 13 | 14 | @dataclass 15 | class Docstring: 16 | content: Optional[str] = None 17 | sections: dict[str, DocstringSection] = field(default_factory=dict) 18 | 19 | def parse_numpy_docstring(node: ast.AST, module: Module) -> Optional[Docstring]: 20 | docstr = None 21 | if isinstance(node, ast.Constant) and isinstance(node.value, str): 22 | docstr = node.value 23 | else: 24 | docstr = docstring(node, module) 25 | if docstr is None: 26 | return None 27 | 28 | doc = Docstring() 29 | is_in_code = False 30 | 31 | section: Optional[DocstringSection] = None 32 | decl: Optional[str] = None 33 | content = "" 34 | def _commit_content(): 35 | nonlocal section, content 36 | content = dedent(content).strip() 37 | 38 | if content != "": 39 | if section is None: 40 | doc.content = content 41 | elif decl is not None: 42 | section.decls[decl] = content 43 | content = "" 44 | 45 | lines = docstr.splitlines() 46 | skip_next_line = False 47 | for line, next in zip(lines, lines[1:] + [""]): 48 | if skip_next_line: 49 | skip_next_line = False 50 | continue 51 | 52 | line_stripped = line.strip() 53 | line_has_leading_whitespace = line.startswith((" ", "\t")) 54 | 55 | # Don't interpret anything in code blocks 56 | if line_stripped.startswith("```"): 57 | is_in_code = not is_in_code 58 | if is_in_code: 59 | content += line + "\n" 60 | continue 61 | 62 | # If a section start is encountered, skip the section lines, 63 | # commit the currently accumulated content and start the section 64 | if next.startswith("----") and not line_has_leading_whitespace: 65 | if len(line_stripped) != len(next): 66 | print(f"warning: Encountered invalid section underline below '{line_stripped}' in {module.path}:{node.lineno}", file=sys.stderr) 67 | 68 | _commit_content() 69 | section = DocstringSection(name=line_stripped) 70 | doc.sections[section.name.lower()] = section 71 | skip_next_line = True 72 | continue 73 | 74 | # If a line without leading whitespace is encountered in a section, 75 | # and it wasn't a new section start, we have a new decl. 76 | if section is not None and not line_has_leading_whitespace: 77 | _commit_content() 78 | decl = line_stripped 79 | continue 80 | 81 | content += line + "\n" 82 | _commit_content() 83 | 84 | return doc 85 | -------------------------------------------------------------------------------- /docs/astdown/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import ast 3 | import os 4 | import sys 5 | from ast import Module as AstModule 6 | from dataclasses import dataclass, field 7 | from pathlib import Path 8 | from typing import Literal, Optional 9 | 10 | @dataclass 11 | class IndexEntry: 12 | type: Literal["module", "function", "attribute", "class"] 13 | fqname: str 14 | name: str 15 | parent: Optional[IndexEntry] 16 | module: Optional[IndexEntry] 17 | module_url: Optional[str] 18 | node: ast.AST 19 | 20 | def short_name(self): 21 | return '.'.join(self.fqname.split(".")[-2:]) 22 | 23 | def display_name(self): 24 | dn = self.short_name() 25 | if self.type == "function": 26 | dn += "()" 27 | return dn 28 | 29 | def url(self, relative_to: Optional[str] = None): 30 | if self.type == "module": 31 | subref = "" 32 | elif self.type == "function": 33 | subref = f"#def-{self.short_name()}" 34 | elif self.type == "class": 35 | subref = f"#class-{self.short_name()}" 36 | elif self.type == "attribute": 37 | subref = f"#attr-{self.short_name()}" 38 | else: 39 | raise ValueError(f"Invalid {self.type=}") 40 | 41 | module = self.module or self 42 | relative_path = module.module_url or "" 43 | if relative_to is not None: 44 | relative_path = os.path.relpath(relative_path, start=os.path.dirname(relative_to)) 45 | return relative_path + subref 46 | 47 | def replace_crossrefs(content: str, node: ast.AST, module: Module) -> str: 48 | """Currently intended to be monkeypatched.""" 49 | _ = (node, module) 50 | return content 51 | 52 | def docstring(node: ast.AST, module: Module) -> Optional[str]: 53 | docstr = ast.get_docstring(node) 54 | if docstr is None: 55 | return None 56 | return replace_crossrefs(docstr, node, module) 57 | 58 | def short_docstring(node: ast.AST, module: Module) -> Optional[str]: 59 | content = docstring(node, module) 60 | if content is None or content == "": 61 | return None 62 | locs = [len(content) - 1, content.find(". "), content.find(".\n"), content.find("\n\n")] 63 | first_dot_or_paragraph_end = min(l for l in locs if l > 0) 64 | return content[:first_dot_or_paragraph_end + 1] 65 | 66 | @dataclass 67 | class Module: 68 | name: str 69 | path: str 70 | modules: list[Module] = field(default_factory=list) 71 | packages: list[Module] = field(default_factory=list) 72 | index: Optional[dict[str, IndexEntry]] = None 73 | _ast: Optional[AstModule] = None 74 | 75 | @property 76 | def basename(self) -> str: 77 | return self.name.split('.')[-1] 78 | 79 | @property 80 | def ast(self) -> AstModule: 81 | if self._ast is None: 82 | with open(self.path, "r") as f: 83 | self._ast = ast.parse(f.read(), filename=self.path) 84 | return self._ast 85 | 86 | @property 87 | def docstring(self) -> Optional[str]: 88 | return docstring(self.ast, self) 89 | 90 | def generate_index(self, url: str) -> None: 91 | index: dict[str, IndexEntry] = {} 92 | index[self.name] = module_index = IndexEntry("module", self.name, self.name, None, None, url, self.ast) 93 | 94 | def _fqname(parent: IndexEntry, name: str): 95 | return f"{parent.fqname}.{name}" 96 | 97 | def _index_functions(body: list[ast.stmt], parent: IndexEntry): 98 | for node in body: 99 | # Functions 100 | if isinstance(node, ast.FunctionDef): 101 | fqname = _fqname(parent, node.name) 102 | index[fqname] = IndexEntry("function", fqname, node.name, parent, module_index, None, node) 103 | 104 | def _index_attributes(body: list[ast.stmt], parent: IndexEntry): 105 | for node in body: 106 | # Attributes 107 | if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): 108 | fqname = _fqname(parent, node.target.id) 109 | index[fqname] = IndexEntry("attribute", fqname, node.target.id, parent, module_index, None, node) 110 | if isinstance(node, ast.Assign): 111 | for target in node.targets: 112 | if isinstance(target, ast.Name): 113 | fqname = _fqname(parent, target.id) 114 | index[fqname] = IndexEntry("attribute", fqname, target.id, parent, module_index, None, node) 115 | 116 | def _index_classes(body: list[ast.stmt], parent: IndexEntry): 117 | for node in body: 118 | if isinstance(node, ast.ClassDef): 119 | fqname = _fqname(parent, node.name) 120 | index[fqname] = class_index = IndexEntry("class", fqname, node.name, parent, module_index, None, node) 121 | _index_functions(node.body, class_index) 122 | _index_attributes(node.body, class_index) 123 | _index_classes(node.body, class_index) 124 | 125 | _index_functions(self.ast.body, module_index) 126 | _index_attributes(self.ast.body, module_index) 127 | _index_classes(self.ast.body, module_index) 128 | 129 | self.index = index 130 | 131 | def find_module(module_name: str) -> Optional[str]: 132 | parts = module_name.split('.') 133 | filenames = [os.path.join(*parts, '__init__.py'), 134 | os.path.join(*parts) + '.py'] 135 | 136 | for path in sys.path: 137 | for choice in filenames: 138 | abs_path = os.path.normpath(os.path.join(path, choice)) 139 | if os.path.isfile(abs_path): 140 | return abs_path 141 | 142 | def _package_module_recursive(package_name: str, package_dir: Path) -> Optional[Module]: 143 | module_path = package_dir / "__init__.py" 144 | if not module_path.is_file(): 145 | return None 146 | 147 | module = Module(name=package_name, path=str(module_path)) 148 | for x in package_dir.iterdir(): 149 | if x.is_dir() and (x / "__init__.py").is_file(): 150 | sub_package = _package_module_recursive(f"{package_name}.{x.name}", x) 151 | if sub_package is not None: 152 | module.packages.append(sub_package) 153 | elif x.is_file() and x.name.endswith(".py") and x.name != "__init__.py": 154 | module.modules.append(Module(name=f"{package_name}.{x.name[:-len('.py')]}", path=str(x))) 155 | return module 156 | 157 | def find_package_or_module(package_name: str) -> Optional[Module]: 158 | module_path = find_module(package_name) 159 | if module_path is None: 160 | return None 161 | if module_path.endswith("__init__.py"): 162 | return _package_module_recursive(package_name, Path(module_path).parent) 163 | return Module(name=package_name, path=module_path) 164 | -------------------------------------------------------------------------------- /docs/astdown/markdown.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from dataclasses import dataclass, field 3 | from itertools import zip_longest 4 | from typing import Any, Callable, ContextManager, Optional, Union 5 | 6 | import astdown.loader 7 | from astdown.docstring import DocstringSection, parse_numpy_docstring 8 | from astdown.loader import Module, short_docstring 9 | 10 | from rich import print as rprint 11 | from rich.markup import escape as rescape 12 | 13 | def print(msg: Any, *args, **kwargs): 14 | rprint(rescape(msg) if isinstance(msg, str) else msg, *args, **kwargs) 15 | 16 | separate_each_parameter_in_function = False 17 | include_types_in_signature = True 18 | include_types_in_parameter_descriptions = False 19 | include_defaults_in_parameter_descriptions = False 20 | max_function_signature_width = 76 21 | 22 | @dataclass 23 | class _DelegateContextManager: 24 | f_enter: Callable[[], None] 25 | f_exit: Callable[[], None] 26 | 27 | def __enter__(self) -> None: 28 | self.f_enter() 29 | 30 | def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None: 31 | _ = (exc_type, exc_value, exc_traceback) 32 | self.f_exit() 33 | 34 | @dataclass 35 | class MarkdownWriter: 36 | content: str = "" 37 | title_depth: int = 0 38 | lists: list = field(default_factory=list) 39 | needs_bullet_point: bool = False 40 | indents: list = field(default_factory=list) 41 | indent_str: str = "" 42 | 43 | def _calculate_indent(self): 44 | self.indent_str = ''.join(self.indents) 45 | 46 | def margin(self, empty_lines: int): 47 | existing_newlines = 0 48 | for c in reversed(self.content): 49 | if c != "\n": 50 | break 51 | existing_newlines += 1 52 | self.content += max(empty_lines + 1 - existing_newlines, 0) * "\n" 53 | 54 | def add_line(self, line: str): 55 | indent_prefix = self.indent_str 56 | if self.needs_bullet_point: 57 | indent_prefix = ''.join(self.indents[:-1]) + self.lists[-1] + " " * (len(self.indents[-1]) - len(self.lists[-1])) 58 | self.needs_bullet_point = False 59 | 60 | self.content += indent_prefix + line 61 | if not line.endswith("\n"): 62 | self.content += "\n" 63 | 64 | def add_content(self, text: str): 65 | paragraphs = text.split("\n\n") 66 | for p in paragraphs: 67 | self.margin(1) 68 | for line in p.splitlines(): 69 | self.add_line(line) 70 | self.margin(1) 71 | 72 | def title(self, title: str, additional_depth: int = 1) -> ContextManager: 73 | def _enter(): 74 | self.title_depth += additional_depth 75 | self.margin(1) 76 | self.add_line("#" * self.title_depth + " " + title) 77 | self.margin(1) 78 | def _exit(): 79 | self.title_depth -= additional_depth 80 | return _DelegateContextManager(_enter, _exit) 81 | 82 | def unordered_list(self, sign: str = " -") -> ContextManager: 83 | def _enter(): 84 | self.margin(1) 85 | self.lists.append(sign) 86 | def _exit(): 87 | self.margin(1) 88 | self.lists.pop() 89 | return _DelegateContextManager(_enter, _exit) 90 | 91 | def list_item(self, indent=" ") -> ContextManager: 92 | def _enter(): 93 | self.indents.append(indent) 94 | self._calculate_indent() 95 | self.needs_bullet_point = True 96 | def _exit(): 97 | self.indents.pop() 98 | self._calculate_indent() 99 | if self.needs_bullet_point: 100 | self.add_line("") 101 | self.needs_bullet_point = False 102 | return _DelegateContextManager(_enter, _exit) 103 | 104 | def function_def_to_markdown(markdown: MarkdownWriter, func: ast.FunctionDef, parent_basename: Optional[str]) -> None: 105 | if len(func.args.posonlyargs) > 0: 106 | raise NotImplementedError(f"functions with 'posonlyargs' are not supported.") 107 | 108 | # Word tokens that need to be joined together, but should be wrapped 109 | # before overflowing to the right. If the required alignment width 110 | # becomes too high, we fall back to simple fixed alignment. 111 | if parent_basename is None: 112 | initial_str = f"def {func.name}(" 113 | else: 114 | initial_str = f"def {parent_basename}.{func.name}(" 115 | align_width = len(initial_str) 116 | if align_width > 32: 117 | align_width = 8 118 | 119 | def _arg_to_str(arg: ast.arg, default: Optional[ast.expr] = None, prefix: str = ""): 120 | arg_token = prefix + arg.arg 121 | equals_separator = "=" 122 | if arg.annotation is not None and include_types_in_signature: 123 | arg_token += f": {ast.unparse(arg.annotation) or ''}" 124 | equals_separator = " = " 125 | if default is not None: 126 | arg_token += equals_separator + ast.unparse(default) or "" 127 | return arg_token 128 | 129 | arg: ast.arg 130 | default: Optional[ast.expr] 131 | tokens = [] 132 | for arg, default in reversed(list(zip_longest(reversed(func.args.args), reversed(func.args.defaults)))): 133 | tokens.append(_arg_to_str(arg, default)) 134 | if func.args.vararg is not None: 135 | tokens.append(_arg_to_str(func.args.vararg, prefix="*")) 136 | 137 | for arg, default in zip(func.args.kwonlyargs, func.args.kw_defaults): 138 | tokens.append(_arg_to_str(arg, default)) 139 | if func.args.kwarg is not None: 140 | tokens.append(_arg_to_str(func.args.kwarg, prefix="**")) 141 | 142 | # Append commata to arguments followed by arguments. 143 | for i,_ in enumerate(tokens[:-1]): 144 | tokens[i] += ", " 145 | 146 | # Return type or end. 147 | if func.returns is not None and include_types_in_signature: 148 | tokens.append(f") -> {ast.unparse(func.returns)}:") 149 | else: 150 | tokens.append(f"):") 151 | 152 | markdown.add_line("```python") 153 | line = initial_str 154 | def _commit(): 155 | nonlocal line 156 | if line != "": 157 | markdown.add_line(line) 158 | line = "" 159 | 160 | for t in tokens: 161 | if len(line) + len(t.rstrip()) > max_function_signature_width: 162 | _commit() 163 | 164 | if line == "": 165 | line += align_width * " " 166 | line += t 167 | if separate_each_parameter_in_function: 168 | _commit() 169 | 170 | _commit() 171 | markdown.add_line("```") 172 | 173 | def function_parameters_docstring_to_markdown(markdown: MarkdownWriter, func: ast.FunctionDef, decls: dict[str, str]) -> None: 174 | all_args: dict[str, tuple[ast.arg, Optional[ast.expr]]] = {} 175 | for arg, default in reversed(list(zip_longest(reversed(func.args.args), reversed(func.args.defaults)))): 176 | all_args[arg.arg] = (arg, default) 177 | if func.args.vararg is not None: 178 | all_args[func.args.vararg.arg] = (func.args.vararg, None) 179 | for arg, default in zip(func.args.kwonlyargs, func.args.kw_defaults): 180 | all_args[arg.arg] = (arg, default) 181 | if func.args.kwarg is not None: 182 | all_args[func.args.kwarg.arg] = (func.args.kwarg, None) 183 | 184 | for name, value in decls.items(): 185 | with markdown.list_item(): 186 | arg, default = all_args.get(name) or (None, None) 187 | if arg is not None: 188 | content = f"**{arg.arg}**" 189 | if arg.annotation is not None and include_types_in_parameter_descriptions: 190 | content += f" (`{ast.unparse(arg.annotation)}`)" 191 | if default is not None and include_defaults_in_parameter_descriptions: 192 | content += f" (*Default: {ast.unparse(default)}*)" 193 | content += f": {value}" 194 | markdown.add_content(content) 195 | else: 196 | # TODO link to name if name is a type 197 | markdown.add_content(f"**{name}**: {value}") 198 | 199 | def docstring_section_to_markdown(markdown: MarkdownWriter, node: Optional[ast.AST], section: DocstringSection) -> None: 200 | with markdown.title(section.name): 201 | with markdown.unordered_list(): 202 | if section.name.lower() == "parameters" and isinstance(node, ast.FunctionDef): 203 | function_parameters_docstring_to_markdown(markdown, node, section.decls) 204 | return 205 | 206 | for name, value in section.decls.items(): 207 | with markdown.list_item(): 208 | markdown.add_content(f"**{name}**: {value}") 209 | 210 | def docstring_to_markdown(markdown: MarkdownWriter, node: ast.AST, module: Module) -> None: 211 | docstring = parse_numpy_docstring(node, module) 212 | if docstring is not None: 213 | if docstring.content is not None: 214 | markdown.add_content(docstring.content) 215 | 216 | for section_id in ["parameters", "returns", "raises"]: 217 | if section_id in docstring.sections: 218 | section = docstring.sections[section_id] 219 | docstring_section_to_markdown(markdown, node, section) 220 | 221 | def function_to_markdown(markdown: MarkdownWriter, func: ast.FunctionDef, parent_basename: Optional[str], module: Module) -> None: 222 | title = "def `" 223 | if parent_basename is not None: 224 | title += f"{parent_basename}." 225 | title += f"{func.name}()`" 226 | with markdown.title(title): 227 | function_def_to_markdown(markdown, func, parent_basename) 228 | docstring_to_markdown(markdown, func, module) 229 | 230 | def class_to_markdown(markdown: MarkdownWriter, cls: ast.ClassDef, parent_basename: str, module: Module) -> None: 231 | with markdown.title(f"class `{parent_basename}.{cls.name}`"): 232 | docstring_to_markdown(markdown, cls, module) 233 | 234 | # Global attributes 235 | attributes_to_markdown(markdown, cls.body, None, module) 236 | 237 | # Functions 238 | function_defs = [node for node in cls.body if isinstance(node, ast.FunctionDef) and not node.name.startswith("_")] 239 | for func in function_defs: 240 | function_to_markdown(markdown, func, cls.name, module) 241 | 242 | def extract_attributes(nodes: list[ast.stmt]) -> dict[str, tuple[ast.AST, Optional[str], Optional[str]]]: 243 | attrs = {} 244 | def _add(ass: Union[ast.Assign, ast.AnnAssign], docnode: ast.Constant): 245 | if isinstance(ass, ast.AnnAssign) and isinstance(ass.target, ast.Name): 246 | if ass.target.id.startswith("_"): 247 | return 248 | attrs[ass.target.id] = (docnode, ast.unparse(ass.annotation), ast.unparse(ass.value) if ass.value is not None else None) 249 | elif isinstance(ass, ast.Assign): 250 | for target in ass.targets: 251 | if not isinstance(target, ast.Name) or target.id.startswith("_"): 252 | return 253 | attrs[target.id] = (docnode, None, ast.unparse(ass.value) if ass.value is not None else None) 254 | 255 | ass_node = None 256 | for node in nodes: 257 | if isinstance(node, (ast.Assign, ast.AnnAssign)): 258 | ass_node = node 259 | elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant) and isinstance(node.value.value, str) and ass_node is not None: 260 | _add(ass_node, node.value) 261 | else: 262 | ass_node = None 263 | return attrs 264 | 265 | def attributes_to_markdown(markdown: MarkdownWriter, nodes: list[ast.stmt], parent_basename: Optional[str], module: Module) -> None: 266 | attributes = extract_attributes(nodes) 267 | if len(attributes) > 0: 268 | with markdown.title("Attributes"): 269 | for name, (docnode, annotation, value) in attributes.items(): 270 | attr_name = name if parent_basename is None else f"{parent_basename}.{name}" 271 | with markdown.title(f"attr `{attr_name}`"): 272 | markdown.add_line("```python") 273 | repr = f"{attr_name}" 274 | if annotation is not None: 275 | repr += f": {annotation}" 276 | if value is not None: 277 | repr += f" = {value}" 278 | markdown.add_line(repr) 279 | markdown.add_line("```") 280 | docstring_to_markdown(markdown, docnode, module) 281 | 282 | def module_to_markdown(markdown: MarkdownWriter, module: Module) -> None: 283 | with markdown.title(module.name): 284 | module_doc = module.docstring 285 | if module_doc is not None: 286 | markdown.add_content(module_doc) 287 | 288 | # Subpackages 289 | if len(module.packages) > 0: 290 | with markdown.title("Subpackages"): 291 | with markdown.unordered_list(): 292 | for submod in module.packages: 293 | with markdown.list_item(): 294 | submod_ref = astdown.loader.replace_crossrefs(f"`{module.basename}.{submod.basename}`", submod.ast, module) 295 | markdown.add_content(f"{submod_ref} ‒ {short_docstring(submod.ast, submod) or '*No description.*'}") 296 | 297 | # Submodules 298 | if len(module.modules) > 0: 299 | with markdown.title("Submodules"): 300 | with markdown.unordered_list(): 301 | for submod in module.modules: 302 | with markdown.list_item(): 303 | submod_ref = astdown.loader.replace_crossrefs(f"`{module.basename}.{submod.basename}`", submod.ast, module) 304 | markdown.add_content(f"{submod_ref} ‒ {short_docstring(submod.ast, submod) or '*No description.*'}") 305 | 306 | # Global attributes 307 | attributes_to_markdown(markdown, module.ast.body, module.basename, module) 308 | 309 | # Classes 310 | class_defs = [node for node in module.ast.body if isinstance(node, ast.ClassDef) and not node.name.startswith("_")] 311 | for cls in class_defs: 312 | class_to_markdown(markdown, cls, module.basename, module) 313 | 314 | # Functions 315 | function_defs = [node for node in module.ast.body if isinstance(node, ast.FunctionDef) and not node.name.startswith("_")] 316 | if len(function_defs) > 0: 317 | with markdown.title("Functions"): 318 | for func in function_defs: 319 | function_to_markdown(markdown, func, module.basename, module) 320 | -------------------------------------------------------------------------------- /docs/fora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oddlama/fora/e270a021c45666c8b22250f1ae7a1534fe6040d3/docs/fora.png -------------------------------------------------------------------------------- /docs/make.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import ast 5 | import functools 6 | import re 7 | import shutil 8 | import sys 9 | from pathlib import Path 10 | from textwrap import dedent 11 | from typing import Any, Optional, Union 12 | 13 | import astdown.loader 14 | from astdown.loader import Module, find_package_or_module, short_docstring 15 | from astdown.markdown import MarkdownWriter, module_to_markdown 16 | 17 | from rich import print as rprint 18 | from rich.markup import escape as rescape 19 | 20 | def print(msg: Any, *args, **kwargs): 21 | rprint(rescape(msg) if isinstance(msg, str) else msg, *args, **kwargs) 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description="Builds the documentation for fora.") 25 | parser.add_argument('-o', '--output-dir', dest='output_dir', default="build", type=str, 26 | help="Specifies the output directory for the documentation. (default: 'build')") 27 | parser.add_argument('-I', '--include-path', dest='include_path', action='append', default=[], type=str, 28 | help="Specify an additional directory to add to the python module search path. Can be given multiple times.") 29 | parser.add_argument('--clean', action='store_true', 30 | help="Clean the build directory before generating new documentation.") 31 | ### parser.add_argument('modules', nargs='+', type=str, 32 | ### help="The modules to generate documentation for.") 33 | args = parser.parse_args() 34 | 35 | # TODO this is only for fora. make the tool generic at some point 36 | args.modules = ["fora"] 37 | ref_prefix = "api/" 38 | 39 | build_path = Path(args.output_dir) 40 | if args.clean: 41 | # Clean last build 42 | if build_path.exists(): 43 | shutil.rmtree(build_path) 44 | build_path.mkdir(parents=True, exist_ok=True) 45 | 46 | # Add args to python module search path 47 | for p in args.include_path: 48 | sys.path.insert(0, p) 49 | 50 | # Find packages and modules 51 | stack: list[Module] = [] 52 | for i in args.modules: 53 | module = find_package_or_module(i) 54 | if module is None: 55 | raise ModuleNotFoundError(f"Could not find source file for module '{i}'") 56 | stack.append(module) 57 | 58 | # Deduplicate and flatten 59 | modules: dict[str, Module] = {} 60 | while len(stack) > 0: 61 | m = stack.pop() 62 | if m.name in modules: 63 | continue 64 | modules[m.name] = m 65 | stack.extend(m.packages) 66 | stack.extend(m.modules) 67 | 68 | def _to_path(module: Module) -> str: 69 | #if len(module.modules) > 0 or len(module.packages) > 0: 70 | # return f"{module.name.replace('.', '/')}/__init__.md" 71 | #else: 72 | return f"{module.name.replace('.', '/')}.md" 73 | 74 | # Index references 75 | print("Indexing references") 76 | index = {} 77 | for module in modules.values(): 78 | module.generate_index(ref_prefix + _to_path(module)) 79 | index.update(module.index or {}) 80 | 81 | # Register cross-reference replacer 82 | def _replace_crossref(match: Any, node: ast.AST, module: Module) -> str: 83 | fqname = match.group(1) 84 | if fqname.startswith(".") or fqname.endswith("."): 85 | return match.group(0) 86 | if fqname not in index and "." in fqname: 87 | for key in index: 88 | if key.endswith(fqname): 89 | fqname = key 90 | break 91 | if fqname not in index: 92 | if "." in fqname: 93 | print(f"warning: Skipping invalid reference '{match.group(1)}' in {module.path}:{node.lineno}", file=sys.stderr) 94 | return match.group(0) 95 | 96 | idx = index[fqname] 97 | module_idx = index[module.name] 98 | url = idx.url(relative_to=module_idx.url()).replace('_', r'\_') 99 | return f"[`{idx.display_name()}`]({url})" 100 | 101 | ref_pattern = re.compile(r"(? str: 103 | return ref_pattern.sub(functools.partial(_replace_crossref, node=node, module=module), content) 104 | astdown.loader.replace_crossrefs = _replace_crossrefs 105 | 106 | def _link_to(fqname: str, display_name: Optional[str] = None, relative_to: Optional[Union[Module, str]] = None, code: bool = True) -> str: 107 | idx = index[fqname] 108 | to = None 109 | if relative_to is not None: 110 | to = relative_to.name if isinstance(relative_to, Module) else relative_to 111 | if display_name is None: 112 | display_name = idx.display_name() 113 | assert display_name is not None 114 | if not code: 115 | display_name = display_name.replace('_', r'\_') 116 | url = idx.url(to).replace('_', r'\_') 117 | cm = "`" if code else "" 118 | return f"[{cm}{display_name}{cm}]({url})" 119 | 120 | # Generate documentation 121 | print("Generating markdown") 122 | for i,module in enumerate(modules.values()): 123 | print(f"[{100*(i+1)/len(modules):6.2f}%] Processing {module.name}") 124 | markdown = MarkdownWriter() 125 | module_to_markdown(markdown, module) 126 | file_path = build_path / _to_path(module) 127 | file_path.parent.mkdir(parents=True, exist_ok=True) 128 | with open(file_path, "w") as f: 129 | f.write(markdown.content.strip("\n") + "\n") 130 | 131 | # Generate API index 132 | print("Generating API index") 133 | markdown = MarkdownWriter() 134 | name_overrides = {"fora": "Fora API"} 135 | def _recursive_list_module(module: Module): 136 | with markdown.list_item(indent=" "): 137 | markdown.add_line(_link_to(module.name, display_name=name_overrides.get(module.name), code=False)) 138 | if len(module.modules) > 0: 139 | with markdown.unordered_list(sign="*"): 140 | for submod in sorted(module.packages, key=lambda x: x.name): 141 | _recursive_list_module(submod) 142 | for submod in sorted(module.modules, key=lambda x: x.name): 143 | with markdown.list_item(indent=" "): 144 | markdown.add_line(_link_to(submod.name, code=False)) 145 | 146 | with markdown.title("Fora API"): 147 | markdown.margin(2) 148 | with markdown.unordered_list(sign="*"): 149 | with markdown.list_item(indent=" "): 150 | markdown.add_line(r"[Operations Index](api/index\_operations.md)") 151 | _recursive_list_module(modules["fora"]) 152 | 153 | with open(build_path / f"API_SUMMARY.md", "w") as f: 154 | f.write(dedent( 155 | markdown.content 156 | .replace("* ", "* ") 157 | .strip("\n") 158 | .replace("\n\n", "\n")) + "\n") 159 | 160 | # Generate Operations index 161 | print("Generating operations index") 162 | markdown = MarkdownWriter() 163 | with markdown.title("Operations"): 164 | for submod in sorted(modules["fora.operations"].modules, key=lambda x: x.name): 165 | if submod.name in ["fora.operations.api", "fora.operations.utils"]: 166 | continue 167 | with markdown.title(index[submod.name].display_name()): 168 | assert submod.index is not None 169 | for key, idx in submod.index.items(): 170 | if "._" in key: 171 | continue 172 | if idx.type != "function": 173 | continue 174 | with markdown.unordered_list(): 175 | with markdown.list_item(): 176 | markdown.add_content(_link_to(key, relative_to=ref_prefix) + f" ‒ {short_docstring(idx.node, submod)}") 177 | 178 | with open(build_path / f"index_operations.md", "w") as f: 179 | f.write(markdown.content.strip("\n") + "\n") 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /pylint-check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pylint src/fora --ignore=version.py --disable=line-too-long,invalid-name,too-many-instance-attributes,too-few-public-methods,too-many-arguments,too-many-locals,duplicate-code"${1:+,}$1" 3 | mypy src/fora 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools_scm] 6 | write_to = "src/fora/version.py" 7 | git_describe_command = "git describe --dirty --tags --long --match 'v*' --first-parent" 8 | 9 | [tool.twine] 10 | sign = true 11 | 12 | [tool.mypy] 13 | ignore_missing_imports = true 14 | disallow_untyped_defs = true 15 | disallow_incomplete_defs = true 16 | no_implicit_optional = true 17 | warn_redundant_casts = true 18 | warn_return_any = true 19 | warn_unreachable = true 20 | local_partial_types = true 21 | strict_equality = true 22 | show_error_codes = true 23 | show_traceback = true 24 | pretty = true 25 | no_implicit_reexport = true 26 | 27 | [tool.pylint.'MESSAGES CONTROL'] 28 | ignore = "version.py" 29 | disable = "fixme, line-too-long, invalid-name, too-many-instance-attributes, too-few-public-methods, too-many-arguments, too-many-locals, duplicate-code" 30 | 31 | [tool.coverage.run] 32 | branch = true 33 | source = ["src/fora"] 34 | parallel = true 35 | 36 | [tool.coverage.report] 37 | show_missing = true 38 | exclude_lines = [ 39 | '^\s*raise AssertionError\b', 40 | '^\s*raise NotImplementedError\b', 41 | '^\s*raise$', 42 | '^\s*except ModuleNotFoundError:$', 43 | '^\s*pass$', 44 | "^if __name__ == ['\"]__main__['\"]:$", 45 | ] 46 | 47 | [tool.tox] 48 | legacy_tox_ini = """ 49 | [tox] 50 | isolated_build = True 51 | envlist = py39,pylint,type,docs,coverage 52 | 53 | [testenv] 54 | description = run tests and create coverage data 55 | deps = 56 | pytest 57 | coverage[toml] 58 | pytest-cov 59 | passenv = SSH_AUTH_SOCK 60 | commands = pytest -v --cov={envsitepackagesdir}/fora 61 | 62 | [testenv:pylint] 63 | description = check with pylint 64 | deps = pylint 65 | basepython = python3.9 66 | commands = pylint src/fora 67 | 68 | [testenv:type] 69 | description = type-check with mypy 70 | deps = mypy 71 | basepython = python3.9 72 | commands = python -m mypy src/fora 73 | 74 | [testenv:docs] 75 | description = check if docs can be built 76 | deps = pdoc 77 | basepython = python3.9 78 | extras = docs 79 | commands = python docs/make.py -I src/ -o {toxworkdir}/docs_build 80 | 81 | [testenv:coverage] 82 | description = [run after tests]: combine coverage data and create report 83 | deps = coverage[toml] 84 | skip_install = true 85 | commands = 86 | coverage html 87 | coverage report --fail-under=80 88 | """ 89 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = fora 3 | author = oddlama 4 | author_email = oddlama@oddlama.org 5 | description = A simple infrastructure and configuration management tool 6 | long_description = file: README.md 7 | long_description_content_type = text/markdown 8 | license = MIT 9 | license_files = LICENSE 10 | url = https://github.com/oddlama/fora 11 | project_urls = 12 | Documentation = https://oddlama.gitbook.io/fora/ 13 | Source = https://github.com/oddlama/fora 14 | Issues = https://github.com/oddlama/fora/issues 15 | classifier = 16 | Development Status :: 4 - Beta 17 | Environment :: Console 18 | Intended Audience :: Developers 19 | Intended Audience :: Information Technology 20 | Intended Audience :: System Administrators 21 | License :: OSI Approved :: MIT License 22 | Operating System :: POSIX :: Linux 23 | Programming Language :: Python :: 3 :: Only 24 | Programming Language :: Python :: 3.9 25 | Topic :: System :: Installation/Setup 26 | Topic :: System :: Systems Administration 27 | Topic :: Utilities 28 | keywords = ansible, configuration, deploy, deployment, infra, infrastructure, management, puppet, saltstack, config, fora 29 | 30 | [options] 31 | zip_safe = False 32 | python_requires = >=3.9 33 | install_requires = 34 | jinja2 35 | package_dir= 36 | =src 37 | packages=find: 38 | 39 | [options.packages.find] 40 | where=src 41 | 42 | [options.entry_points] 43 | console_scripts = 44 | fora = fora.main:main 45 | -------------------------------------------------------------------------------- /src/fora/__init__.py: -------------------------------------------------------------------------------- 1 | """The main module of fora.""" 2 | 3 | from __future__ import annotations 4 | 5 | import argparse 6 | from typing import TYPE_CHECKING, cast 7 | 8 | if TYPE_CHECKING: 9 | from fora.types import GroupWrapper, HostWrapper, ScriptWrapper 10 | from fora.inventory_wrapper import InventoryWrapper 11 | 12 | args: argparse.Namespace = cast(argparse.Namespace, None) 13 | """ 14 | The global logger. Should be used for all user-facing information logging to ensure 15 | that this information is displayed in a proper format and according to the user's 16 | verbosity preferences. 17 | """ 18 | 19 | inventory: InventoryWrapper = cast("InventoryWrapper", None) 20 | """ 21 | The inventory module we are operating on. 22 | This is loaded from the inventory definition file. 23 | """ 24 | 25 | group: GroupWrapper = cast("GroupWrapper", None) 26 | """ 27 | This variable wraps the currently loaded group module. 28 | It must not be accessed anywhere else but inside the 29 | definition (source) of the actual group module. 30 | """ 31 | 32 | host: HostWrapper = cast("HostWrapper", None) # Cast None to ease typechecking in user code. 33 | """ 34 | This variable wraps the currently loaded hosts module (in case a host is just being defined), 35 | or the currently active host while executing a script. It must not be used anywhere else 36 | but inside the definition (source) of the actual module or inside of a script. 37 | """ 38 | 39 | script: ScriptWrapper = cast("ScriptWrapper", None) # Cast None to ease typechecking in user code. 40 | """This variable wraps the currently executed script module (if any).""" 41 | -------------------------------------------------------------------------------- /src/fora/connection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides a class to manage a remote connection via the host's connector. 3 | Stores state along with the connection. 4 | """ 5 | 6 | from __future__ import annotations 7 | from copy import copy 8 | 9 | from types import TracebackType 10 | from typing import Type, cast, Optional 11 | 12 | import fora 13 | from fora import logger 14 | from fora.connectors.connector import Connector, CompletedRemoteCommand, GroupEntry, StatResult, UserEntry 15 | from fora.remote_settings import RemoteSettings 16 | from fora.types import HostWrapper 17 | 18 | class Connection: 19 | """ 20 | The connection class represents a connection to a host. 21 | It consists of a connector, which is actually responsible for 22 | providing remote access, and some state, which determines defaults 23 | for the commands executed on the remote system. 24 | """ 25 | 26 | def __init__(self, host: HostWrapper): 27 | self.host = host 28 | self.connector: Connector = self.host.create_connector() 29 | self.base_settings: RemoteSettings = copy(self.host.inventory.base_remote_settings()) 30 | 31 | def __enter__(self) -> Connection: 32 | self.connector.open() 33 | self.host.connection = self 34 | self._resolve_identity() 35 | return self 36 | 37 | def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]) -> None: 38 | _ = (exc_type, exc, traceback) 39 | self.host.connection = cast(Connection, None) 40 | self.connector.close() 41 | 42 | def _resolve_identity(self) -> None: 43 | """ 44 | Query the user and group under which we are operating, and store it 45 | in our base_settings. This ensures that the base settings reflect 46 | the actual user as which we operate. 47 | """ 48 | user = self.resolve_user(None) 49 | group = self.resolve_group(None) 50 | self.base_settings.as_user = user 51 | self.base_settings.as_group = group 52 | self.base_settings.owner = user 53 | self.base_settings.group = group 54 | 55 | def resolve_defaults(self, settings: RemoteSettings) -> RemoteSettings: 56 | """ 57 | Resolves (and verifies) the given settings against the current defaults, 58 | and returns tha actual values that should now be in effect. Verification 59 | means that this method will fail if e.g. the cwd doesn't exist on the remote. 60 | 61 | Parameters 62 | ---------- 63 | settings 64 | Additional overrides for the current defaults 65 | 66 | Returns 67 | ------- 68 | RemoteSettings 69 | The resolved settings 70 | """ 71 | # pylint: disable=protected-access 72 | if fora.script is None: 73 | raise RuntimeError("Cannot resolve defaults, when no script is currently running.") 74 | 75 | # Overlay settings on top of defaults 76 | settings = fora.script.current_defaults().overlay(settings) 77 | 78 | # A function to check whether a mask is octal 79 | def check_mask(mask: Optional[str], name: str) -> None: 80 | if mask is None: 81 | raise ValueError(f"Error while resolving settings: {name} cannot be None!") 82 | try: 83 | int(mask, 8) 84 | except ValueError: 85 | raise ValueError(f"Error while resolving settings: {name} is '{mask}' but must be octal!") # pylint: disable=raise-missing-from 86 | 87 | settings.as_user = None if settings.as_user is None else self.resolve_user(settings.as_user) 88 | settings.as_group = None if settings.as_group is None else self.resolve_group(settings.as_group) 89 | settings.owner = None if settings.owner is None else self.resolve_user(settings.owner) 90 | settings.group = None if settings.group is None else self.resolve_group(settings.group) 91 | check_mask(settings.file_mode, "file_mode") 92 | check_mask(settings.dir_mode, "dir_mode") 93 | check_mask(settings.umask, "umask") 94 | if settings.cwd: 95 | s = self.stat(settings.cwd) 96 | if not s: 97 | raise ValueError(f"The selected working directory '{settings.cwd}' doesn't exist!") 98 | if s.type != "dir": 99 | raise ValueError(f"The selected working directory '{settings.cwd}' is not a directory!") 100 | 101 | return settings 102 | 103 | def run(self, 104 | command: list[str], 105 | input: Optional[bytes] = None, # pylint: disable=redefined-builtin 106 | capture_output: bool = True, 107 | check: bool = True, 108 | user: Optional[str] = None, 109 | group: Optional[str] = None, 110 | umask: Optional[str] = None, 111 | cwd: Optional[str] = None) -> CompletedRemoteCommand: 112 | """See `fora.connectors.connector.Connector.run`.""" 113 | logger.debug_args("Connection.run", locals()) 114 | defaults = fora.script.current_defaults() 115 | return self.connector.run( 116 | command=command, 117 | input=input, 118 | capture_output=capture_output, 119 | check=check, 120 | user=user if user is not None else defaults.as_user, 121 | group=group if group is not None else defaults.as_group, 122 | umask=umask if umask is not None else defaults.umask, 123 | cwd=cwd if cwd is not None else defaults.cwd) 124 | 125 | def resolve_user(self, user: Optional[str]) -> str: 126 | """See `fora.connectors.connector.Connector.resolve_user`.""" 127 | logger.debug_args("Connection.resolve_user", locals()) 128 | return self.connector.resolve_user(user) 129 | 130 | def resolve_group(self, group: Optional[str]) -> str: 131 | """See `fora.connectors.connector.Connector.resolve_group`.""" 132 | logger.debug_args("Connection.resolve_group", locals()) 133 | return self.connector.resolve_group(group) 134 | 135 | def stat(self, path: str, follow_links: bool = False, sha512sum: bool = False) -> Optional[StatResult]: 136 | """See `fora.connectors.connector.Connector.stat`.""" 137 | logger.debug_args("Connection.stat", locals()) 138 | return self.connector.stat( 139 | path=path, 140 | follow_links=follow_links, 141 | sha512sum=sha512sum) 142 | 143 | def upload(self, 144 | file: str, 145 | content: bytes, 146 | mode: Optional[str] = None, 147 | owner: Optional[str] = None, 148 | group: Optional[str] = None) -> None: 149 | """See `fora.connectors.connector.Connector.upload`.""" 150 | logger.debug_args("Connection.upload", locals()) 151 | return self.connector.upload( 152 | file=file, 153 | content=content, 154 | mode=mode, 155 | owner=owner, 156 | group=group) 157 | 158 | def download(self, file: str) -> bytes: 159 | """See `fora.connectors.connector.Connector.download`.""" 160 | logger.debug_args("Connection.download", locals()) 161 | return self.connector.download(file=file) 162 | 163 | def download_or(self, file: str, default: Optional[bytes] = None) -> Optional[bytes]: 164 | """ 165 | Same as `Connection.download`, but returns the given default in case the file doesn't exist. 166 | 167 | Parameters 168 | ---------- 169 | file 170 | The file to download. 171 | default 172 | The alternative to return if the file doesn't exist. 173 | 174 | Returns 175 | ------- 176 | Optional[bytes] 177 | The downloaded file or the default if the file didn't exist. 178 | 179 | Raises 180 | ------ 181 | fora.connectors.tunnel_dispatcher.RemoteOSError 182 | If the remote command fails for any reason other than file not found. 183 | IOError 184 | An error occurred with the connection. 185 | """ 186 | try: 187 | return self.download(file=file) 188 | except ValueError: 189 | return default 190 | 191 | def query_user(self, user: str, query_password_hash: bool = False, default: Optional[UserEntry] = None) -> Optional[UserEntry]: 192 | """See `fora.connectors.connector.Connector.query_user`, but returns the given default in case the user doesn't exist.""" 193 | logger.debug_args("Connection.query_user", locals()) 194 | try: 195 | return self.connector.query_user(user=user, query_password_hash=query_password_hash) 196 | except ValueError: 197 | return default 198 | 199 | def query_group(self, group: str, default: Optional[GroupEntry] = None) -> Optional[GroupEntry]: 200 | """See `fora.connectors.connector.Connector.query_group`, but returns the given default in case the group doesn't exist.""" 201 | logger.debug_args("Connection.query_group", locals()) 202 | try: 203 | return self.connector.query_group(group=group) 204 | except ValueError: 205 | return default 206 | 207 | def home_dir(self, user: Optional[str] = None) -> str: 208 | """ 209 | Return's the home directory of the given user. If the user is None, 210 | it defaults to the current user. 211 | 212 | Parameters 213 | ---------- 214 | user 215 | The user. 216 | 217 | Returns 218 | ------- 219 | str 220 | The home directory of the requested user. 221 | 222 | Raises 223 | ------ 224 | ValueError 225 | If the user could not be resolved. 226 | fora.connectors.tunnel_dispatcher.RemoteOSError 227 | If the remote command fails because of an remote OSError. 228 | IOError 229 | An error occurred with the connection. 230 | """ 231 | logger.debug_args("Connection.home_dir", locals()) 232 | if user is None: 233 | user = self.resolve_user(None) 234 | return self.connector.query_user(user=user).home 235 | 236 | def getenv(self, key: str, default: Optional[str] = None) -> Optional[str]: 237 | """See `fora.connectors.connector.Connector.getenv`, but returns the given default in case the key doesn't exist.""" 238 | logger.debug_args("Connection.getenv", locals()) 239 | val = self.connector.getenv(key=key) 240 | return default if val is None else val 241 | 242 | def open_connection(host: HostWrapper) -> Connection: 243 | """ 244 | Returns a connection (context manager) that opens the connection when it is entered and 245 | closes it when it is exited. The connection can be obtained via host.connection, 246 | as long as it is opened. 247 | 248 | Parameters 249 | ---------- 250 | host 251 | The host to which a connection should be opened 252 | 253 | Returns 254 | ------- 255 | Connection 256 | The connection (context manager) 257 | """ 258 | return Connection(host) 259 | -------------------------------------------------------------------------------- /src/fora/connectors/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains all standard conectors to register them by default.""" 2 | 3 | from fora.utils import import_submodules 4 | 5 | # Import all submodules to ensure that decorators have a chance 6 | # to register operations to a registry (e.g. package_managers). 7 | import_submodules(__name__) 8 | -------------------------------------------------------------------------------- /src/fora/connectors/connector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the connector interface. 3 | """ 4 | 5 | from __future__ import annotations 6 | from dataclasses import dataclass 7 | from typing import Callable, Optional, Type, Union 8 | 9 | from fora.types import HostWrapper 10 | 11 | @dataclass 12 | class CompletedRemoteCommand: 13 | """The return value of `Connector.run()`, representing a finished remote process.""" 14 | stdout: Optional[bytes] 15 | """The stdout of the remote command. If you need a string, but are not sure that the formatting is utf-8, 16 | be sure to decode with `errors="surrogateescape"` or `errors=backslashreplace` depending on your use.""" 17 | stderr: Optional[bytes] 18 | """The stderr of the remote command. If you need a string, but are not sure that the formatting is utf-8, 19 | be sure to decode with `errors="surrogateescape"` or `errors=backslashreplace` depending on your use.""" 20 | returncode: int 21 | """The return code of the remote command.""" 22 | 23 | class StatResult: 24 | """ 25 | The return value of stat(), representing information about a remote file. 26 | The type will be one of [ "dir", "chr", "blk", "file", "fifo", "link", "sock", "other" ]. 27 | If requested, the sha512sum of the file will be included. 28 | """ 29 | def __init__(self, 30 | type: str, # pylint: disable=redefined-builtin 31 | mode: Union[int, str], 32 | owner: str, 33 | group: str, 34 | size: int, 35 | mtime: int, 36 | ctime: int, 37 | sha512sum: Optional[bytes]): 38 | self.type = type 39 | self.mode: str = mode if isinstance(mode, str) else oct(mode)[2:] 40 | self.owner = owner 41 | self.group = group 42 | self.size = size 43 | self.mtime = mtime 44 | self.ctime = ctime 45 | self.sha512sum = sha512sum 46 | 47 | @dataclass 48 | class UserEntry: 49 | """The result of a user query.""" 50 | name: str 51 | """The name of the user""" 52 | uid: int 53 | """The numerical user id""" 54 | group: str 55 | """The name of the primary group""" 56 | gid: int 57 | """The numerical primary group id""" 58 | groups: list[str] 59 | """All names of the supplementary groups this user belongs to""" 60 | password_hash: Optional[str] 61 | """The password hash from shadow, if requested.""" 62 | gecos: str 63 | """The comment (GECOS) field of the user""" 64 | home: str 65 | """The home directory of the user""" 66 | shell: str 67 | """The default shell of the user""" 68 | 69 | @dataclass 70 | class GroupEntry: 71 | """The result of a group query.""" 72 | name: str 73 | """The name of the group""" 74 | gid: int 75 | """The numerical group id""" 76 | members: list[str] 77 | """All the group member's user names""" 78 | 79 | class Connector: 80 | """The base class for all connectors.""" 81 | 82 | schema: str 83 | """ 84 | The schema of the connector. Must match the schema used in urls of this connector, 85 | such as `ssh` for `ssh:...`. May also appear in log messages. A schema is the part 86 | of the url until (but not including) the first colon. Set by the @connector decorator. 87 | """ 88 | 89 | registered_connectors: dict[str, Type[Connector]] = {} 90 | """The list of all registered connectors.""" 91 | 92 | def __init__(self, url: Optional[str], host: HostWrapper): 93 | self.url = url 94 | self.host = host 95 | 96 | def open(self) -> None: 97 | """Opens the connection to the remote host.""" 98 | raise NotImplementedError("Must be overwritten by subclass.") 99 | 100 | def close(self) -> None: 101 | """Closes the connection to the remote host.""" 102 | raise NotImplementedError("Must be overwritten by subclass.") 103 | 104 | def run(self, 105 | command: list[str], 106 | input: Optional[bytes] = None, # pylint: disable=redefined-builtin 107 | capture_output: bool = True, 108 | check: bool = True, 109 | user: Optional[str] = None, 110 | group: Optional[str] = None, 111 | umask: Optional[str] = None, 112 | cwd: Optional[str] = None) -> CompletedRemoteCommand: 113 | """ 114 | Runs the given command on the remote, returning a CompletedRemoteCommand 115 | containing the returned information (if any) and the status code. 116 | 117 | Parameters 118 | ---------- 119 | command 120 | The command to be executed on the remote host. 121 | input 122 | Input to the remote command. 123 | capture_output 124 | Whether the output of the command should be captured. 125 | check 126 | Whether to raise an exception if the remote command returns with a non-zero exit status. 127 | user 128 | The remote user under which the command should be run. Also sets the group 129 | to the primary group of that user if it isn't explicitly given. If not given, the command 130 | is run as the user under which the remote dispatcher is running (usually root). 131 | group 132 | The remote group under which the command should be run. If not given, the command 133 | is run as the group under which the remote dispatcher is running (usually root), 134 | or in case the user was explicitly specified, the primary group of that user. 135 | umask 136 | The umask to use when executing the command on the remote system. Defaults to "077". 137 | cwd 138 | The remote working directory under which the command should be run. 139 | 140 | Returns 141 | ------- 142 | CompletedRemoteCommand 143 | The result of the remote command. 144 | 145 | Raises 146 | ------ 147 | subprocess.CalledProcessError 148 | If check is True and the process returned a non-zero exit status. 149 | ValueError 150 | A parameter was invalid. 151 | fora.connectors.tunnel_dispatcher.RemoteOSError 152 | If the remote command fails because of an remote OSError. 153 | IOError 154 | An error occurred with the connection. 155 | """ 156 | _ = (self, command, input, capture_output, check, user, group, umask, cwd) 157 | raise NotImplementedError("Must be overwritten by subclass.") 158 | 159 | def resolve_user(self, user: Optional[str]) -> str: 160 | """ 161 | Resolves the given user on the remote, returning 162 | the canonicalized username. If the given user is None, instead 163 | returns the user as which the remote command is running. 164 | 165 | Parameters 166 | ---------- 167 | user 168 | The username or uid that should be resolved, or None to query the current user. 169 | 170 | Returns 171 | ------- 172 | str 173 | The resolved username or None if the input was None. 174 | 175 | Raises 176 | ------ 177 | ValueError 178 | If the user could not be resolved. 179 | fora.connectors.tunnel_dispatcher.RemoteOSError 180 | If the remote command fails because of an remote OSError. 181 | IOError 182 | An error occurred with the connection. 183 | """ 184 | _ = (self, user) 185 | raise NotImplementedError("Must be overwritten by subclass.") 186 | 187 | def resolve_group(self, group: Optional[str]) -> str: 188 | """ 189 | Resolves the given group on the remote, returning 190 | the canonicalized groupname. If the given group is None, instead 191 | returns the group as which the remote command is running. 192 | 193 | Parameters 194 | ---------- 195 | group 196 | The groupname or gid that should be resolved, or None to query the current group. 197 | 198 | Returns 199 | ------- 200 | str 201 | The resolved groupname or None if the input was None. 202 | 203 | Raises 204 | ------ 205 | ValueError 206 | If the group could not be resolved. 207 | fora.connectors.tunnel_dispatcher.RemoteOSError 208 | If the remote command fails because of an remote OSError. 209 | IOError 210 | An error occurred with the connection. 211 | """ 212 | _ = (self, group) 213 | raise NotImplementedError("Must be overwritten by subclass.") 214 | 215 | def stat(self, path: str, follow_links: bool = False, sha512sum: bool = False) -> Optional[StatResult]: 216 | """ 217 | Runs `os.stat()` on the given path on the remote. Follows links if follow_links 218 | is true. Includes the sha512sum if desired and if the path is a file. 219 | 220 | Returns None if the remote path doesn't exist. 221 | 222 | Parameters 223 | ---------- 224 | path 225 | The path to stat. 226 | follow_links 227 | Whether to follow symbolic links instead of running stat on the link. 228 | sha512sum 229 | Whether to include the sha512sum if the path is a file. 230 | 231 | Returns 232 | ------- 233 | Optional[StatResult] 234 | The stat result or None if the path didn't exist. 235 | 236 | Raises 237 | ------ 238 | fora.connectors.tunnel_dispatcher.RemoteOSError 239 | If the remote command fails for any reason other than file not found. 240 | IOError 241 | An error occurred with the connection. 242 | """ 243 | _ = (self, path, follow_links, sha512sum) 244 | raise NotImplementedError("Must be overwritten by subclass.") 245 | 246 | def upload(self, 247 | file: str, 248 | content: bytes, 249 | mode: Optional[str] = None, 250 | owner: Optional[str] = None, 251 | group: Optional[str] = None) -> None: 252 | """ 253 | Uploads the given content to the remote system and saves it under the given file path. Overwrites existing files. 254 | 255 | Parameters 256 | ---------- 257 | file 258 | The file where the content will be saved. 259 | content 260 | The file content. 261 | owner 262 | The owner for the file. Defaults to root if not given. 263 | group 264 | The group for the file. If the owner is given, defaults to the primary 265 | group of the owner, otherwise defaults to root. 266 | mode 267 | The mode for the file. Defaults to '600' if not given. 268 | 269 | Raises 270 | ------ 271 | ValueError 272 | A parameter was invalid. 273 | fora.connectors.tunnel_dispatcher.RemoteOSError 274 | If the remote command fails because of an remote OSError. 275 | IOError 276 | An error occurred with the connection. 277 | """ 278 | _ = (self, file, content, mode, owner, group) 279 | raise NotImplementedError("Must be overwritten by subclass.") 280 | 281 | def download(self, file: str) -> bytes: 282 | """ 283 | Downloads the given file from the remote system. 284 | 285 | Parameters 286 | ---------- 287 | file 288 | The file to download. 289 | 290 | Raises 291 | ------ 292 | ValueError 293 | If the file was not found. 294 | fora.connectors.tunnel_dispatcher.RemoteOSError 295 | If the remote command fails for any reason other than file not found. 296 | IOError 297 | An error occurred with the connection. 298 | """ 299 | _ = (self, file) 300 | raise NotImplementedError("Must be overwritten by subclass.") 301 | 302 | def query_user(self, user: str, query_password_hash: bool = False) -> UserEntry: 303 | """ 304 | Queries information about a user on the reomte system. 305 | 306 | Parameters 307 | ---------- 308 | user 309 | The username or uid that should be queried. 310 | query_password_hash 311 | Whether the password hash should also be returned. Requires elevated privileges. 312 | 313 | Returns 314 | ------- 315 | UserEntry 316 | The information about the user. 317 | 318 | Raises 319 | ------ 320 | ValueError 321 | If the user could not be resolved. 322 | fora.connectors.tunnel_dispatcher.RemoteOSError 323 | If the remote command fails because of an remote OSError. 324 | IOError 325 | An error occurred with the connection. 326 | """ 327 | _ = (self, user, query_password_hash) 328 | raise NotImplementedError("Must be overwritten by subclass.") 329 | 330 | def query_group(self, group: str) -> GroupEntry: 331 | """ 332 | Queries information about a group on the reomte system. 333 | 334 | Parameters 335 | ---------- 336 | group 337 | The groupname or gid that should be queried. 338 | 339 | Returns 340 | ------- 341 | GroupEntry 342 | The resolved groupname or None if the input was None. 343 | 344 | Raises 345 | ------ 346 | ValueError 347 | If the group could not be resolved. 348 | fora.connectors.tunnel_dispatcher.RemoteOSError 349 | If the remote command fails because of an remote OSError. 350 | IOError 351 | An error occurred with the connection. 352 | """ 353 | _ = (self, group) 354 | raise NotImplementedError("Must be overwritten by subclass.") 355 | 356 | def getenv(self, key: str) -> Optional[str]: 357 | """ 358 | Return's an environment variable from the remote host. 359 | 360 | Parameters 361 | ---------- 362 | key 363 | The variable to get. 364 | 365 | Returns 366 | ------- 367 | str 368 | The corresponding variable if found, None otherwise. 369 | 370 | Raises 371 | ------ 372 | ValueError 373 | If the user could not be resolved. 374 | fora.connectors.tunnel_dispatcher.RemoteOSError 375 | If the remote command fails because of an remote OSError. 376 | IOError 377 | An error occurred with the connection. 378 | """ 379 | _ = (self, key) 380 | raise NotImplementedError("Must be overwritten by subclass.") 381 | 382 | @classmethod 383 | def extract_hostname(cls, url: str) -> str: 384 | """ 385 | Extracts the hostname from a given url where the schema matches this connector. 386 | 387 | Raises 388 | ------ 389 | ValueError 390 | The provided url was invalid. 391 | 392 | Parameters 393 | ---------- 394 | url 395 | The url to extract the hostname from. 396 | 397 | Returns 398 | ------- 399 | str 400 | The extracted hostname. 401 | """ 402 | _ = (url) 403 | raise NotImplementedError("Must be overwritten by subclass.") 404 | 405 | def connector(schema: str) -> Callable[[Type[Connector]], Type[Connector]]: 406 | """ 407 | The @connector class decorator used to register the connector 408 | to the global registry. 409 | 410 | Parameters 411 | ---------- 412 | schema 413 | The schema for the connector, for example 'ssh'. 414 | """ 415 | def wrapper(cls: Type[Connector]) -> Type[Connector]: 416 | cls.schema = schema 417 | Connector.registered_connectors[cls.schema] = cls 418 | return cls 419 | return wrapper 420 | -------------------------------------------------------------------------------- /src/fora/connectors/local.py: -------------------------------------------------------------------------------- 1 | """Contains a connector which handles connections to hosts via SSH.""" 2 | 3 | import os 4 | from typing import Optional 5 | 6 | import fora 7 | from fora.connectors import tunnel_dispatcher as td 8 | from fora.connectors.connector import connector 9 | from fora.connectors.tunnel_connector import TunnelConnector 10 | from fora.types import HostWrapper 11 | 12 | @connector(schema='local') 13 | class LocalConnector(TunnelConnector): 14 | """A tunnel connector that provides remote access to the current local machine via a subprocess.""" 15 | 16 | def __init__(self, url: Optional[str], host: HostWrapper): 17 | super().__init__(url, host) 18 | 19 | if url is not None and url.startswith(f"{self.schema}:"): 20 | self.url = url 21 | else: 22 | self.url = "local:localhost" 23 | 24 | def command(self) -> list[str]: 25 | """ 26 | Constructs the full command needed to execute a tunnel dispatcher on this machine. 27 | 28 | Returns 29 | ------- 30 | list[str] 31 | The required ssh command. 32 | """ 33 | command = ["python3", os.path.realpath(td.__file__)] 34 | if fora.args.debug: 35 | command.append("--debug") 36 | return command 37 | 38 | @classmethod 39 | def extract_hostname(cls, url: str) -> str: 40 | if not url.startswith(f"{cls.schema}:"): 41 | raise ValueError(f"Cannot extract hostname from url without matching schema (expected '{cls.schema}', got '{url}').") 42 | hostname = url[len(cls.schema) + 1:] 43 | return hostname if len(hostname) > 0 else "localhost" 44 | -------------------------------------------------------------------------------- /src/fora/connectors/ssh.py: -------------------------------------------------------------------------------- 1 | """Contains a connector which handles connections to hosts via SSH.""" 2 | 3 | import base64 4 | import zlib 5 | from typing import Optional 6 | 7 | import fora 8 | from fora.connectors import tunnel_dispatcher as td 9 | from fora.connectors.connector import connector 10 | from fora.connectors.tunnel_connector import TunnelConnector 11 | from fora.types import HostWrapper 12 | 13 | @connector(schema='ssh') 14 | class SshConnector(TunnelConnector): 15 | """A tunnel connector that provides remote access via SSH.""" 16 | 17 | def __init__(self, url: Optional[str], host: HostWrapper): 18 | super().__init__(url, host) 19 | 20 | self.ssh_opts: list[str] = host.ssh_opts if hasattr(host, 'ssh_opts') else [] 21 | schema_prefix = f"{self.schema}:" 22 | if url is not None and url.startswith(schema_prefix): 23 | # change ssh:host -> ssh://host if necessary 24 | if not url.startswith(f"{schema_prefix}//"): 25 | self.url = f"{schema_prefix}//{url[len(schema_prefix):]}" 26 | else: 27 | self.url = url 28 | else: 29 | self.url: str = f"{self.schema}://{host.ssh_host}:{host.ssh_port}" 30 | 31 | def command(self) -> list[str]: 32 | """ 33 | Constructs the full ssh command needed to execute a 34 | tunnel dispatcher on the remote host. 35 | 36 | Returns 37 | ------- 38 | list[str] 39 | The required ssh command. 40 | """ 41 | with open(td.__file__, 'rb') as f: 42 | tunnel_dispatcher_gz_b64 = base64.b64encode(zlib.compress(f.read(), 9)).decode('ascii') 43 | 44 | # Start the remote dispatcher by uploading it inline as base64 45 | param_debug = "--debug" if fora.args.debug else "" 46 | 47 | command = ["ssh"] 48 | command.extend(self.ssh_opts) 49 | command.append(self.url) 50 | command.append(f"env python3 -c \"$(echo '{tunnel_dispatcher_gz_b64}' | base64 -d | python -c 'import zlib,sys;sys.stdout.buffer.write(zlib.decompress(sys.stdin.buffer.read()))')\" {param_debug}") 51 | 52 | return command 53 | 54 | @classmethod 55 | def extract_hostname(cls, url: str) -> str: 56 | if not url.startswith(f"{cls.schema}:"): 57 | raise ValueError(f"Cannot extract hostname from url without matching schema (expected '{cls.schema}', got '{url}').") 58 | 59 | # strip ssh:// 60 | # remaining: [user@]hostname[:port] 61 | hostname = url[len(cls.schema) + 3:] 62 | 63 | # Remove user 64 | pos = hostname.find("@") 65 | if pos >= 0: 66 | hostname = hostname[pos + 1:] 67 | 68 | # Remove port 69 | pos = hostname.find(":") 70 | if pos >= 0: 71 | hostname = hostname[:pos] 72 | 73 | return hostname 74 | -------------------------------------------------------------------------------- /src/fora/connectors/tunnel_connector.py: -------------------------------------------------------------------------------- 1 | """Contains a connector base which handles communication via any spawned subprocess command that can run a tunnel dispatcher on the remote host.""" 2 | 3 | import sys 4 | import subprocess 5 | from typing import Any, Optional, Type, cast 6 | 7 | from fora import logger 8 | from fora.connectors import tunnel_dispatcher as td 9 | from fora.connectors.connector import CompletedRemoteCommand, Connector, GroupEntry, StatResult, UserEntry 10 | from fora.types import HostWrapper 11 | 12 | def _expect_response_packet(packet: Any, expected_type: Type) -> None: 13 | """ 14 | Check if the given packet is of the expected type, otherwise raise a IOError. 15 | 16 | Parameters 17 | ---------- 18 | packet 19 | The packet to check. 20 | expected_type 21 | The expected type. 22 | """ 23 | if not isinstance(packet, expected_type): 24 | raise IOError(f"Invalid response '{type(packet)}' from remote dispatcher. This is a bug.") 25 | 26 | class TunnelConnector(Connector): 27 | """A connector that handles requests via an externally supplied subprocess running a tunnel dispatcher. 28 | Any subclass must override command().""" 29 | 30 | def __init__(self, url: Optional[str], host: HostWrapper): 31 | super().__init__(url, host) 32 | 33 | self.process: Optional[subprocess.Popen] = None 34 | self.conn: td.Connection 35 | self.is_open: bool = False 36 | 37 | def command(self) -> list[str]: 38 | """Returns the command that should be executed to open a tunnel dispatcher to the destination.""" 39 | raise NotImplementedError("Must be overwritten by subclass.") 40 | 41 | def open(self) -> None: 42 | logger.connection_init(self) 43 | 44 | # pylint: disable=consider-using-with 45 | # The process must outlive this function. 46 | self.process = subprocess.Popen(self.command(), stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr) 47 | if self.process.stdout is None or self.process.stdin is None: 48 | raise RuntimeError("Subprocess has no stdin/stdout. If is a bug.") 49 | self.conn = td.Connection(self.process.stdout, self.process.stdin) 50 | 51 | try: 52 | response = self._request(td.PacketCheckAlive()) 53 | _expect_response_packet(response, td.PacketAck) 54 | 55 | # As a last action record that the connection is opened successfully, 56 | # otherwise the finally block will kill the process. 57 | self.is_open = True 58 | except IOError as e: 59 | returncode = self.process.poll() 60 | if returncode is None: 61 | logger.connection_failed(str(e)) 62 | else: 63 | logger.connection_failed(f"command exited with code {returncode}") 64 | raise 65 | finally: 66 | # If the connection failed for any reason, be sure to kill the background process. 67 | if not self.is_open: 68 | self.process.terminate() 69 | self.process = None 70 | 71 | logger.connection_established() 72 | 73 | def close(self) -> None: 74 | if self.is_open: 75 | self.conn.write_packet(td.PacketExit()) 76 | 77 | if self.process is not None: 78 | if self.process.stdin is not None: 79 | self.process.stdin.close() 80 | self.process.wait() 81 | if self.process.stdout is not None: 82 | self.process.stdout.close() 83 | self.process = None 84 | 85 | def _request(self, packet: Any) -> Any: 86 | """Sends the request packet and returns the response. 87 | Propagates exceptions from raised from td.receive_packet.""" 88 | self.conn.write_packet(packet) 89 | return td.receive_packet(self.conn, request=packet) 90 | 91 | def run(self, 92 | command: list[str], 93 | input: Optional[bytes] = None, # pylint: disable=redefined-builtin 94 | capture_output: bool = True, 95 | check: bool = True, 96 | user: Optional[str] = None, 97 | group: Optional[str] = None, 98 | umask: Optional[str] = None, 99 | cwd: Optional[str] = None) -> CompletedRemoteCommand: 100 | # Construct and send packet with process information 101 | request = td.PacketProcessRun( 102 | command=command, 103 | stdin=input, 104 | capture_output=capture_output, 105 | user=user, 106 | group=group, 107 | umask=umask, 108 | cwd=cwd) 109 | response = self._request(request) 110 | 111 | if isinstance(response, td.PacketProcessError): 112 | raise ValueError(response.message) 113 | 114 | _expect_response_packet(response, td.PacketProcessCompleted) 115 | result = CompletedRemoteCommand(stdout=response.stdout, 116 | stderr=response.stderr, 117 | returncode=response.returncode) 118 | 119 | # Check output if requested 120 | if check and result.returncode != 0: 121 | raise subprocess.CalledProcessError(returncode=result.returncode, 122 | output=result.stdout, 123 | stderr=result.stderr, 124 | cmd=command) 125 | 126 | return result 127 | 128 | def stat(self, path: str, follow_links: bool = False, sha512sum: bool = False) -> Optional[StatResult]: 129 | # Construct and send packet with process information 130 | request = td.PacketStat( 131 | path=path, 132 | follow_links=follow_links, 133 | sha512sum=sha512sum) 134 | 135 | try: 136 | response = self._request(request) 137 | except ValueError: 138 | # File was not found, return None 139 | return None 140 | 141 | _expect_response_packet(response, td.PacketStatResult) 142 | return StatResult( 143 | type=response.type, 144 | mode=response.mode, 145 | owner=response.owner, 146 | group=response.group, 147 | size=response.size, 148 | mtime=response.mtime, 149 | ctime=response.ctime, 150 | sha512sum=response.sha512sum) 151 | 152 | def resolve_user(self, user: Optional[str]) -> str: 153 | request = td.PacketResolveUser(user=user) 154 | response = self._request(request) 155 | 156 | _expect_response_packet(response, td.PacketResolveResult) 157 | return cast(td.PacketResolveResult, response).value 158 | 159 | def resolve_group(self, group: Optional[str]) -> str: 160 | request = td.PacketResolveGroup(group=group) 161 | response = self._request(request) 162 | 163 | _expect_response_packet(response, td.PacketResolveResult) 164 | return cast(td.PacketResolveResult, response).value 165 | 166 | def query_user(self, user: str, query_password_hash: bool = False) -> UserEntry: 167 | request = td.PacketQueryUser(user=user, query_password_hash=query_password_hash) 168 | response = self._request(request) 169 | 170 | _expect_response_packet(response, td.PacketUserEntry) 171 | return UserEntry( 172 | name=response.name, 173 | uid=response.uid, 174 | group=response.group, 175 | gid=response.gid, 176 | groups=response.groups, 177 | password_hash=response.password_hash, 178 | gecos=response.gecos, 179 | home=response.home, 180 | shell=response.shell) 181 | 182 | def query_group(self, group: str) -> GroupEntry: 183 | request = td.PacketQueryGroup(group=group) 184 | response = self._request(request) 185 | 186 | _expect_response_packet(response, td.PacketGroupEntry) 187 | return GroupEntry( 188 | name=response.name, 189 | gid=response.gid, 190 | members=response.members) 191 | 192 | def getenv(self, key: str) -> Optional[str]: 193 | request = td.PacketGetenv(key=key) 194 | response = self._request(request) 195 | 196 | _expect_response_packet(response, td.PacketEnvironVar) 197 | return cast(td.PacketEnvironVar, response).value 198 | 199 | def upload(self, 200 | file: str, 201 | content: bytes, 202 | mode: Optional[str] = None, 203 | owner: Optional[str] = None, 204 | group: Optional[str] = None) -> None: 205 | request = td.PacketUpload( 206 | file=file, 207 | content=content, 208 | mode=mode, 209 | owner=owner, 210 | group=group) 211 | response = self._request(request) 212 | _expect_response_packet(response, td.PacketOk) 213 | 214 | def download(self, file: str) -> bytes: 215 | request = td.PacketDownload(file=file) 216 | response = self._request(request) 217 | 218 | _expect_response_packet(response, td.PacketDownloadResult) 219 | return cast(td.PacketDownloadResult, response).content 220 | -------------------------------------------------------------------------------- /src/fora/example_deploys.py: -------------------------------------------------------------------------------- 1 | """Provides example deploys, which can be used as a starting point.""" 2 | 3 | import os 4 | from pathlib import Path 5 | import sys 6 | from textwrap import dedent 7 | from typing import Literal, NoReturn 8 | 9 | from fora import logger 10 | from fora.utils import print_status 11 | 12 | _inventory_def = dedent("""\ 13 | # Defines which hosts belong to this inventory. 14 | hosts = [ 15 | "local:", # Local machine, executed as the user who invokes fora 16 | # "example", # Some remote machine via ssh (probably requires matching entry in `.ssh/config`) 17 | # "ssh://root@example.com", # An explicit user on a remote machine via ssh 18 | ] 19 | """) 20 | 21 | _localhost_def = dedent("""\ 22 | # Define a (different) url for this host. Useful 23 | # if the inventory entry is just a name like "localhost" 24 | # url = "ssh://root@localhost" 25 | 26 | # Define a variable for this host 27 | somevariable = "this was defined by the host" 28 | """) 29 | 30 | _nginx_add_site = dedent("""\ 31 | from fora.operations import files 32 | 33 | @Params 34 | class params: 35 | site: str 36 | 37 | files.template( 38 | name="Create the site definition", 39 | src="templates/site.j2", 40 | dest=f"/etc/nginx/sites/{params.site}") 41 | files.line_in_file( 42 | name="Add ", 43 | file="/etc/nginx/sites", 44 | line=f"sites/{params.site}") 45 | """) 46 | 47 | _nginx_site_j2 = dedent("""\ 48 | {{ fora_managed }} 49 | server { 50 | # ... 51 | } 52 | """) 53 | 54 | _nginx_install = dedent("""\ 55 | from fora.operations import system 56 | 57 | system.package( 58 | name="Install the application", 59 | packages=["nginx"]) 60 | system.service( 61 | name="(Re-)start the service", 62 | service="nginx", 63 | state="restarted", 64 | enabled=True) 65 | """) 66 | 67 | _modular_nginx_deploy = dedent("""\ 68 | from fora.operations import local, system 69 | 70 | local.script( 71 | script="tasks/example_task/install.py") 72 | local.script( 73 | name="Add test1.example.com site", 74 | script="tasks/example_task/add_site.py", 75 | params=dict(site="site1")) 76 | local.script( 77 | name="Add test2.example.com site", 78 | script="tasks/example_task/add_site.py", 79 | params=dict(site="site2")) 80 | 81 | system.service( 82 | name="(Re-)start the service", 83 | service="nginx", 84 | state="restarted") 85 | """) 86 | 87 | _all_def = dedent("""\ 88 | somevariable = "defined fallback in 'all' group" 89 | """) 90 | 91 | def _create_dirs(dirs: list[str]) -> None: 92 | """ 93 | Creates the given list of directories in the current working directory. 94 | 95 | Parameters 96 | ---------- 97 | dirs 98 | The directories to create 99 | """ 100 | for d in dirs: 101 | Path(d).mkdir(exist_ok=True) 102 | 103 | def _write_file(file: str, content: str) -> None: 104 | """ 105 | Writes the given content to the specified file. 106 | 107 | Parameters 108 | ---------- 109 | file 110 | The file 111 | content 112 | The content 113 | """ 114 | with open(file, "w", encoding="utf-8") as f: 115 | f.write(content) 116 | 117 | def init_structure_minimal() -> None: 118 | """Creates a minimal deploy structure.""" 119 | _create_dirs(["hosts"]) 120 | _write_file("hosts/localhost.py", _localhost_def) 121 | _write_file("inventory.py", _inventory_def) 122 | _write_file("deploy.py", dedent("""\ 123 | from fora import host 124 | from fora.operations import files 125 | 126 | files.upload_content( 127 | name="A temporary example file", 128 | content=f"Hello from {host.name}, also remember that {host.somevariable=}!", 129 | dest="/tmp/hello_world") 130 | """)) 131 | 132 | def init_structure_flat() -> None: 133 | """Creates a flat deploy structure.""" 134 | _create_dirs(["hosts", "groups", "tasks", "files", "templates"]) 135 | _write_file("hosts/localhost.py", _localhost_def) 136 | _write_file("groups/all.py", _all_def) 137 | _write_file("inventory.py", _inventory_def) 138 | _write_file("tasks/example_task.py", dedent("""\ 139 | from fora.operations import files 140 | 141 | files.upload( 142 | name="A temporary example file", 143 | src="../files/staticfile", 144 | dest="/tmp/hello_world") 145 | """)) 146 | _write_file("tasks/example_params.py", dedent("""\ 147 | from fora.operations import files 148 | 149 | @Params 150 | class params: 151 | filename: str 152 | 153 | script_var = "this is a fallback value defined in a script" 154 | 155 | files.template( 156 | name="Render a template to the file that was specified in the parameters", 157 | src="../templates/template.j2", 158 | dest=params.filename) 159 | """)) 160 | _write_file("deploy.py", dedent("""\ 161 | from fora.operations import local 162 | 163 | local.script( 164 | name="Run example task", 165 | script="tasks/example_task.py") 166 | local.script( 167 | name="Run parameter example task", 168 | script="tasks/example_params.py", 169 | params=dict(filename="/tmp/paramtest.txt")) 170 | """)) 171 | _write_file("templates/template.j2", dedent("""\ 172 | {{ fora_managed }} 173 | This file was specified by script parameters! See the fallback to the script var for the host: {{host.script_var}} 174 | """)) 175 | _write_file("files/staticfile", dedent("""\ 176 | Hello I am static content! 177 | """)) 178 | 179 | def init_structure_dotfiles() -> None: 180 | """Creates a dotfiles deploy structure.""" 181 | _write_file("deploy.py", dedent("""\ 182 | from fora import host 183 | from fora.operations import files 184 | 185 | # Get home directory of current user 186 | home = host.home_dir() 187 | 188 | # zsh 189 | files.upload(src="zsh/zshrc", dest=f"{home}/.zshrc") 190 | 191 | # kitty 192 | files.directory(path=f"{home}/.config/kitty") 193 | files.upload(src="kitty/kitty.conf", dest=f"{home}/.config/kitty/kitty.conf") 194 | 195 | # neovim 196 | files.directory(path=f"{home}/.config/nvim") 197 | files.upload(src="neovim/init.lua", dest=f"{home}/.config/init.lua") 198 | """)) 199 | 200 | def init_structure_modular() -> None: 201 | """Creates a modular deploy structure.""" 202 | _create_dirs(["hosts", "groups", "tasks", "tasks/example_task", "tasks/example_task/files", "tasks/example_task/templates"]) 203 | _write_file("hosts/localhost.py", _localhost_def) 204 | _write_file("groups/all.py", _all_def) 205 | _write_file("inventory.py", _inventory_def) 206 | _write_file("tasks/example_task/install.py", _nginx_install) 207 | _write_file("tasks/example_task/add_site.py", _nginx_add_site) 208 | _write_file("tasks/example_task/templates/site.j2", _nginx_site_j2) 209 | _write_file("deploy.py", _modular_nginx_deploy) 210 | 211 | def init_structure_staging_prod() -> None: 212 | """Creates a staging_prod deploy structure.""" 213 | _create_dirs(["inventories", "inventories/hosts", "inventories/groups", "tasks", "tasks/example_task", "tasks/example_task/files", "tasks/example_task/templates"]) 214 | _write_file("inventories/hosts/example.py", dedent("""\ 215 | # Same hostfile definition for all example.com hosts, to avoid repetition 216 | domain = "example.com" 217 | """)) 218 | _write_file("inventories/groups/all.py", _all_def) 219 | _write_file("inventories/staging.py", dedent("""\ 220 | import os 221 | 222 | # Global variables (and functions) that are not inventory-related 223 | # will automatically be exported to the `all` group. 224 | api_key = os.getenv("API_KEY_STAGING") 225 | 226 | hosts = [dict(url="staging1.example.com", file="hosts/example.py", groups=["staging"]), 227 | dict(url="staging2.example.com", file="hosts/example.py", groups=["staging"])] 228 | """)) 229 | _write_file("inventories/prod.py", dedent("""\ 230 | import os 231 | 232 | # Global variables (and functions) that are not inventory-related 233 | # will automatically be exported to the `all` group. 234 | api_key = os.getenv("API_KEY_PROD") 235 | 236 | hosts = [dict(url="prod1.example.com", file="hosts/example.py", groups=["prod"]), 237 | dict(url="prod2.example.com", file="hosts/example.py", groups=["prod"]), 238 | dict(url="prod3.example.com", file="hosts/example.py", groups=["prod"]), 239 | dict(url="prod4.example.com", file="hosts/example.py", groups=["prod"])] 240 | """)) 241 | _write_file("tasks/example_task/install.py", _nginx_install) 242 | _write_file("tasks/example_task/add_site.py", _nginx_add_site) 243 | _write_file("tasks/example_task/templates/site.j2", _nginx_site_j2) 244 | _write_file("deploy.py", _modular_nginx_deploy) 245 | 246 | def init_deploy_structure(layout: Literal["minimal", "flat", "dotfiles", "modular", "staging_prod"]) -> NoReturn: # type: ignore[misc] 247 | """ 248 | Initializes the current directory with a default deploy structure, if it is empty. 249 | Prompts the user to confirm operation if the current directory is not empty. 250 | 251 | Parameters 252 | ---------- 253 | layout 254 | The layout for the deploy. 255 | 256 | Raises 257 | ------ 258 | ValueError 259 | Invalid layout. 260 | """ 261 | if layout not in _init_fns: 262 | raise ValueError(f"Unknown deploy layout structure '{layout}'") 263 | 264 | # Check if directory is empty. If not, ask whether to proceed. 265 | cwd = os.getcwd() 266 | if any(os.scandir(cwd)): 267 | response = input(f"{logger.col('[1;33m')}warning:{logger.col('[m')} current directory is not empty, proceed anyway? (conflicting files will be overwritten) [y/N] ") 268 | if response.lower() not in ["y", "yes"]: 269 | sys.exit(1) 270 | 271 | print_status("init:", f"creating {logger.col('[1;33m')}{layout}{logger.col('[m')} deploy structure") 272 | _init_fns[layout]() 273 | sys.exit(0) 274 | 275 | _init_fns = { 276 | "minimal": init_structure_minimal, 277 | "flat": init_structure_flat, 278 | "dotfiles": init_structure_dotfiles, 279 | "modular": init_structure_modular, 280 | "staging_prod": init_structure_staging_prod, 281 | } 282 | -------------------------------------------------------------------------------- /src/fora/loader.py: -------------------------------------------------------------------------------- 1 | """Provides the dynamic module loading utilities.""" 2 | 3 | import inspect 4 | import os 5 | import subprocess 6 | from types import ModuleType 7 | from typing import Union, Any, Optional 8 | 9 | import fora 10 | 11 | from fora import logger 12 | from fora.inventory_wrapper import InventoryWrapper 13 | from fora.types import ScriptWrapper 14 | from fora.utils import FatalError, load_py_module, print_process_error 15 | 16 | script_stack: list[tuple[ScriptWrapper, inspect.FrameInfo]] = [] 17 | """A stack of all currently executed scripts ((name, file), frame).""" 18 | 19 | class ImmediateInventory: 20 | """A temporary inventory just for a single run, without the ability to load host or group module files.""" 21 | def __init__(self, hosts: list[Union[str, tuple[str, str]]]) -> None: 22 | self.hosts = list(hosts) 23 | 24 | def base_dir(self) -> str: 25 | """An immediate inventory has no base directory.""" 26 | _ = (self) 27 | raise RuntimeError("Immediate inventories have no base directory!") 28 | 29 | def group_module_file(self, name: str) -> Optional[str]: # pylint: disable=useless-return 30 | """An immediate inventory has no group modules.""" 31 | _ = (self, name) 32 | return None 33 | 34 | def host_module_file(self, name: str) -> Optional[str]: # pylint: disable=useless-return 35 | """An immediate inventory has no host modules.""" 36 | _ = (self, name) 37 | return None 38 | 39 | def available_groups(self) -> set[str]: 40 | """An immediate inventory has no groups.""" 41 | _ = (self) 42 | return set() 43 | 44 | def load_inventory(inventory_file_or_host_url: str) -> None: 45 | """ 46 | Loads the global inventory from the given filename or single-host url 47 | and validates the definintions. 48 | 49 | Parameters 50 | ---------- 51 | inventory_file_or_host_url 52 | Either a single host url or an inventory module file (`*.py`). If a single host url 53 | is given without a connection schema (like `ssh://`), ssh will be used. 54 | 55 | Raises 56 | ------ 57 | FatalError 58 | The loaded inventory was invalid. 59 | """ 60 | wrapper = InventoryWrapper() 61 | fora.inventory = wrapper 62 | 63 | if inventory_file_or_host_url.endswith(".py"): 64 | # Load inventory from module file 65 | def _pre_exec(module: ModuleType) -> None: 66 | wrapper.wrap(module) 67 | 68 | inv = load_py_module(inventory_file_or_host_url, pre_exec=_pre_exec) 69 | 70 | # Check that the hosts definition is valid. 71 | if not hasattr(inv, "hosts"): 72 | raise FatalError("Inventory must define a list of hosts!", loc=wrapper.definition_file()) 73 | hosts = getattr(inv, "hosts") 74 | if not isinstance(hosts, list): 75 | raise FatalError(f"`hosts` definition must be of type list, not {type(hosts)}!", loc=wrapper.definition_file()) 76 | else: 77 | # Create an immediate inventory with just the given host. 78 | wrapper.wrap(ImmediateInventory([inventory_file_or_host_url])) 79 | 80 | try: 81 | wrapper.load() 82 | except ValueError as e: 83 | raise FatalError(str(e), loc=wrapper.definition_file()) from e 84 | 85 | def run_script(script: str, 86 | frame: inspect.FrameInfo, 87 | params: Optional[dict[str, Any]] = None, 88 | name: Optional[str] = None) -> None: 89 | """ 90 | Loads and implicitly runs the given script by creating a new instance. 91 | 92 | Parameters 93 | ---------- 94 | script 95 | The path to the script that should be instanciated 96 | frame 97 | The FrameInfo object as given by inspect.getouterframes(inspect.currentframe())[?] 98 | where the script call originates from. Used to keep track of the script invocation stack, 99 | which helps with debugging (e.g. cyclic script calls). 100 | name 101 | A printable name for the script. Defaults to the script path. 102 | """ 103 | 104 | # It is intended that the name is passed before resolving it, so 105 | # that it is None if the user didn't pass one specifically. 106 | logger.run_script(script, name=name) 107 | 108 | with logger.indent(): 109 | if name is None: 110 | name = os.path.splitext(os.path.basename(script))[0] 111 | 112 | wrapper = ScriptWrapper(name) 113 | script_stack.append((wrapper, frame)) 114 | try: 115 | previous_script = fora.script 116 | previous_working_directory = os.getcwd() 117 | canonical_script = os.path.realpath(script) 118 | 119 | # Change into script's containing directory, so a script 120 | # can reliably use relative paths while it is executed. 121 | new_working_directory = os.path.dirname(canonical_script) 122 | os.chdir(new_working_directory) 123 | 124 | try: 125 | fora.script = wrapper 126 | # New script instance should start with a fresh set of default values. 127 | # Therefore, we use defaults() here to apply the connection's base settings. 128 | with wrapper.defaults(): 129 | def _pre_exec(module: ModuleType) -> None: 130 | wrapper.wrap(module, copy_members=True, copy_functions=True) 131 | setattr(module, '_params', params or {}) 132 | load_py_module(canonical_script, pre_exec=_pre_exec) 133 | finally: 134 | os.chdir(previous_working_directory) 135 | fora.script = previous_script 136 | except Exception as e: 137 | if isinstance(e, subprocess.CalledProcessError) and not hasattr(e, "__fora_already_printed"): 138 | print_process_error(e) 139 | setattr(e, "__fora_already_printed", True) 140 | 141 | # Save the current script_stack in any exception thrown from this context 142 | # for later use in any exception handler. 143 | if not hasattr(e, 'script_stack'): 144 | setattr(e, 'script_stack', script_stack.copy()) 145 | raise 146 | finally: 147 | script_stack.pop() 148 | -------------------------------------------------------------------------------- /src/fora/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides logging utilities. 3 | """ 4 | 5 | import argparse 6 | import difflib 7 | import os 8 | from dataclasses import dataclass 9 | import sys 10 | from types import TracebackType 11 | from typing import Any, Optional, Type, cast 12 | 13 | import fora 14 | 15 | @dataclass 16 | class State: 17 | """Global state for logging.""" 18 | 19 | indentation_level: int = 0 20 | """The current global indentation level.""" 21 | 22 | state: State = State() 23 | """The global logger state.""" 24 | 25 | def use_color() -> bool: 26 | """Returns true if color should be used.""" 27 | if not isinstance(cast(Any, fora.args), argparse.Namespace): 28 | return os.getenv("NO_COLOR") is None 29 | return not fora.args.no_color 30 | 31 | def col(color_code: str) -> str: 32 | """Returns the given argument only if color is enabled.""" 33 | return color_code if use_color() else "" 34 | 35 | class IndentationContext: 36 | """A context manager to modify the indentation level.""" 37 | def __enter__(self) -> None: 38 | state.indentation_level += 1 39 | 40 | def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]) -> None: 41 | _ = (exc_type, exc, traceback) 42 | state.indentation_level -= 1 43 | 44 | def ellipsis(s: str, width: int) -> str: 45 | """ 46 | Shrinks the given string to width (including an ellipsis character). 47 | 48 | Parameters 49 | ---------- 50 | s 51 | The string. 52 | width 53 | The maximum width. 54 | 55 | Returns 56 | ------- 57 | str 58 | A modified string with at most `width` characters. 59 | """ 60 | if len(s) > width: 61 | s = s[:width - 1] + "…" 62 | return s 63 | 64 | def indent() -> IndentationContext: 65 | """Retruns a context manager that increases the indentation level.""" 66 | return IndentationContext() 67 | 68 | def indent_prefix() -> str: 69 | """Returns the indentation prefix for the current indentation level.""" 70 | if not use_color(): 71 | return " " * state.indentation_level 72 | ret = "" 73 | for i in range(state.indentation_level): 74 | if i % 2 == 0: 75 | ret += "[90m│[m " 76 | else: 77 | ret += "[90m╵[m " 78 | return ret 79 | 80 | def debug(msg: str) -> None: 81 | """Prints the given message only in debug mode.""" 82 | if not fora.args.debug: 83 | return 84 | 85 | print(f" [1;34mDEBUG[m: {msg}", file=sys.stderr) 86 | 87 | def debug_args(msg: str, args: dict[str, Any]) -> None: 88 | """Prints all given arguments when in debug mode.""" 89 | if not fora.args.debug: 90 | return 91 | 92 | str_args = "" 93 | args = {k: v for k,v in args.items() if k != "self"} 94 | if len(args) > 0: 95 | str_args = " " + ", ".join(f"{k}={v}" for k,v in args.items()) 96 | 97 | print(f" [1;34mDEBUG[m: {msg}{str_args}", file=sys.stderr) 98 | 99 | def print_indented(msg: str, **kwargs: Any) -> None: 100 | """Same as print(), but prefixes the message with the indentation prefix.""" 101 | print(f"{indent_prefix()}{msg}", **kwargs) 102 | 103 | 104 | def connection_init(connector: Any) -> None: 105 | """Prints connection initialization information.""" 106 | print_indented(f"{col('[1;34m')}host{col('[m')} {connector.host.name} via {col('[1;33m')}{connector.host.url}{col('[m')}", flush=True) 107 | 108 | def connection_failed(error_msg: str) -> None: 109 | """Signals that an error has occurred while establishing the connection.""" 110 | print(col("[1;31m") + "ERR" + col("[m")) 111 | print_indented(f" {col('[90m')}└{col('[m')} " + f"{col('[31m')}{error_msg}{col('[m')}") 112 | 113 | def connection_established() -> None: 114 | """Signals that the connection has been successfully established.""" 115 | #print(col("[1;32m") + "OK" + col("[m")) 116 | 117 | 118 | def run_script(script: str, name: Optional[str] = None) -> None: 119 | """Prints the script file and name that is being executed next.""" 120 | if name is not None: 121 | print_indented(f"{col('[33;1m')}script{col('[m')} {script} {col('[90m')}({name}){col('[m')}") 122 | else: 123 | print_indented(f"{col('[33;1m')}script{col('[m')} {script}") 124 | 125 | def print_operation_title(op: Any, title_color: str, end: str = "\n") -> None: 126 | """Prints the operation title and description.""" 127 | name_if_given = (" " + col('[90m') + f"({op.name})" + col('[m')) if op.name is not None else "" 128 | dry_run_info = f" {col('[90m')}(dry){col('[m')}" if fora.args.dry else "" 129 | print_indented(f"{title_color}{op.op_name}{col('[m')}{dry_run_info} {op.description}{name_if_given}", end=end, flush=True) 130 | 131 | def print_operation_early(op: Any) -> None: 132 | """Prints the operation title and description before the final status is known.""" 133 | title_color = col("[1;33m") 134 | # Only overwrite status later if debugging is not enabled. 135 | print_operation_title(op, title_color, end=" (early status)\n" if fora.args.debug else "") 136 | 137 | 138 | def decode_escape(data: bytes, encoding: str = 'utf-8') -> str: 139 | """ 140 | Tries to decode the given data with the given encoding, but replaces all non-decodeable 141 | and non-printable characters with backslash escape sequences. 142 | 143 | Example: 144 | 145 | ```python 146 | >>> decode_escape(b'It is Wednesday\\nmy dudes\\r\\n🐸\\xff\\0') 147 | 'It is Wednesday\\\\nMy Dudes\\\\r\\\\n🐸\\\\xff\\\\0' 148 | ``` 149 | 150 | Parameters 151 | ---------- 152 | content 153 | The content that should be decoded and escaped. 154 | encoding 155 | The encoding that should be tried. To preserve utf-8 symbols, use 'utf-8', 156 | to replace any non-ascii character with an escape sequence use 'ascii'. 157 | 158 | Returns 159 | ------- 160 | str 161 | The decoded and escaped string. 162 | """ 163 | def escape_char(c: str) -> str: 164 | special = {'\x00': '\\0', '\n': '\\n', '\r': '\\r', '\t': '\\t'} 165 | if c in special: 166 | return special[c] 167 | 168 | num = ord(c) 169 | if not c.isprintable() and num <= 0xff: 170 | return f"\\x{num:02x}" 171 | return c 172 | return ''.join([escape_char(c) for c in data.decode(encoding, 'backslashreplace')]) 173 | 174 | def diff(filename: str, old: Optional[bytes], new: Optional[bytes], color: bool = True) -> list[str]: 175 | """ 176 | Creates a diff between the old and new content of the given filename, 177 | that can be printed to the console. This function returns the diff 178 | output as an array of lines. The lines in the output array are not 179 | terminated by newlines. 180 | 181 | If color is True, the diff is colored using ANSI escape sequences. 182 | 183 | If you want to provide an alternative diffing function, beware that 184 | the input can theoretically contain any bytes and therefore should 185 | be decoded as utf-8 if possible, but non-decodeable 186 | or non-printable charaters should be replaced with human readable 187 | variants such as `\\x00`, `^@` or similar represenations. 188 | 189 | Your diffing function should still be able to work on the raw bytes 190 | representation, after you aquire the diff and before you apply colors, 191 | your output should be made printable with a function such as `fora.logger.decode_escape`: 192 | 193 | ```python 194 | # First decode and escape 195 | line = logger.decode_escape(byteline) 196 | # Add coloring afterwards so ANSI escape sequences are not escaped 197 | ``` 198 | 199 | Parameters 200 | ---------- 201 | filename 202 | The filename of the file that is being diffed. 203 | old 204 | The old content, or None if the file didn't exist before. 205 | new 206 | The new content, or None if the file was deleted. 207 | color 208 | Whether the output should be colored (with ANSI color sequences). 209 | 210 | Returns 211 | ------- 212 | list[str] 213 | The lines of the diff output. The individual lines will not have a terminating newline. 214 | """ 215 | bdiff = list(difflib.diff_bytes(difflib.unified_diff, 216 | a=[] if old is None else old.split(b'\n'), 217 | b=[] if new is None else new.split(b'\n'), 218 | lineterm=b'')) 219 | # Strip file name header and decode diff to be human readable. 220 | difflines = map(decode_escape, bdiff[2:]) 221 | 222 | # Create custom file name header 223 | action = 'created' if old is None else 'deleted' if new is None else 'modified' 224 | title = f"{action}: {filename}" 225 | N = len(title) 226 | header = ['─' * N, title, '─' * N] 227 | 228 | # Apply coloring if desired 229 | if color: 230 | def apply_color(line: str) -> str: 231 | linecolor = { 232 | '+': '[32m', 233 | '-': '[31m', 234 | '@': '[34m', 235 | } 236 | return linecolor.get(line[0], '[90m') + line + '[m' 237 | # Apply color to diff 238 | difflines = map(apply_color, difflines) 239 | # Apply color to header 240 | header = list(map(lambda line: f"[33m{line}[m", header)) 241 | 242 | return header + list(difflines) 243 | 244 | # TODO: move functions to operation api. cleaner and has type access. 245 | def _operation_state_infos(result: Any) -> list[str]: 246 | def to_str(v: Any) -> str: 247 | return v.hex() if isinstance(v, bytes) else str(v) 248 | 249 | # Print "key: value" pairs with changes 250 | state_infos: list[str] = [] 251 | for k,final_v in result.final.items(): 252 | if final_v is None: 253 | continue 254 | 255 | initial_v = result.initial[k] 256 | str_initial_v = to_str(initial_v) 257 | str_final_v = to_str(final_v) 258 | 259 | # Add ellipsis on long strings, if we are not in verbose mode 260 | if fora.args.verbose == 0: 261 | k = ellipsis(k, 12) 262 | str_initial_v = ellipsis(to_str(initial_v), 9) 263 | str_final_v = ellipsis(to_str(final_v), 9+3+9 if initial_v is None else 9) 264 | 265 | if initial_v == final_v: 266 | if fora.args.verbose >= 1: 267 | # TODO = instead of : for better readability 268 | entry_str = f"{col('[90m')}{k}: {str_initial_v}{col('[m')}" 269 | state_infos.append(entry_str) 270 | else: 271 | if initial_v is None: 272 | entry_str = f"{col('[33m')}{k}: {col('[32m')}{str_final_v}{col('[m')}" 273 | else: 274 | entry_str = f"{col('[33m')}{k}: {col('[31m')}{str_initial_v}{col('[33m')} → {col('[32m')}{str_final_v}{col('[m')}" 275 | state_infos.append(entry_str) 276 | return state_infos 277 | 278 | def print_operation(op: Any, result: Any) -> None: 279 | """Prints the operation summary after it has finished execution.""" 280 | if result.success: 281 | title_color = col("[1;32m") if result.changed else col("[1;90m") 282 | else: 283 | title_color = col("[1;31m") 284 | 285 | # Print title and name, overwriting the transitive status 286 | print("\r", end="") 287 | print_operation_title(op, title_color) 288 | 289 | if not result.success: 290 | print_indented(f" {col('[90m')}└{col('[m')} " + f"{col('[31m')}{result.failure_message}{col('[m')}") 291 | return 292 | 293 | if not fora.args.changes: 294 | return 295 | 296 | # Cache number of upcoming diffs to determine what box character to print 297 | n_diffs = len(op.diffs) if fora.args.diff else 0 298 | box_char = '└' if n_diffs == 0 else '├' 299 | 300 | # Print "key: value" pairs with changes 301 | state_infos = _operation_state_infos(result) 302 | if len(state_infos) > 0: 303 | print_indented(f"{col('[90m')}{box_char}{col('[m')} " + f"{col('[90m')},{col('[m')} ".join(state_infos)) 304 | 305 | if fora.args.diff: 306 | diff_lines = [] 307 | # Generate diffs 308 | for file, old, new in op.diffs: 309 | diff_lines.extend(diff(file, old, new)) 310 | # Print diffs with block character line 311 | if len(diff_lines) > 0: 312 | for l in diff_lines[:-1]: 313 | print_indented(f"{col('[90m')}│ {col('[m')}" + l) 314 | print_indented(f"{col('[90m')}└ {col('[m')}" + diff_lines[-1]) 315 | -------------------------------------------------------------------------------- /src/fora/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides the top-level logic of fora such as 3 | the CLI interface and main script dispatching. 4 | """ 5 | 6 | import argparse 7 | import inspect 8 | import os 9 | import sys 10 | from types import ModuleType 11 | from typing import Any, Callable, NoReturn, Optional, cast 12 | 13 | import fora 14 | from fora.connection import open_connection 15 | from fora.example_deploys import init_deploy_structure 16 | from fora.loader import load_inventory, run_script 17 | from fora.logger import col 18 | from fora.types import GroupWrapper, HostWrapper, ModuleWrapper, VariableActionSnapshot 19 | from fora.utils import FatalError, die_error, install_exception_hook, print_fullwith, print_table 20 | from fora.version import version 21 | 22 | def main_run(args: argparse.Namespace) -> None: 23 | """ 24 | Main method used to run a script on an inventory. 25 | 26 | Parameters 27 | ---------- 28 | args 29 | The parsed arguments 30 | """ 31 | try: 32 | load_inventory(args.inventory) 33 | except FatalError as e: 34 | die_error(str(e), loc=e.loc) 35 | 36 | # Deduplicate host selection and check if every host is valid 37 | selected_hosts = [] 38 | for host in (args.hosts.split(",") if args.hosts is not None else fora.inventory.loaded_hosts): 39 | # Skip duplicate entries 40 | if host in selected_hosts: 41 | continue 42 | # Ensure host existence 43 | if host not in fora.inventory.loaded_hosts: 44 | die_error(f"Unknown host '{host}'") 45 | selected_hosts.append(host) 46 | 47 | # TODO: multiprocessing? 48 | # - displaying must then be handled by ncurses which makes things a lot more complex. 49 | # - would open the door to a more interactive experience, e.g. allow to select past operations 50 | # and view information about them, scroll through diffs, ... 51 | # - we need to save some kind of log file as the output won't persist in the terminal 52 | # - fatal errors must be delayed until all executions are fininshed. 53 | 54 | # Instanciate (run) the given script for each selected host 55 | for k in selected_hosts: 56 | host = fora.inventory.loaded_hosts[k] 57 | with open_connection(host): 58 | fora.host = host 59 | run_script(args.script, inspect.getouterframes(inspect.currentframe())[0], name="cmdline") 60 | fora.host = cast(HostWrapper, None) 61 | 62 | if host.name != selected_hosts[-1]: 63 | # Separate hosts by a newline for better visibility 64 | print() 65 | 66 | def show_inventory(inventory: str) -> None: 67 | """ 68 | Display a summary of the given inventory. 69 | 70 | Parameters 71 | ---------- 72 | inventory 73 | The inventory argument 74 | """ 75 | # pylint: disable=protected-access,too-many-branches,too-many-statements 76 | try: 77 | load_inventory(inventory) 78 | except FatalError as e: 79 | die_error(str(e), loc=e.loc) 80 | 81 | col_red = col("\033[31m") 82 | col_red_b = col("\033[1;31m") 83 | col_green = col("\033[32m") 84 | col_green_b = col("\033[1;32m") 85 | col_yellow = col("\033[33m") 86 | col_blue = col("\033[34m") 87 | col_darker = col("\033[90m") 88 | col_darker_b = col("\033[1;90m") 89 | col_reset = col("\033[m") 90 | 91 | try: 92 | base_dir = fora.inventory.base_dir() 93 | except RuntimeError: 94 | base_dir = "." 95 | 96 | def relpath(path: Optional[str]) -> Optional[str]: 97 | return None if path is None else os.path.relpath(path, start=base_dir) 98 | 99 | def value_repr(x: Any) -> list[str]: 100 | color = col_reset 101 | if x is None: 102 | color = col_red 103 | elif isinstance(x, bool): 104 | color = col_green if x else col_red 105 | elif isinstance(x, (list, tuple, range, dict, set)): 106 | color = col_blue 107 | elif isinstance(x, (str, bytes)): 108 | color = col_green 109 | elif isinstance(x, (int, float)): 110 | color = col_yellow 111 | else: 112 | color = col_darker 113 | 114 | return [color, repr(value), col_reset] 115 | 116 | def precedence(wrapper: ModuleWrapper) -> int: 117 | """Calculates a numeric variable precedence in accordance with the hierachical lookup rules.""" 118 | if isinstance(wrapper, GroupWrapper): 119 | return fora.inventory._topological_order.index(wrapper.name) 120 | if isinstance(wrapper, HostWrapper): 121 | return len(fora.inventory._topological_order) 122 | return -1 123 | 124 | print_fullwith(["──────── ", col_red_b, "inventory", col_reset, " ", col_darker_b, inventory, col_reset, " "], [col_darker, f" {relpath(fora.inventory.definition_file())}", col_reset]) 125 | 126 | pretty_group_names = { name: f"{col_darker}- ({index}){col_reset} {col_yellow}{name}{col_reset}" for index,name in enumerate(fora.inventory._topological_order) } 127 | print(f"{col_blue}groups{col_reset} {col_darker}(precedence, low to high){col_reset}") 128 | for i in pretty_group_names.values(): 129 | print(f" {i}") 130 | 131 | pretty_host_names = { name: f"{col_darker}-{col_reset} {col_green}{name}{col_reset} {col_darker}({host.url}, {relpath(host.definition_file())}){col_reset}" for name,host in fora.inventory.loaded_hosts.items() } 132 | print(f"{col_blue}hosts{col_reset} {col_darker}(url, module){col_reset}") 133 | for i in pretty_host_names.values(): 134 | print(f" {i}") 135 | 136 | global_vars = fora.inventory.exported_variables() 137 | if len(global_vars) > 0: 138 | print(f"{col_blue}variables{col_reset}") 139 | for attr, value in global_vars.items(): 140 | print(f"{col_green}{attr}{col_reset}\t(type {type(value)}) = {value}") 141 | 142 | for name, host in fora.inventory.loaded_hosts.items(): 143 | print() 144 | print_fullwith(["──────── ", col_red_b, "host", col_reset, " ", col_green_b, name, col_reset, " "], [col_darker, f" {relpath(host.definition_file())}", col_reset]) 145 | entries = [] 146 | for attr, value in host.vars_hierarchical().items(): 147 | if attr.startswith("_") or isinstance(value, ModuleType): 148 | continue 149 | is_declared_by_wrapper = attr in HostWrapper.__dict__ or attr in HostWrapper.__annotations__ 150 | last_actor = host._variable_action_history.get(attr, [VariableActionSnapshot("definition", host, value)])[-1].actor 151 | entries.append((attr, value, is_declared_by_wrapper, last_actor)) 152 | 153 | table = [] 154 | for attr, value, is_declared_by_wrapper, last_actor in sorted(entries, key=lambda tup: (not tup[2], precedence(tup[3]), tup[0])): 155 | definition_str: list[str] = [] 156 | for action in reversed(host._variable_action_history.get(attr, [VariableActionSnapshot("definition", host, value)])): 157 | if isinstance(action.actor, GroupWrapper): 158 | definition_str = [col_darker, f"({precedence(action.actor)}) ", col_reset, col_yellow, action.actor.name, col_reset, col_darker, ", ", col_reset] + definition_str 159 | elif isinstance(action.actor, HostWrapper): 160 | definition_str = [col_darker, f"({precedence(action.actor)}) ", col_reset, col_green, action.actor.name, col_reset, col_darker, ", ", col_reset] + definition_str 161 | if action.action == "definition": 162 | break 163 | 164 | # Strip last ", " 165 | definition_str = definition_str[:-3] 166 | 167 | col_var = col_darker 168 | if is_declared_by_wrapper: 169 | if host.is_overridden(attr): 170 | col_var = col_darker_b 171 | else: 172 | col_var = col_darker 173 | elif isinstance(last_actor, GroupWrapper): 174 | col_var = col_yellow 175 | elif isinstance(last_actor, HostWrapper): 176 | col_var = col_green 177 | else: 178 | col_var = col_reset 179 | 180 | table.append([[col_var, attr, col_reset], [col_darker, type(value).__name__, col_reset], definition_str, value_repr(value)]) 181 | print_table([[col_blue, "variable", col_reset], 182 | [col_blue, "type", col_reset], 183 | [col_darker, "(prec) ", col_reset, col_blue, "defined by", col_reset], 184 | [col_blue, "value", col_reset]], 185 | table, min_col_width=[24, 0, 12, 0]) 186 | 187 | sys.exit(0) 188 | 189 | class ArgumentParserError(Exception): 190 | """Error class for argument parsing errors.""" 191 | 192 | class ThrowingArgumentParser(argparse.ArgumentParser): 193 | """An argument parser that throws when invalid argument types are passed.""" 194 | 195 | def error(self, message: str) -> NoReturn: 196 | """Raises an exception on error.""" 197 | raise ArgumentParserError(message) 198 | 199 | class ActionImmediateFunction(argparse.Action): 200 | """An action that calls a function immediately when the argument is encountered.""" 201 | def __init__(self, option_strings: Any, func: Callable[[Any], Any], *args: Any, **kwargs: Any): 202 | self.func = func 203 | super().__init__(option_strings, *args, **kwargs) 204 | 205 | def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: Any = None) -> None: 206 | _ = (parser, namespace, values, option_string) 207 | self.func(values) 208 | 209 | def main(argv: Optional[list[str]] = None) -> None: 210 | """ 211 | The main program entry point. This will parse arguments, load inventory and task 212 | definitions and run the given user script. Defaults to sys.argv[1:] if argv is None. 213 | """ 214 | if argv is None: 215 | argv = sys.argv[1:] 216 | parser = ThrowingArgumentParser(description="Runs a fora script.") 217 | 218 | # General options 219 | parser.add_argument('-V', '--version', action='version', 220 | version=f"%(prog)s version {version}") 221 | 222 | # Run script options 223 | parser.add_argument('--init', action=ActionImmediateFunction, func=init_deploy_structure, choices=["minimal", "flat", "dotfiles", "modular", "staging_prod"], 224 | help="Initialize the current directory with a default deploy structure and exit. The various choices are explained in-depth in the documentation. As a rule of thumb, 'minimal' is the most basic starting point, 'flat' is well-suited for small and simple deploys, 'dotfiles' is explicitly intended for dotfile deploys, 'modular' is the most versatile layout intended to be used with modular sub-tasks, and 'staging_prod' is the modular layout with two separate inventories.") 225 | parser.add_argument('--inspect-inventory', action=ActionImmediateFunction, func=show_inventory, 226 | help="Display all available information about a specific inventory. This includes a summary as well as the specific variables available on each group or host.") 227 | parser.add_argument('-H', '--hosts', dest='hosts', default=None, type=str, 228 | help="Specifies a comma separated list of hosts to run on. By default all hosts are selected. Duplicates will be ignored.") 229 | parser.add_argument('--dry', '--dry-run', '--pretend', dest='dry', action='store_true', 230 | help="Print what would be done instead of performing any actions. Probing commands will still be executed to determine the current state of the systems.") 231 | parser.add_argument('-v', '--verbose', dest='verbose', action='count', default=0, 232 | help="Increase output verbosity. Can be given multiple times.") 233 | parser.add_argument('--no-changes', dest='changes', action='store_false', 234 | help="Don't display changes for each operation in a short diff-like format.") 235 | parser.add_argument('--diff', dest='diff', action='store_true', 236 | help="Display an actual diff when an operation changes a file. Use with care, as this might print secrets!") 237 | parser.add_argument('--debug', dest='debug', action='store_true', 238 | help="Enable debugging output. Forces verbosity to max value.") 239 | parser.add_argument('--no-color', dest='no_color', action='store_true', 240 | help="Disables any color output. Color can also be disabled by setting the NO_COLOR environment variable.") 241 | parser.add_argument('inventory', type=str, 242 | help="The inventory to run on. Either a single host url or an inventory module (`*.py`). If a single host url is given without a connection schema (like `ssh://`), ssh will be used. Single hosts also do not load any groups or host modules.") 243 | parser.add_argument('script', type=str, 244 | help="The user script containing the logic of what should be executed on the inventory.") 245 | parser.set_defaults(func=main_run) 246 | 247 | try: 248 | args: argparse.Namespace = parser.parse_args(argv) 249 | except ArgumentParserError as e: 250 | die_error(str(e)) 251 | 252 | # Force max verbosity with --debug 253 | if args.debug: 254 | args.verbose = 99 255 | 256 | # Disable color when NO_COLOR is set, or output is not a TTY 257 | is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() 258 | if os.getenv("NO_COLOR") is not None or not is_a_tty: 259 | args.no_color = True 260 | 261 | # Install exception hook to modify traceback, if debug isn't set. 262 | # Exceptions raised from a dynamically loaded module will then 263 | # be displayed a lot cleaner. 264 | if not args.debug: 265 | install_exception_hook() 266 | 267 | if 'func' not in args: 268 | # Fallback to --help. 269 | parser.print_help() 270 | else: 271 | fora.args = args 272 | args.func(args) 273 | -------------------------------------------------------------------------------- /src/fora/operations/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains all standard operation modules.""" 2 | 3 | from fora.utils import import_submodules 4 | 5 | # Import all submodules to ensure that decorators have a chance 6 | # to register operations to a registry (e.g. package_managers). 7 | import_submodules(__name__) 8 | -------------------------------------------------------------------------------- /src/fora/operations/api.py: -------------------------------------------------------------------------------- 1 | """Provides API to define operations.""" 2 | 3 | from dataclasses import dataclass 4 | import subprocess 5 | import sys 6 | 7 | from functools import wraps 8 | from typing import Callable, TypeVar, cast, Any, Optional 9 | from types import TracebackType, FrameType 10 | 11 | import fora 12 | from fora import logger 13 | from fora.types import RemoteDefaultsContext 14 | from fora.utils import check_host_active, print_process_error 15 | 16 | class OperationError(Exception): 17 | """An exception that indicates an error while executing an operation.""" 18 | 19 | @dataclass 20 | class OperationResult: 21 | """Stores the result of an operation.""" 22 | success: bool 23 | """Whether the operation succeeded.""" 24 | changed: bool 25 | """Whether the operation changed something.""" 26 | initial: dict[str, Any] 27 | """The initial state of the host.""" 28 | final: dict[str, Any] 29 | """The final state of the host.""" 30 | failure_message: Optional[str] = None 31 | """The failure message, if success is False.""" 32 | 33 | class Operation: 34 | """This class is used to ease the building of operations with consistent output and state tracking.""" 35 | 36 | internal_use_only: "Operation" = cast("Operation", None) 37 | """operation's op variable is defaulted to this value to indicate that it must not be given by the user.""" 38 | 39 | def __init__(self, op_name: str, name: Optional[str]): 40 | self.op_name = op_name 41 | self.name = name 42 | self.has_nested = False 43 | self.description: str = "?" 44 | self.initial_state_dict: Optional[dict[str, Any]] = None 45 | self.final_state_dict: Optional[dict[str, Any]] = None 46 | self.diffs: list[tuple[str, Optional[bytes], Optional[bytes]]] = [] 47 | 48 | def nested(self, has_nested: bool) -> None: 49 | """ 50 | Sets whet this operation spawns nested operations. In this case, 51 | this operation will not have separate state, and the printing will be 52 | handled differently. 53 | 54 | Parameters 55 | ---------- 56 | has_nested 57 | Whether the operation has nested operations. 58 | """ 59 | self.has_nested = has_nested 60 | 61 | def add_nested_result(self, key: str, result: OperationResult) -> None: 62 | """ 63 | Adds initial and final state of a nested operation under the given key 64 | into this operation's state dictionaries. 65 | 66 | Parameters 67 | ---------- 68 | key 69 | The key under which to add the nested result. 70 | result 71 | The result to add. 72 | """ 73 | if not self.has_nested: 74 | raise OperationError("An operation can only accumulate nested results if it is marked as nested.") 75 | if self.initial_state_dict is None: 76 | self.initial_state_dict = {} 77 | if self.final_state_dict is None: 78 | self.final_state_dict = {} 79 | if key in self.initial_state_dict or key in self.final_state_dict: 80 | raise OperationError(f"Cannot add nested operation result under existing key '{key}'.") 81 | self.initial_state_dict[key] = result.initial 82 | self.final_state_dict[key] = result.final 83 | 84 | def desc(self, description: str) -> None: 85 | """ 86 | Sets the description of the operation, and prints an 87 | early status via the logger. 88 | 89 | Parameters 90 | ---------- 91 | description 92 | The new description. 93 | """ 94 | self.description = description 95 | logger.print_operation_early(self) 96 | if self.has_nested: 97 | print() 98 | 99 | def defaults(self, *args: Any, **kwargs: Any) -> RemoteDefaultsContext: 100 | """Sets defaults on the current script. See `fora.types.ScriptWrapper.defaults`.""" 101 | _ = (self) 102 | return fora.script.defaults(*args, **kwargs) 103 | 104 | def initial_state(self, **kwargs: Any) -> None: 105 | """Sets the initial state.""" 106 | if self.has_nested: 107 | raise OperationError("An operation that nests other operations cannot have state on its own.") 108 | if self.initial_state_dict is not None: 109 | raise OperationError("An operation's 'initial_state' can only be set once.") 110 | self.initial_state_dict = dict(kwargs) 111 | 112 | def final_state(self, **kwargs: Any) -> None: 113 | """Sets the final state.""" 114 | if self.has_nested: 115 | raise OperationError("An operation that nests other operations cannot have state on its own.") 116 | if self.final_state_dict is not None: 117 | raise OperationError("An operation's 'final_state' can only be set once.") 118 | self.final_state_dict = dict(kwargs) 119 | 120 | def unchanged(self, ignore_none: bool = False) -> bool: 121 | """ 122 | Checks whether the initial and final states differ. 123 | 124 | Parameters 125 | ---------- 126 | ignore_none 127 | Set to `True` to not count states where the final value is None. 128 | 129 | Returns 130 | ------- 131 | bool 132 | Whether the states differ. 133 | """ 134 | if self.initial_state_dict is None or self.final_state_dict is None: 135 | raise OperationError("Both initial and final state must have been set before 'unchanged()' may be called.") 136 | 137 | if not ignore_none: 138 | return self.initial_state_dict == self.final_state_dict 139 | 140 | keys_not_none = (k for k in self.final_state_dict if k is not None) 141 | for k in keys_not_none: 142 | if self.initial_state_dict[k] != self.final_state_dict[k]: 143 | return False 144 | return True 145 | 146 | 147 | def changed(self, key: str) -> bool: 148 | """ 149 | Checks whether a specific key will change. 150 | 151 | Parameters 152 | ---------- 153 | key 154 | The key to check for changes. 155 | 156 | Returns 157 | ------- 158 | bool 159 | Whether the states differ. 160 | """ 161 | if self.has_nested: 162 | raise OperationError("An operation that nests other operations cannot have state on its own.") 163 | if self.initial_state_dict is None or self.final_state_dict is None: 164 | raise OperationError("Both initial and final state must have been set before 'changed()' may be called.") 165 | return bool(self.initial_state_dict[key] != self.final_state_dict[key]) 166 | 167 | def diff(self, file: str, old: Optional[bytes], new: Optional[bytes]) -> None: 168 | """ 169 | Adds a file to the diffing output. 170 | 171 | Parameters 172 | ---------- 173 | file 174 | The filename which the diff belongs to. 175 | old 176 | The previous content or None if the file didn't exist previously. 177 | new 178 | The new content or None if the file was deleted. 179 | """ 180 | if self.has_nested: 181 | raise OperationError("An operation that nests other operations cannot have state on its own.") 182 | if old == new: 183 | return 184 | self.diffs.append((file, old, new)) 185 | 186 | def failure(self, msg: str) -> OperationResult: 187 | """ 188 | Returns a failed operation result. 189 | 190 | Returns 191 | ------- 192 | OperationResult 193 | The OperationResult for this failed operation. 194 | """ 195 | result = OperationResult(success=False, 196 | changed=False, 197 | initial=self.initial_state_dict or {}, 198 | final=self.final_state_dict or {}, 199 | failure_message=msg) 200 | if not self.has_nested: 201 | logger.print_operation(self, result) 202 | return result 203 | 204 | def success(self) -> OperationResult: 205 | """ 206 | Returns a successful operation result. 207 | 208 | Returns 209 | ------- 210 | OperationResult 211 | The OperationResult for this successful operation. 212 | """ 213 | if self.initial_state_dict is None or self.final_state_dict is None: 214 | raise OperationError("Both initial and final state must have been set before 'success()' may be called.") 215 | result = OperationResult(success=True, 216 | changed=not self.unchanged(), 217 | initial=self.initial_state_dict, 218 | final=self.final_state_dict) 219 | if not self.has_nested: 220 | logger.print_operation(self, result) 221 | return result 222 | 223 | _TFunc = TypeVar("_TFunc", bound=Callable[..., Any]) 224 | # This is untyped as the language server can then apparently 225 | # complete the wrapped function correctly, and ParamSpec 226 | # was only introduced in python 3.10. 227 | def operation(op_name: str): # type: ignore[no-untyped-def] 228 | """Operation function decorator.""" 229 | 230 | def _calling_site_traceback() -> TracebackType: 231 | """ 232 | Returns a modified traceback object which can be used in Exception.with_traceback() to make 233 | the exception appear as if it originated at the calling site of the operation. 234 | """ 235 | try: 236 | raise AssertionError 237 | except AssertionError: 238 | traceback = sys.exc_info()[2] 239 | if traceback is None: 240 | raise RuntimeError("Traceback cannot be None. This is a bug!") from None 241 | back_frame: Optional[FrameType] = traceback.tb_frame 242 | back_frame = back_frame.f_back if back_frame else None # Omit this function 243 | back_frame = back_frame.f_back if back_frame else None # Omit the function where _calling_site_traceback is called (the operation_wrapper below) 244 | if back_frame is None: 245 | raise RuntimeError("Error in site traceback: back_frame cannot be None. This is a bug!") from None 246 | 247 | return TracebackType(tb_next=None, 248 | tb_frame=back_frame, 249 | tb_lasti=back_frame.f_lasti, 250 | tb_lineno=back_frame.f_lineno) 251 | 252 | def operation_wrapper(function: _TFunc) -> _TFunc: 253 | @wraps(function) 254 | def wrapper(*args: Any, **kwargs: Any) -> Any: 255 | check_host_active() 256 | 257 | op = Operation(op_name=op_name, name=kwargs.get("name", None)) 258 | check = kwargs.get("check", True) 259 | 260 | try: 261 | ret = function(*args, **kwargs, op=op) 262 | except OperationError as e: 263 | ret = op.failure(str(e)) 264 | # If we are not in debug mode, we modify the traceback such that the exception 265 | # seems to originate at the calling site where the operation is called. 266 | if fora.args.debug: 267 | raise 268 | raise e.with_traceback(_calling_site_traceback()) 269 | except subprocess.CalledProcessError as e: 270 | ret = op.failure(str(e)) 271 | if not hasattr(e, "__fora_already_printed"): 272 | print_process_error(e) 273 | setattr(e, "__fora_already_printed", True) 274 | if fora.args.debug: 275 | raise 276 | raise e.with_traceback(_calling_site_traceback()) 277 | except Exception as e: 278 | ret = op.failure(str(e)) 279 | raise 280 | 281 | if ret is None: 282 | raise OperationError("The operation failed to return a status. THIS IS A BUG! Please report it to the package maintainer of the package which the operation belongs to.") 283 | 284 | if check and not ret.success: 285 | error = OperationError(ret.failure_message) 286 | # If we are not in debug mode, we modify the traceback such that the exception 287 | # seems to originate at the calling site where the operation is called. 288 | if fora.args.debug: 289 | raise error 290 | raise error.with_traceback(_calling_site_traceback()) 291 | 292 | return ret 293 | return cast(_TFunc, wrapper) 294 | return operation_wrapper 295 | -------------------------------------------------------------------------------- /src/fora/operations/apt.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to the apt package manager.""" 2 | 3 | from functools import partial 4 | from typing import Optional 5 | import fora 6 | from fora.operations.api import Operation, OperationResult, operation 7 | from fora.operations.utils import generic_package, package_manager 8 | 9 | def _is_installed(package: str, opts: Optional[list[str]] = None) -> bool: # pylint: disable=redefined-outer-name 10 | """Checks whether a package is installed with dpkg-query on the remote host.""" 11 | opts = opts or [] 12 | ret = fora.host.connection.run(["dpgk-query", "--show", "--showformat=${Status}"] + opts + ["--", package]) 13 | return ret.stdout is not None and b"ok installed" in ret.stdout 14 | 15 | def _install(package: str, opts: Optional[list[str]] = None) -> None: # pylint: disable=redefined-outer-name 16 | """Installs a package with apt-get on the remote host.""" 17 | opts = opts or [] 18 | fora.host.connection.run(["apt-get", "install"] + opts + ["--", package]) 19 | 20 | def _uninstall(package: str, opts: Optional[list[str]] = None) -> None: # pylint: disable=redefined-outer-name 21 | """Uninstalls a package with apt-get on the remote host.""" 22 | opts = opts or [] 23 | fora.host.connection.run(["apt-get", "remove"] + opts + ["--", package]) 24 | 25 | @package_manager(command="apt-get") 26 | @operation("package") 27 | def package(packages: list[str], 28 | present: bool = True, 29 | opts: Optional[list[str]] = None, 30 | name: Optional[str] = None, 31 | check: bool = True, 32 | op: Operation = Operation.internal_use_only) -> OperationResult: 33 | """ 34 | Adds or removes system packages with apt-get. 35 | 36 | Parameters 37 | ---------- 38 | packages 39 | The packages to modify. 40 | present 41 | Whether the given package should be installed or uninstalled. 42 | opts 43 | Extra options passed to apt-get when installing/uninstalling. 44 | name 45 | The name for the operation. 46 | check 47 | If True, returning `op.failure()` will raise an OperationError. All manually raised 48 | OperationErrors will be propagated. When False, any manually raised OperationError will 49 | be caught and `op.failure()` will be returned with the given message while continuing execution. 50 | op 51 | The operation wrapper. Must not be supplied by the user. 52 | """ 53 | _ = (name, check) # Processed automatically. 54 | op.desc(str(packages)) 55 | 56 | return generic_package(op, packages, 57 | present=present, 58 | is_installed=_is_installed, 59 | install=partial(_install, opts=opts), 60 | uninstall=partial(_uninstall, opts=opts)) 61 | -------------------------------------------------------------------------------- /src/fora/operations/git.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to git.""" 2 | 3 | import os 4 | from typing import Optional 5 | import fora 6 | from fora.operations.api import Operation, OperationResult, operation 7 | from fora.operations.utils import check_absolute_path 8 | 9 | @operation("repo") 10 | def repo(url: str, 11 | path: str, 12 | branch_or_tag: Optional[str] = None, 13 | update: bool = True, 14 | depth: Optional[int] = None, 15 | rebase: bool = True, 16 | ff_only: bool = False, 17 | update_submodules: bool = False, 18 | recursive_submodules: bool = False, 19 | shallow_submodules: bool = False, 20 | name: Optional[str] = None, 21 | check: bool = True, 22 | op: Operation = Operation.internal_use_only) -> OperationResult: 23 | """ 24 | Clones or updates a git repository and its submodules. 25 | 26 | Parameters 27 | ---------- 28 | url 29 | The url to the git repository. 30 | path 31 | The path where the repository should be cloned. 32 | branch_or_tag 33 | Either a branch name or a tag to clone. Follows the default branch of the remote if not given. 34 | update 35 | Whether to keep the repository up to date if it has already been cloned. 36 | depth 37 | Keep the repository as a shallow clone with the specified number of commits. 38 | Also applies when pulling updates. 39 | rebase 40 | Use `--rebase` when pulling updates. 41 | ff_only 42 | Use `--ff-only` when pulling updates. 43 | update_submodules 44 | Also initialize and update submodules after cloning or pulling. 45 | recursive_submodules 46 | Recursively update submodules after cloning or pulling. 47 | shallow_submodules 48 | Also apply the given `depth` to submodule updates. 49 | name 50 | The name for the operation. 51 | check 52 | If True, returning `op.failure()` will raise an OperationError. All manually raised 53 | OperationErrors will be propagated. When False, any manually raised OperationError will 54 | be caught and `op.failure()` will be returned with the given message while continuing execution. 55 | op 56 | The operation wrapper. Must not be supplied by the user. 57 | """ 58 | # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements 59 | _ = (name, check) # Processed automatically. 60 | check_absolute_path(path, f"{path=}") 61 | op.desc(f"{path} [{url}]") 62 | 63 | conn = fora.host.connection 64 | 65 | stat_path = conn.stat(path) 66 | if stat_path is None: 67 | op.initial_state(initialized=False, commit=None) 68 | cur_commit = None 69 | elif stat_path.type == "dir": 70 | # Assert that it is a git directory 71 | stat_git = conn.stat(os.path.join(path, ".git")) 72 | if stat_git is None: 73 | return op.failure(f"directory '{path}' already exists but is not a git repository") 74 | 75 | if stat_git.type != "dir": 76 | return op.failure(f"directory '{path}' already exists but doesn't contains a valid .git directory") 77 | 78 | remote_commit = conn.run(["git", "-C", path, "rev-parse", "HEAD"]) 79 | cur_commit = (remote_commit.stdout or b"").decode("utf-8", errors="backslashreplace").strip() 80 | op.initial_state(initialized=True, commit=cur_commit) 81 | else: 82 | return op.failure(f"path '{path}' exists but is not a directory!") 83 | 84 | # If the repository is already cloned but we shouldn't update, 85 | # nothing will change and we are done. 86 | if stat_path is not None and not update: 87 | op.final_state(initialized=True, commit=cur_commit) 88 | return op.success() 89 | 90 | # Check the newest available commit 91 | remote_newest_commit = conn.run(["git", "ls-remote", "--exit-code", "--", url, branch_or_tag or "HEAD"]) 92 | newest_commit = (remote_newest_commit.stdout or b"").decode("utf-8", errors="backslashreplace").strip().split()[0] 93 | 94 | op.final_state(initialized=True, commit=newest_commit) 95 | 96 | # Return success if nothing needs to be changed 97 | if op.unchanged(): 98 | return op.success() 99 | 100 | # Apply actions to reach new state, if we aren't in pretend mode 101 | if not fora.args.dry: 102 | if stat_path is None: 103 | # Create a fresh clone of the repository 104 | clone_cmd = ["git", "clone"] 105 | if depth is not None: 106 | clone_cmd.extend(["--depth", str(depth)]) 107 | if branch_or_tag is not None: 108 | clone_cmd.extend(["--branch", branch_or_tag]) 109 | clone_cmd.extend(["--", url, path]) 110 | conn.run(clone_cmd) 111 | 112 | if update_submodules: 113 | # Initialize submodules if requested 114 | submodule_cmd = ["git", "-C", path, "submodule", "update", "--init"] 115 | if shallow_submodules and depth is not None: 116 | submodule_cmd.extend(["--depth", str(depth)]) 117 | if recursive_submodules: 118 | submodule_cmd.extend(["--recursive"]) 119 | conn.run(submodule_cmd) 120 | elif update: 121 | # Assert that the existing repository's remote url matches the given url to prevent pulling an unrelated repo 122 | ret_current_remote = conn.run(["git", "-C", path, "config", "--get", "remote.origin.url"]) 123 | current_remote = (ret_current_remote.stdout or b"").decode("utf-8", errors="backslashreplace").strip() 124 | if current_remote != url: 125 | return op.failure(f"refusing to update existing git repository with different remote url '{current_remote}'") 126 | 127 | # Update the existing repository 128 | update_cmd = ["git", "-C", path, "pull"] 129 | if depth is not None: 130 | update_cmd.extend(["--depth", str(depth)]) 131 | if rebase: 132 | update_cmd.append("--rebase") 133 | if ff_only: 134 | update_cmd.append("--ff-only") 135 | conn.run(update_cmd) 136 | 137 | if update_submodules: 138 | # Update submodules if requested 139 | submodule_update_cmd = ["git", "-C", path, "submodule", "update", "--init"] 140 | if shallow_submodules and depth is not None: 141 | submodule_update_cmd.extend(["--depth", str(depth)]) 142 | if recursive_submodules: 143 | submodule_update_cmd.extend(["--recursive"]) 144 | conn.run(submodule_update_cmd) 145 | 146 | return op.success() 147 | -------------------------------------------------------------------------------- /src/fora/operations/local.py: -------------------------------------------------------------------------------- 1 | """Provides operations that are related to the local system on which the fora scripts are executed.""" 2 | 3 | import inspect 4 | import os 5 | from typing import Any, Optional 6 | 7 | from fora.loader import script_stack, run_script 8 | from fora.utils import check_host_active 9 | 10 | def script(script: str, # pylint: disable=redefined-outer-name 11 | recursive: bool = False, 12 | params: Optional[dict[str, Any]] = None, 13 | name: Optional[str] = None) -> None: 14 | """ 15 | Executes the given script for the current host. 16 | Useful to split functionality into smaller sub-scripts. 17 | 18 | Scripts can take parameters. Parameters to scripts are passed by 19 | supplying a `params` dictionary. The script declares its parameters 20 | by annotating them. (The annotation then transparently extracts the 21 | value from a separately passed global variable). 22 | 23 | ```python 24 | @Params 25 | class params: 26 | username: str 27 | website_title: str = "Default website title." 28 | 29 | # Use a parameter 30 | print(params.username) 31 | ``` 32 | 33 | Parameters 34 | ---------- 35 | script 36 | The local path to the script to execute. 37 | recursive 38 | Whether recursive calls should be allowed. 39 | params 40 | The parameters for the script. 41 | name 42 | The name for the script execution (used for logging). 43 | """ 44 | check_host_active() 45 | 46 | # Asserts that the call is not recursive, if not explicitly allowed 47 | if not recursive: 48 | for wrapper, _ in script_stack: 49 | # pylint: disable=protected-access 50 | if os.path.samefile(script, wrapper.definition_file()): 51 | raise ValueError(f"Invalid recursive call to script '{script}'. Use recursive=True to allow this.") 52 | 53 | outer_frame = inspect.getouterframes(inspect.currentframe())[1] 54 | run_script(script, outer_frame, params=params, name=name) 55 | -------------------------------------------------------------------------------- /src/fora/operations/pacman.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to the pacman package manager.""" 2 | 3 | from functools import partial 4 | from typing import Optional 5 | import fora 6 | from fora.operations.api import Operation, OperationResult, operation 7 | from fora.operations.utils import generic_package, package_manager 8 | 9 | def _is_installed(package: str, opts: Optional[list[str]] = None) -> bool: # pylint: disable=redefined-outer-name 10 | """Checks whether a package is installed with pacman on the remote host.""" 11 | opts = opts or [] 12 | return fora.host.connection.run(["pacman", "-Ql"] + opts + ["--", package], check=False).returncode == 0 13 | 14 | def _install(package: str, opts: Optional[list[str]] = None) -> None: # pylint: disable=redefined-outer-name 15 | """Installs a package with pacman on the remote host.""" 16 | opts = opts or [] 17 | fora.host.connection.run(["pacman", "--color", "always", "--noconfirm", "-S"] + opts + ["--", package]) 18 | 19 | def _uninstall(package: str, opts: Optional[list[str]] = None) -> None: # pylint: disable=redefined-outer-name 20 | """Uninstalls a package with pacman on the remote host.""" 21 | opts = opts or [] 22 | fora.host.connection.run(["pacman", "--color", "always", "--noconfirm", "-Rns"] + opts + ["--", package]) 23 | 24 | @package_manager(command="pacman") 25 | @operation("package") 26 | def package(packages: list[str], 27 | present: bool = True, 28 | opts: Optional[list[str]] = None, 29 | name: Optional[str] = None, 30 | check: bool = True, 31 | op: Operation = Operation.internal_use_only) -> OperationResult: 32 | """ 33 | Adds or removes system packages with pacman. 34 | 35 | Parameters 36 | ---------- 37 | packages 38 | The packages to modify. 39 | present 40 | Whether the given package should be installed or uninstalled. 41 | opts 42 | Extra options passed to pacman when installing/uninstalling. 43 | name 44 | The name for the operation. 45 | check 46 | If True, returning `op.failure()` will raise an OperationError. All manually raised 47 | OperationErrors will be propagated. When False, any manually raised OperationError will 48 | be caught and `op.failure()` will be returned with the given message while continuing execution. 49 | op 50 | The operation wrapper. Must not be supplied by the user. 51 | """ 52 | _ = (name, check) # Processed automatically. 53 | op.desc(str(packages)) 54 | 55 | return generic_package(op, packages, 56 | present=present, 57 | is_installed=_is_installed, 58 | install=partial(_install, opts=opts), 59 | uninstall=partial(_uninstall, opts=opts)) 60 | -------------------------------------------------------------------------------- /src/fora/operations/pip.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to pip.""" 2 | -------------------------------------------------------------------------------- /src/fora/operations/portage.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to the portage package manager.""" 2 | 3 | from functools import partial 4 | from typing import Optional 5 | import fora 6 | from fora.operations.api import Operation, OperationResult, operation 7 | from fora.operations.utils import generic_package, package_manager 8 | 9 | def _is_installed(package: str, opts: Optional[list[str]] = None) -> bool: # pylint: disable=redefined-outer-name 10 | """Checks whether a package is installed with portage on the remote host.""" 11 | opts = opts or [] 12 | ret = fora.host.connection.run(["emerge", "--info"] + opts + ["--", package]) 13 | return ret.stdout is not None and b"was built with the following" in ret.stdout 14 | 15 | def _install(package: str, opts: Optional[list[str]] = None, oneshot: bool = False) -> None: # pylint: disable=redefined-outer-name 16 | """Installs a package with portage on the remote host.""" 17 | opts = opts or [] 18 | if oneshot: 19 | opts = ["--oneshot"] + opts 20 | fora.host.connection.run(["emerge", "--color=y", "--verbose"] + opts + ["--", package]) 21 | 22 | def _uninstall(package: str, opts: Optional[list[str]] = None) -> None: # pylint: disable=redefined-outer-name 23 | """Uninstalls a package with portage on the remote host.""" 24 | opts = opts or [] 25 | fora.host.connection.run(["emerge", "--color=y", "--verbose", "--depclean"] + opts + ["--", package]) 26 | 27 | @package_manager(command="emerge") 28 | @operation("package") 29 | def package(packages: list[str], 30 | present: bool = True, 31 | oneshot: bool = False, 32 | opts: Optional[list[str]] = None, 33 | name: Optional[str] = None, 34 | check: bool = True, 35 | op: Operation = Operation.internal_use_only) -> OperationResult: 36 | """ 37 | Adds or removes system packages with portage. 38 | 39 | Parameters 40 | ---------- 41 | packages 42 | The packages to modify. 43 | present 44 | Whether the given package should be installed or uninstalled. 45 | oneshot 46 | Whether to use --oneshot to install packages, which prevents them from being added to the world file. 47 | opts 48 | Extra options passed to emerge when installing/uninstalling. 49 | name 50 | The name for the operation. 51 | check 52 | If True, returning `op.failure()` will raise an OperationError. All manually raised 53 | OperationErrors will be propagated. When False, any manually raised OperationError will 54 | be caught and `op.failure()` will be returned with the given message while continuing execution. 55 | op 56 | The operation wrapper. Must not be supplied by the user. 57 | """ 58 | _ = (name, check) # Processed automatically. 59 | op.desc(str(package)) 60 | 61 | return generic_package(op, packages, 62 | present=present, 63 | is_installed=_is_installed, 64 | install=partial(_install, opts=opts, oneshot=oneshot), 65 | uninstall=partial(_uninstall, opts=opts)) 66 | -------------------------------------------------------------------------------- /src/fora/operations/postgres.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to postgres.""" 2 | -------------------------------------------------------------------------------- /src/fora/operations/systemd.py: -------------------------------------------------------------------------------- 1 | """Provides operations related to the systemd init system.""" 2 | 3 | from typing import Optional 4 | 5 | from fora.operations.api import Operation, OperationResult, operation 6 | from fora.operations.utils import service_manager 7 | import fora 8 | 9 | @operation("systemctl") 10 | def daemon_reload(user_mode: bool = False, 11 | name: Optional[str] = None, 12 | check: bool = True, 13 | op: Operation = Operation.internal_use_only) -> OperationResult: 14 | """ 15 | Manages a systemd unit. 16 | 17 | Parameters 18 | ---------- 19 | user_mode 20 | Whether `systemctl --user` should be used to make user specific changes. 21 | name 22 | The name for the operation. 23 | check 24 | If True, returning `op.failure()` will raise an OperationError. All manually raised 25 | OperationErrors will be propagated. When False, any manually raised OperationError will 26 | be caught and `op.failure()` will be returned with the given message while continuing execution. 27 | op 28 | The operation wrapper. Must not be supplied by the user. 29 | """ 30 | _ = (name, check) # Processed automatically. 31 | op.desc("daemon_reload") 32 | conn = fora.host.connection 33 | 34 | # This operation has no dynamic state. 35 | op.initial_state(reloaded=False) 36 | op.final_state(reloaded=True) 37 | 38 | if not fora.args.dry: 39 | if user_mode: 40 | conn.run(["systemctl", "--user", "daemon-reload"]) 41 | else: 42 | conn.run(["systemctl", "daemon-reload"]) 43 | 44 | return op.success() 45 | 46 | @service_manager(command="systemctl") 47 | @operation("service") 48 | def service(service: str, # pylint: disable=redefined-outer-name 49 | state: Optional[str] = None, 50 | enabled: Optional[bool] = None, 51 | user_mode: bool = False, 52 | name: Optional[str] = None, 53 | check: bool = True, 54 | op: Operation = Operation.internal_use_only) -> OperationResult: 55 | """ 56 | Manages a systemd unit. 57 | 58 | Parameters 59 | ---------- 60 | service 61 | The unit to manage. 62 | state 63 | The desired state of the unit. Valid options are `started`, `restarted`, `reloaded` and `stopped`. 64 | If None, the service's current state will not be changed. 65 | enabled 66 | Whether the unit should be started on boot. 67 | user_mode 68 | Whether `systemctl --user` should be used to make user specific changes. 69 | name 70 | The name for the operation. 71 | check 72 | If True, returning `op.failure()` will raise an OperationError. All manually raised 73 | OperationErrors will be propagated. When False, any manually raised OperationError will 74 | be caught and `op.failure()` will be returned with the given message while continuing execution. 75 | op 76 | The operation wrapper. Must not be supplied by the user. 77 | """ 78 | _ = (name, check) # Processed automatically. 79 | op.desc(service) 80 | conn = fora.host.connection 81 | 82 | state_actions: dict[str, str] = { 83 | "started": "start", 84 | "restarted": "restart", 85 | "reloaded": "reload", 86 | "stopped": "stop", 87 | } 88 | 89 | if state is not None and state not in state_actions: 90 | raise ValueError(f"Invalid target state '{state}'") 91 | 92 | # Examine current state 93 | systemd_active_state = conn.run(["systemctl", "show", "--value", "--property", "ActiveState", "--", service]).stdout 94 | if (systemd_active_state or b"").decode('utf-8', errors='ignore').strip() in ["active", "activating"]: 95 | cur_state = "started" 96 | else: 97 | cur_state = "stopped" 98 | 99 | systemd_unit_file_state = conn.run(["systemctl", "show", "--value", "--property", "UnitFileState", "--", service]).stdout 100 | cur_enabled = (systemd_unit_file_state or b"").decode('utf-8', errors='ignore').strip() == "enabled" 101 | 102 | op.initial_state(state=cur_state, enabled=cur_enabled) 103 | op.final_state(state=state, enabled=enabled) 104 | 105 | # Return success if nothing needs to be changed 106 | if op.unchanged(ignore_none=True): 107 | return op.success() 108 | 109 | # Apply actions to reach desired state, but only if we are not doing a dry run 110 | if not fora.args.dry: 111 | base_command = ["systemctl", "--user"] if user_mode else ["systemctl"] 112 | if op.changed("state") and state is not None: 113 | conn.run(base_command + [state_actions[state], "--", service]) 114 | 115 | if op.changed("enabled") and enabled is not None: 116 | conn.run(base_command + ["enable" if enabled else "disable", "--", service]) 117 | 118 | return op.success() 119 | -------------------------------------------------------------------------------- /src/fora/operations/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides utiliy functions for operations. 3 | """ 4 | 5 | import hashlib 6 | from typing import Any, Callable, Optional, Union 7 | from fora.connection import Connection 8 | import fora 9 | 10 | from fora.operations.api import Operation, OperationError, OperationResult 11 | 12 | package_managers: dict[str, Any] = {} 13 | """All registered package managers as a map from (command name -> package function).""" 14 | 15 | service_managers: dict[str, Any] = {} 16 | """All registered service managers as a map from (command name -> service function).""" 17 | 18 | def find_command(conn: Connection, command_to_result_map: dict[str, Any]) -> Optional[Any]: 19 | """ 20 | Searches for any of the commands provided as keys in `command_to_result_map`, 21 | and if found on the target system, returns the associated value from the map. 22 | """ 23 | query = " || ".join([f"{{ type &>/dev/null {cmd} && echo {cmd} ; }}" for cmd in command_to_result_map]) 24 | query += " || echo __unknown__" 25 | res = conn.run(["bash", "-c", query]) 26 | 27 | return command_to_result_map.get((res.stdout or b"").decode('utf-8', errors='ignore').strip(), None) 28 | 29 | def package_manager(command: str) -> Callable[[Callable], Callable]: 30 | """ 31 | Operation function decorator to denote that this operation constitutes the package() operation of a package manager. 32 | This will cause it to be registered such that system.package() can call this package function if the given command is detected on the remote system. 33 | 34 | See `fora.operations.pacman.package` for an example usage. 35 | """ 36 | def operation_wrapper(function: Callable) -> Callable: 37 | package_managers[command] = function 38 | return function 39 | return operation_wrapper 40 | 41 | def service_manager(command: str) -> Callable[[Callable], Callable]: 42 | """ 43 | Operation function decorator to denote that this operation constitutes the service() operation of a service manager. 44 | This will cause it to be registered such that system.service() can call this service function if the given command is detected on the remote system. 45 | 46 | See `fora.operations.systemd.service` for an example usage. 47 | """ 48 | def operation_wrapper(function: Callable) -> Callable: 49 | service_managers[command] = function 50 | return function 51 | return operation_wrapper 52 | 53 | def generic_package(op: Operation, 54 | packages: list[str], 55 | present: bool, 56 | is_installed: Callable[[str], bool], 57 | install: Callable[[str], None], 58 | uninstall: Callable[[str], None]) -> OperationResult: 59 | """ 60 | A generic package operation that will query the current system state and 61 | call install/uninstall on each of the packages where an action is required 62 | to reach the target state. 63 | 64 | Parameters 65 | ---------- 66 | op 67 | The operation wrapper. 68 | packages 69 | The packages to modify. 70 | present 71 | Whether the given package should be installed or uninstalled. 72 | is_installed 73 | A function that returns whether a given package is installed. 74 | install 75 | A function that installs the given package on the remote system. 76 | uninstall 77 | A function that uninstalls the given package on the remote system. 78 | """ 79 | # Examine current state 80 | installed = set() 81 | if not isinstance(packages, list): 82 | raise ValueError("'packages' must be a list!") 83 | 84 | for p in packages: 85 | if is_installed(p): 86 | installed.add(p) 87 | 88 | # Set initial and target state. 89 | op.initial_state(installed=sorted(list(installed))) 90 | op.final_state(installed=sorted(list(packages)) if present else []) 91 | 92 | # Return success if nothing needs to be changed 93 | if op.unchanged(): 94 | return op.success() 95 | 96 | # Apply actions to reach desired state, but only if we are not doing a dry run 97 | if not fora.args.dry: 98 | if present: 99 | for p in set(packages) - installed: 100 | install(p) 101 | else: 102 | for p in installed: 103 | uninstall(p) 104 | 105 | return op.success() 106 | 107 | def save_content(op: Operation, 108 | content: Union[bytes, str], 109 | dest: str, 110 | mode: Optional[str] = None, 111 | owner: Optional[str] = None, 112 | group: Optional[str] = None) -> OperationResult: 113 | """ 114 | Saves the given content as dest on the remote host. Only for use within an operation, 115 | if save_content is the main functionality. You must supply the op parameter. 116 | 117 | Parameters 118 | ---------- 119 | op 120 | The operation wrapper. 121 | content 122 | The file content. 123 | dest 124 | The remote destination path. 125 | mode 126 | The file mode. Uses the remote execution defaults if None. 127 | owner 128 | The file owner. Uses the remote execution defaults if None. 129 | group 130 | The file group. Uses the remote execution defaults if None. 131 | """ 132 | if isinstance(content, str): 133 | content = content.encode('utf-8') 134 | 135 | conn = fora.host.connection 136 | with op.defaults(file_mode=mode, owner=owner, group=group) as attr: 137 | final_sha512sum = hashlib.sha512(content).digest() 138 | op.final_state(exists=True, mode=attr.file_mode, owner=attr.owner, group=attr.group, sha512=final_sha512sum) 139 | 140 | # Examine current state 141 | stat = conn.stat(dest, sha512sum=True) 142 | if stat is None: 143 | # The directory doesn't exist 144 | op.initial_state(exists=False, mode=None, owner=None, group=None, sha512=None) 145 | else: 146 | if stat.type != "file": 147 | return op.failure(f"path '{dest}' exists but is not a file!") 148 | 149 | # The file exists but may have different attributes or content 150 | op.initial_state(exists=True, mode=stat.mode, owner=stat.owner, group=stat.group, sha512=stat.sha512sum) 151 | 152 | # Return success if nothing needs to be changed 153 | if op.unchanged(): 154 | return op.success() 155 | 156 | # Add diff if desired 157 | if fora.args.diff: 158 | op.diff(dest, conn.download_or(dest), content) 159 | 160 | # Apply actions to reach desired state, but only if we are not doing a dry run 161 | if not fora.args.dry: 162 | # Create directory if it doesn't exist 163 | if op.changed("exists") or op.changed("sha512"): 164 | conn.upload( 165 | file=dest, 166 | content=content, 167 | mode=attr.file_mode, 168 | owner=attr.owner, 169 | group=attr.group) 170 | else: 171 | # Set correct mode, if needed 172 | if op.changed("mode"): 173 | conn.run(["chmod", attr.file_mode, "--", dest]) 174 | 175 | # Set correct owner and group, if needed 176 | if op.changed("owner") or op.changed("group"): 177 | conn.run(["chown", f"{attr.owner}:{attr.group}", "--", dest]) 178 | 179 | return op.success() 180 | 181 | def check_absolute_path(path: str, path_desc: str) -> None: 182 | """ 183 | Asserts that a given path is non empty and absolute. 184 | 185 | Parameters 186 | ---------- 187 | path 188 | The path to check. 189 | path_desc 190 | Will be printed in case of error as a substitute for the invalid variable 191 | """ 192 | if not path: 193 | raise ValueError(f"Path {path_desc} must be non-empty") 194 | if path[0] != "/": 195 | raise ValueError(f"Path {path_desc} must be absolute") 196 | 197 | def new_op_fail(op_name: str, name: Optional[str], desc: str, error: str) -> OperationError: 198 | """ 199 | Creates a new operation with given name and description and immediately 200 | returns a failed status with the given error message. Also returns a OperationError in case 201 | the callee want's to raise and exception. 202 | 203 | This is useful for meta-operations, that have a failure condition before the 204 | required sub-operation is determined (e.g. `system.package` can call different package 205 | manager's package() operation, but can also fail to find a suitable one). 206 | """ 207 | op = Operation(op_name=op_name, name=name) 208 | op.desc(desc) 209 | op.failure(error) 210 | return OperationError(error) 211 | -------------------------------------------------------------------------------- /src/fora/remote_settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides a class that represents execution defaults for a remote host. 3 | """ 4 | 5 | from __future__ import annotations 6 | from typing import Any, Optional 7 | from dataclasses import dataclass 8 | 9 | @dataclass 10 | class ResolvedRemoteSettings: 11 | """ 12 | This class stores a resolved version of the RemoteSettings object, 13 | it only has more strict types for typechecking and is otherwise 14 | identical to the original object. 15 | """ 16 | owner: str 17 | group: str 18 | file_mode: str 19 | dir_mode: str 20 | umask: str 21 | cwd: str 22 | as_user: Optional[str] = None 23 | as_group: Optional[str] = None 24 | 25 | @dataclass 26 | class RemoteSettings: 27 | """ 28 | This class stores certain values that determine how things are executed on 29 | the remote host. This includes things such as the owner and group of newly 30 | created files, or the user as which commands are run. 31 | """ 32 | as_user: Optional[str] = None 33 | as_group: Optional[str] = None 34 | owner: Optional[str] = None 35 | group: Optional[str] = None 36 | file_mode: Optional[str] = None 37 | dir_mode: Optional[str] = None 38 | umask: Optional[str] = None 39 | cwd: Optional[str] = None 40 | 41 | def overlay(self, settings: RemoteSettings) -> RemoteSettings: 42 | """ 43 | Overlays settings on top of this. Values will only be overwritten 44 | if the new value is not None, effectively overlaying the given settings 45 | on top of the current settings. 46 | 47 | Parameters 48 | ---------- 49 | settings 50 | The setting values to overwrite 51 | 52 | Returns 53 | ------- 54 | The resulting overlayed remote settings 55 | """ 56 | return RemoteSettings( 57 | as_user = self.as_user if settings.as_user is None else settings.as_user, 58 | as_group = self.as_group if settings.as_group is None else settings.as_group, 59 | owner = self.owner if settings.owner is None else settings.owner, 60 | group = self.group if settings.group is None else settings.group, 61 | file_mode = self.file_mode if settings.file_mode is None else settings.file_mode, 62 | dir_mode = self.dir_mode if settings.dir_mode is None else settings.dir_mode, 63 | umask = self.umask if settings.umask is None else settings.umask, 64 | cwd = self.cwd if settings.cwd is None else settings.cwd) 65 | 66 | def __repr__(self) -> str: 67 | members = [None if self.as_user is None else ("as_user", self.as_user), 68 | None if self.as_group is None else ("as_group", self.as_group), 69 | None if self.owner is None else ("owner", self.owner), 70 | None if self.group is None else ("group", self.group), 71 | None if self.file_mode is None else ("file_mode", self.file_mode), 72 | None if self.dir_mode is None else ("dir_mode", self.dir_mode), 73 | None if self.umask is None else ("umask", self.umask), 74 | None if self.cwd is None else ("cwd", self.cwd)] 75 | existing_members: list[tuple[str, Any]] = [x for x in members if x is not None] 76 | member_str = ','.join([f"{n}={v}" for (n,v) in existing_members]) 77 | return f"RemoteSettings{{{member_str}}}" 78 | -------------------------------------------------------------------------------- /src/fora/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides utility functions. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import importlib 8 | import importlib.machinery 9 | import importlib.util 10 | import inspect 11 | import os 12 | import pkgutil 13 | import shutil 14 | import subprocess 15 | import sys 16 | import traceback 17 | import uuid 18 | from types import ModuleType, TracebackType 19 | from typing import Any, Collection, NoReturn, Type, TypeVar, Callable, Iterable, Optional, Union 20 | 21 | import fora 22 | from fora.logger import col 23 | 24 | class FatalError(Exception): 25 | """An exception type for fatal errors, optionally including a file location.""" 26 | def __init__(self, msg: str, loc: Optional[str] = None): 27 | super().__init__(msg) 28 | self.loc = loc 29 | 30 | T = TypeVar('T') 31 | 32 | # A set of all modules names that are dynamically loaded modules. 33 | # These are guaranteed to be unique across all possible modules, 34 | # as a random uuid will be generated at load-time for each module. 35 | dynamically_loaded_modules: set[str] = set() 36 | 37 | class CycleError(ValueError): 38 | """An error that is throw to report a cycle in a graph that must be cycle free.""" 39 | 40 | def __init__(self, msg: str, cycle: list[Any]): 41 | super().__init__(msg) 42 | self.cycle = cycle 43 | 44 | def print_status(status: str, msg: str) -> None: 45 | """Prints a message with a (possibly colored) status prefix.""" 46 | print(f"{col('[1;32m')}{status}{col('[m')} {msg}") 47 | 48 | def print_warning(msg: str) -> None: 49 | """Prints a message with a (possibly colored) 'warning: ' prefix.""" 50 | print(f"{col('[1;33m')}warning:{col('[m')} {msg}") 51 | 52 | def print_error(msg: str, loc: Optional[str] = None) -> None: 53 | """Prints a message with a (possibly colored) 'error: ' prefix.""" 54 | if loc is None: 55 | print(f"{col('[1;31m')}error:{col('[m')} {msg}", file=sys.stderr) 56 | else: 57 | print(f"{col('[1m')}{loc}: {col('[1;31m')}error:{col('[m')} {msg}", file=sys.stderr) 58 | 59 | def len_ignore_leading_ansi(s: str) -> int: 60 | """Returns the length of the string or 0 if it starts with `\033[`""" 61 | return 0 if s.startswith("\033[") else len(s) 62 | 63 | def ansilen(ss: Collection[str]) -> int: 64 | """Returns the length of all strings combined ignoring ansi control sequences""" 65 | return sum(map(len_ignore_leading_ansi, ss)) 66 | 67 | def ansipad(ss: Collection[str], pad: int = 0) -> str: 68 | """Joins an array of string and ansi codes together and pads the result with spaces to at least `pad` characters.""" 69 | return ''.join(ss) + " " * max(0, pad - ansilen(ss)) 70 | 71 | def print_fullwith(left: Optional[list[str]] = None, right: Optional[list[str]] = None, pad: str = '─', **kwargs: Any) -> None: 72 | """Prints a message padded to the terminal width to stderr.""" 73 | if not left: 74 | left = [] 75 | if not right: 76 | right = [] 77 | 78 | cols = max(shutil.get_terminal_size((80, 20)).columns, 80) 79 | n_pad = max(0, (cols - ansilen(left) - ansilen(right))) 80 | print(''.join(left) + pad * n_pad + ''.join(right), **kwargs) 81 | 82 | def print_table(header: Collection[Collection[str]], rows: Collection[Collection[Collection[str]]], box_color: str = "\033[90m", min_col_width: Optional[list[int]] = None) -> None: 83 | """Prints the given rows as an ascii box table.""" 84 | max_col_width = 40 85 | terminal_cols = max(shutil.get_terminal_size((80, 20)).columns, 80) 86 | 87 | # Calculate max needed with for each column 88 | cols = len(header) 89 | max_value_width = [0] * cols 90 | for i,v in enumerate(header): 91 | max_value_width[i] = max(max_value_width[i], ansilen(v)) 92 | for row in rows: 93 | for i,v in enumerate(row): 94 | max_value_width[i] = max(max_value_width[i], ansilen(v)) 95 | 96 | # Fairly distribute space between columns 97 | if min_col_width is None: 98 | min_col_width = [0] * cols 99 | even_col_width = terminal_cols // cols 100 | chars_needed_for_table_boxes = 3 * (len(header) - 1) 101 | available_space = terminal_cols - chars_needed_for_table_boxes 102 | col_width = [max(min_col_width[i], min(max_col_width, w, available_space - (cols - i - 1) * even_col_width)) for i,w in enumerate(max_value_width)] 103 | 104 | # Distribute remaining space to first column from the back that would need the space 105 | total_width = sum(col_width) 106 | rest = available_space - total_width 107 | if rest > 0: 108 | for i,w in reversed(list(enumerate(max_value_width))): 109 | if w > col_width[i]: 110 | col_width[i] += rest 111 | break 112 | 113 | # Print table 114 | col_reset = col("\033[m") 115 | col_box = col(box_color) 116 | delim = col_box + " │ " + col_reset 117 | print(delim.join([ansipad(col, w) for col,w in zip(header, col_width)])) 118 | print(col_box + "─┼─".join(["─" * w for w in col_width]) + col_reset) 119 | for row in rows: 120 | print(delim.join([ansipad(col, w) for col,w in zip(row, col_width)])) 121 | 122 | def die_error(msg: str, loc: Optional[str] = None, status_code: int = 1) -> NoReturn: 123 | """Prints a message with a colored 'error: ' prefix, and exit with the given status code afterwards.""" 124 | print_error(msg, loc=loc) 125 | sys.exit(status_code) 126 | 127 | def load_py_module(file: str, pre_exec: Optional[Callable[[ModuleType], None]] = None) -> ModuleType: 128 | """ 129 | Loads a module from the given filename and assigns a unique module name to it. 130 | Calling this function twice for the same file will yield distinct instances. 131 | """ 132 | module_id = str(uuid.uuid4()).replace('-', '_') 133 | module_name = f"{os.path.splitext(os.path.basename(file))[0]}__dynamic__{module_id}" 134 | dynamically_loaded_modules.add(module_name) 135 | loader = importlib.machinery.SourceFileLoader(module_name, file) 136 | spec = importlib.util.spec_from_loader(loader.name, loader) 137 | if spec is None: 138 | raise ValueError(f"Failed to load module from file '{file}'") 139 | 140 | mod = importlib.util.module_from_spec(spec) 141 | setattr(mod, "__path__", os.path.realpath(os.path.dirname(file))) 142 | # Run pre_exec callback after the module is loaded but before it is executed 143 | if pre_exec is not None: 144 | pre_exec(mod) 145 | loader.exec_module(mod) 146 | return mod 147 | 148 | def rank_sort(vertices: Iterable[T], preds_of: Callable[[T], Iterable[T]], childs_of: Callable[[T], Iterable[T]]) -> dict[T, int]: 149 | """ 150 | Calculates the top-down rank for each vertex. Supports graphs with multiple components. 151 | The graph must not have any cycles or a CycleError will be thrown. 152 | 153 | Parameters 154 | ---------- 155 | vertices 156 | A list of vertices 157 | preds_of 158 | A function that returns a list of predecessors given a vertex 159 | childs_of 160 | A function that returns a list of successors given a vertex 161 | 162 | Raises 163 | ------ 164 | CycleError 165 | The given graph is cyclic. 166 | 167 | Returns 168 | ------- 169 | dict[T, int] 170 | A dict associating a rank to each vertex 171 | """ 172 | # Create a mapping of vertices to ranks. 173 | ranks = {v: -1 for v in vertices} 174 | 175 | # While there is at least one node without a rank, 176 | # find the "tree root" of that portion of the graph and 177 | # assign ranks to all reachable children without ranks. 178 | while -1 in ranks.values(): 179 | # Start at any unvisited node 180 | root = next(filter(lambda k: ranks[k] == -1, ranks.keys())) 181 | 182 | # Initialize a visited mapping to detect cycles 183 | visited = {v: False for v in vertices} 184 | visited[root] = True 185 | 186 | # Find the root of the current subtree, 187 | # or detect a cycle and abort. 188 | while any(True for _ in preds_of(root)): 189 | root = next(x for x in preds_of(root)) 190 | if visited[root]: 191 | cycle = list(filter(lambda v: visited[v], vertices)) 192 | raise CycleError("Cannot apply rank_sort to cyclic graph.", cycle) 193 | 194 | visited[root] = True 195 | 196 | # The root node has rank 0 197 | ranks[root] = 0 198 | 199 | # Now assign increasing ranks to children in a breadth-first manner 200 | # to avoid transitive dependencies from causing additional subtree-updates. 201 | # We start with a list of nodes to process and their parents stored as pairs. 202 | needs_rank_list = list((c, root) for c in childs_of(root)) 203 | while len(needs_rank_list) > 0: 204 | # Take the next node to process 205 | n, p = needs_rank_list.pop(0) 206 | r = ranks[p] + 1 207 | 208 | # If the rank to assign is greater than the total number of nodes, the graph must be cyclic. 209 | if r > len(list(vertices)): 210 | raise CycleError("Cannot apply rank_sort to cyclic graph.", [p]) 211 | 212 | # Skip nodes that already have a rank higher than 213 | # or equal to the one we would assign 214 | if ranks[n] >= r: 215 | continue 216 | 217 | # Assign rank 218 | ranks[n] = r 219 | # Queue childenii for rank assignment 220 | needs_rank_list.extend([(c, n) for c in childs_of(n)]) 221 | 222 | # Find cycles in dependencies by checking for the existence of any edge 223 | # that doesn't increase the rank. This is an error. 224 | for v in vertices: 225 | for c in childs_of(v): 226 | if ranks[c] <= ranks[v]: 227 | raise CycleError("Cannot apply rank_sort to cyclic graph (late detection).", [c, v]) 228 | 229 | return ranks 230 | 231 | def script_trace(script_stack: list[tuple[Any, inspect.FrameInfo]], 232 | include_root: bool = False) -> str: 233 | """ 234 | Creates a script trace similar to a python backtrace. 235 | 236 | Parameters 237 | ---------- 238 | script_stack 239 | The script stack to print 240 | include_root 241 | Whether or not to include the root frame in the script trace. 242 | """ 243 | def format_frame(f: inspect.FrameInfo) -> str: 244 | frame = f" File \"{f.filename}\", line {f.lineno}, in {f.frame.f_code.co_name}\n" 245 | if f.code_context is not None: 246 | for context in f.code_context: 247 | frame += f" {context.strip()}\n" 248 | return frame 249 | 250 | ret = "Script stack (most recent call last):\n" 251 | for _, frame in script_stack if include_root else script_stack[1:]: 252 | ret += format_frame(frame) 253 | 254 | return ret[:-1] # Strip last newline 255 | 256 | def print_exception(exc_type: Optional[Type[BaseException]], exc_info: Optional[BaseException], tb: Optional[TracebackType]) -> None: 257 | """ 258 | A function that hook that prints an exception traceback beginning from the 259 | last dynamically loaded module, but including a script stack so the error 260 | location is more easily understood and printed in a cleaner way. 261 | """ 262 | original_tb = tb 263 | last_dynamic_tb = None 264 | # Iterate over all frames in the traceback and 265 | # find the last dynamically loaded module in the traceback 266 | while tb: 267 | frame = tb.tb_frame 268 | if "__name__" in frame.f_locals and frame.f_locals['__name__'] in dynamically_loaded_modules: 269 | last_dynamic_tb = tb 270 | tb = tb.tb_next 271 | 272 | # Print the script stack if at least one user script is involved, 273 | # which means we need to have at least two entries as the root context 274 | # is also involved. 275 | script_stack = getattr(exc_info, "script_stack", None) 276 | if script_stack is not None and len(script_stack) > 1: 277 | print(script_trace(script_stack), file=sys.stderr) 278 | 279 | # Print the exception as usual begining from the last dynamic module, 280 | # if one is involved. 281 | traceback.print_exception(exc_type, exc_info, last_dynamic_tb or original_tb) 282 | 283 | def install_exception_hook() -> None: 284 | """ 285 | Installs a new global exception handler, that will modify the 286 | traceback of exceptions raised from dynamically loaded modules 287 | so that they are printed in a cleaner and more meaningful way (for the user). 288 | """ 289 | sys.excepthook = print_exception 290 | 291 | def import_submodules(package: Union[str, ModuleType], recursive: bool = False) -> dict[str, ModuleType]: 292 | """ 293 | Import all submodules of a module, possibly recursively including subpackages. 294 | 295 | Parameters 296 | ---------- 297 | package 298 | The package to import all submodules from. 299 | recursive 300 | Whether to recursively include subpackages. 301 | 302 | Returns 303 | ------- 304 | dict[str, ModuleType] 305 | """ 306 | if isinstance(package, str): 307 | package = importlib.import_module(package) 308 | results = {} 309 | for _, name, is_pkg in pkgutil.walk_packages(package.__path__): # type: ignore[attr-defined] 310 | full_name = package.__name__ + '.' + name 311 | results[full_name] = importlib.import_module(full_name) 312 | if recursive and is_pkg: 313 | results.update(import_submodules(results[full_name])) 314 | return results 315 | 316 | def check_host_active() -> None: 317 | """Asserts that an inventory has been loaded and a host is active.""" 318 | if fora.inventory is None or not fora.inventory.is_initialized(): 319 | raise FatalError("Invalid attempt to call operation before inventory was loaded! Did you maybe swap the inventory and deploy file on the command line?") 320 | if fora.host is None: 321 | raise FatalError("Invalid attempt to call operation while no host is active!") 322 | 323 | def print_process_error(err: subprocess.CalledProcessError) -> None: 324 | """ 325 | Pretty-prints the output of a failed process. 326 | 327 | Parameters 328 | ---------- 329 | err 330 | The error 331 | """ 332 | # Print output of failed command for debugging 333 | col_red = col("\033[1;31m") 334 | col_yellow = col("\033[1;33m") 335 | col_reset = col("\033[m") 336 | col_darker = col("\033[90m") 337 | print_fullwith(["──────── ", 338 | col_red, "command", col_reset, " ", 339 | str(err.cmd), " ", 340 | col_red, "failed", col_reset, " ", 341 | f"with code {err.returncode} ]"], file=sys.stderr) 342 | print_fullwith(["──────── ", col_yellow, "stdout", col_reset, col_darker, " (special characters escaped) ", col_reset], file=sys.stderr) 343 | print(err.stdout.decode("utf-8", errors="backslashreplace"), file=sys.stderr) 344 | print_fullwith(["──────── ", col_yellow, "stderr", col_reset, col_darker, " (special characters escaped) ", col_reset], file=sys.stderr) 345 | print(err.stderr.decode("utf-8", errors="backslashreplace"), file=sys.stderr) 346 | print_fullwith(["──────── ", col_yellow, "end", col_reset, col_darker, " ", col_reset], file=sys.stderr) 347 | -------------------------------------------------------------------------------- /test/group_dependency_cycle/inventory.py: -------------------------------------------------------------------------------- 1 | hosts = ["local:dummy"] 2 | groups = [dict(name="group1", before=["group2"]), dict(name="group2", before=["group1"])] 3 | -------------------------------------------------------------------------------- /test/group_dependency_cycle/test_group_dependency_cycle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import fora 5 | import fora.loader 6 | from fora.utils import FatalError 7 | 8 | def test_init(): 9 | class DefaultArgs: 10 | debug = False 11 | diff = False 12 | fora.args = DefaultArgs() 13 | 14 | def test_group_dependency_cycle(request): 15 | os.chdir(request.fspath.dirname) 16 | 17 | with pytest.raises(FatalError, match="cycle"): 18 | fora.loader.load_inventory("inventory.py") 19 | 20 | os.chdir(request.config.invocation_dir) 21 | -------------------------------------------------------------------------------- /test/group_dependency_cycle_complex/inventory.py: -------------------------------------------------------------------------------- 1 | hosts = ["local:dummy"] 2 | groups = [dict(name="group1", after=["group2"]), 3 | dict(name="group2", after=["group3"]), 4 | dict(name="group3", after=["group2"])] 5 | -------------------------------------------------------------------------------- /test/group_dependency_cycle_complex/test_group_dependency_cycle_complex.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import fora 5 | import fora.loader 6 | from fora.utils import FatalError 7 | 8 | def test_init(): 9 | class DefaultArgs: 10 | debug = False 11 | diff = False 12 | fora.args = DefaultArgs() 13 | 14 | def test_group_dependency_cycle_complex(request): 15 | os.chdir(request.fspath.dirname) 16 | 17 | with pytest.raises(FatalError, match="cycle"): 18 | fora.loader.load_inventory("inventory.py") 19 | 20 | os.chdir(request.config.invocation_dir) 21 | -------------------------------------------------------------------------------- /test/group_dependency_cycle_self/inventory.py: -------------------------------------------------------------------------------- 1 | hosts = ["local:dummy"] 2 | groups = [dict(name="group1", before=["group1"])] 3 | -------------------------------------------------------------------------------- /test/group_dependency_cycle_self/test_group_dependency_cycle_self.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import fora 5 | import fora.loader 6 | from fora.utils import FatalError 7 | 8 | def test_init(): 9 | class DefaultArgs: 10 | debug = False 11 | diff = False 12 | fora.args = DefaultArgs() 13 | 14 | def test_group_dependency_cycle_self(request): 15 | os.chdir(request.fspath.dirname) 16 | 17 | with pytest.raises(FatalError, match="must not depend on itself"): 18 | fora.loader.load_inventory("inventory.py") 19 | 20 | os.chdir(request.config.invocation_dir) 21 | -------------------------------------------------------------------------------- /test/group_variable_conflict/groups/group1.py: -------------------------------------------------------------------------------- 1 | overwrite_group = "group1" 2 | overwrite_group_second = "group1" 3 | -------------------------------------------------------------------------------- /test/group_variable_conflict/groups/group2.py: -------------------------------------------------------------------------------- 1 | overwrite_group = "group2" 2 | overwrite_group_second = "group2" 3 | -------------------------------------------------------------------------------- /test/group_variable_conflict/groups/group3.py: -------------------------------------------------------------------------------- 1 | overwrite_group = "group3" 2 | overwrite_group_second = "group3" 3 | -------------------------------------------------------------------------------- /test/group_variable_conflict/inventory.py: -------------------------------------------------------------------------------- 1 | hosts = [dict(url="local:dummy", groups=["group1", "group2", "group3"])] 2 | -------------------------------------------------------------------------------- /test/group_variable_conflict/test_group_variable_conflict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import fora 5 | import fora.loader 6 | from fora.utils import FatalError 7 | 8 | def test_init(): 9 | class DefaultArgs: 10 | debug = False 11 | diff = False 12 | fora.args = DefaultArgs() 13 | 14 | def test_group_variable_conflict(request): 15 | os.chdir(request.fspath.dirname) 16 | 17 | with pytest.raises(FatalError, match="Conflict in variable assignment"): 18 | fora.loader.load_inventory("inventory.py") 19 | 20 | os.chdir(request.config.invocation_dir) 21 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/empty.py: -------------------------------------------------------------------------------- 1 | hosts = [] 2 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/hosts/host1.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | pyfile = os.path.basename(__file__) 3 | url = "ssh://nobody@host1.localhost" 4 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/hosts/host2.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | pyfile = os.path.basename(__file__) 3 | url = "ssh://nobody@host2.localhost" 4 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/hosts/host_templ.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | pyfile = os.path.basename(__file__) 3 | url = "ssh://nobody@localhost" 4 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/invalid_hosts_entries.py: -------------------------------------------------------------------------------- 1 | # This contains bullshit. 2 | hosts = [ 123, 41 ] 3 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/missing_definition.py: -------------------------------------------------------------------------------- 1 | # No hosts= definition 2 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/simple_test.py: -------------------------------------------------------------------------------- 1 | hosts = ["host1", "host2", dict(url="host3", file="hosts/host_templ.py"), dict(url="host4", file="hosts/host_templ.py")] 2 | -------------------------------------------------------------------------------- /test/inventory/mock_inventories/single_host1.py: -------------------------------------------------------------------------------- 1 | inventory_var = "from_inventory" 2 | 3 | hosts = ["host1"] 4 | -------------------------------------------------------------------------------- /test/inventory/test_dynamic_instanciation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fora 3 | import fora.loader 4 | 5 | def test_dynamic_instanciation(request): 6 | os.chdir(request.fspath.dirname) 7 | 8 | try: 9 | fora.loader.load_inventory("mock_inventories/simple_test.py") 10 | assert fora.inventory is not None 11 | 12 | hosts = fora.inventory.hosts 13 | expected_files = ["host1.py", "host2.py", "host_templ.py", "host_templ.py"] 14 | 15 | for i in hosts: 16 | if isinstance(i, dict): 17 | i = i["url"] 18 | assert i in fora.inventory.loaded_hosts 19 | 20 | for h, e in zip(fora.inventory.loaded_hosts.values(), expected_files): 21 | assert hasattr(h, 'pyfile') 22 | assert getattr(h, 'pyfile', None) == e 23 | finally: 24 | os.chdir(request.config.invocation_dir) 25 | 26 | def test_inventory_global_variables(request): 27 | os.chdir(request.fspath.dirname) 28 | 29 | try: 30 | fora.loader.load_inventory("mock_inventories/single_host1.py") 31 | assert fora.inventory is not None 32 | 33 | assert "host1" in fora.inventory.loaded_hosts 34 | assert fora.inventory.loaded_hosts["host1"].inventory_var == "from_inventory" 35 | finally: 36 | os.chdir(request.config.invocation_dir) 37 | -------------------------------------------------------------------------------- /test/inventory/test_empty.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fora 3 | import fora.loader 4 | 5 | def test_empty(request): 6 | os.chdir(request.fspath.dirname) 7 | fora.loader.load_inventory("mock_inventories/empty.py") 8 | assert fora.inventory is not None 9 | assert fora.inventory.hosts == [] 10 | 11 | os.chdir(request.config.invocation_dir) 12 | -------------------------------------------------------------------------------- /test/inventory/test_missing_hosts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fora.utils import FatalError 3 | import pytest 4 | import fora.loader 5 | 6 | def test_missing_hosts(request): 7 | os.chdir(request.fspath.dirname) 8 | with pytest.raises(FatalError, match=r"must define a list of hosts"): 9 | fora.loader.load_inventory("mock_inventories/missing_definition.py") 10 | 11 | os.chdir(request.config.invocation_dir) 12 | -------------------------------------------------------------------------------- /test/operations/subdeploy.py: -------------------------------------------------------------------------------- 1 | print("subdeploy executed") 2 | -------------------------------------------------------------------------------- /test/simple_deploy/deploy.py: -------------------------------------------------------------------------------- 1 | from fora import host 2 | from fora.operations import files 3 | 4 | @Params 5 | class params: 6 | filename: str = "def" 7 | 8 | some_script_default = "def" 9 | 10 | files.template_content( 11 | dest="/tmp/__pytest_fora/test_deploy", 12 | content="{{ myvar }}", 13 | context=dict(myvar="testdeploy made this"), 14 | mode="644") 15 | 16 | assert params.filename == "def" 17 | assert not hasattr(host, 'bullshit') 18 | assert hasattr(host, 'some_script_default') 19 | assert getattr(host, 'some_script_default') == "def" 20 | -------------------------------------------------------------------------------- /test/simple_deploy/deploy_bad.py: -------------------------------------------------------------------------------- 1 | from fora.operations import files 2 | 3 | files.template_content( 4 | dest="invalid", 5 | content="", 6 | mode="invalid") 7 | -------------------------------------------------------------------------------- /test/simple_deploy/deploy_bad_recursive.py: -------------------------------------------------------------------------------- 1 | from fora.operations import local 2 | 3 | local.script(script="deploy.py", recursive=True) 4 | local.script(script="deploy_bad_recursive.py") 5 | -------------------------------------------------------------------------------- /test/simple_deploy/inventory.py: -------------------------------------------------------------------------------- 1 | hosts = ["root@localhost"] 2 | -------------------------------------------------------------------------------- /test/simple_inventory/groups/all.py: -------------------------------------------------------------------------------- 1 | value_from_all = "all" 2 | overwrite_group = "all" 3 | -------------------------------------------------------------------------------- /test/simple_inventory/groups/desktops.py: -------------------------------------------------------------------------------- 1 | overwrite_group = "desktops" 2 | -------------------------------------------------------------------------------- /test/simple_inventory/groups/only34.py: -------------------------------------------------------------------------------- 1 | from fora import group as this 2 | 3 | assert name == "only34" 4 | assert this.name == "only34" 5 | 6 | overwrite_group = "only34" 7 | -------------------------------------------------------------------------------- /test/simple_inventory/groups/somehosts.py: -------------------------------------------------------------------------------- 1 | overwrite_group = "somehosts" 2 | -------------------------------------------------------------------------------- /test/simple_inventory/hosts/host1.py: -------------------------------------------------------------------------------- 1 | overwrite_host = "host1" 2 | assert name == "host1" 3 | -------------------------------------------------------------------------------- /test/simple_inventory/hosts/host2.py: -------------------------------------------------------------------------------- 1 | overwrite_host = "host2" 2 | -------------------------------------------------------------------------------- /test/simple_inventory/hosts/host3.py: -------------------------------------------------------------------------------- 1 | overwrite_host = "host3" 2 | -------------------------------------------------------------------------------- /test/simple_inventory/hosts/host4.py: -------------------------------------------------------------------------------- 1 | overwrite_host = "host4" 2 | -------------------------------------------------------------------------------- /test/simple_inventory/hosts/host5.py: -------------------------------------------------------------------------------- 1 | overwrite_host = "host5" 2 | -------------------------------------------------------------------------------- /test/simple_inventory/inventory.py: -------------------------------------------------------------------------------- 1 | hosts = [dict(url="host1", groups=["desktops", "desktops", "desktops"]), 2 | dict(url="host2", groups=["desktops", "somehosts"]), 3 | dict(url="host3", groups=["only34", "desktops"]), 4 | dict(url="host4", groups=["only34", "somehosts"]), 5 | "host5"] 6 | groups = ["desktops", 7 | dict(name="somehosts", after=["desktops"]), 8 | dict(name="only34", after=["all"], before=["desktops"]), 9 | ] 10 | -------------------------------------------------------------------------------- /test/simple_inventory/test_simple_inventory.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, cast 3 | import fora 4 | import fora.loader 5 | 6 | def test_simple_inventory(request): 7 | os.chdir(request.fspath.dirname) 8 | 9 | fora.loader.load_inventory("inventory.py") 10 | for i in ["host1", "host2", "host3", "host4", "host5"]: 11 | assert i in fora.inventory.loaded_hosts 12 | 13 | assert set(fora.inventory.loaded_hosts["host1"].groups) == set(["all", "desktops"]) 14 | assert set(fora.inventory.loaded_hosts["host2"].groups) == set(["all", "desktops", "somehosts"]) 15 | assert set(fora.inventory.loaded_hosts["host3"].groups) == set(["all", "desktops", "only34"]) 16 | assert set(fora.inventory.loaded_hosts["host4"].groups) == set(["all", "somehosts", "only34"]) 17 | assert set(fora.inventory.loaded_hosts["host5"].groups) == set(["all"]) 18 | 19 | assert not hasattr(fora.inventory.loaded_hosts["host1"], '_bullshit') 20 | assert not hasattr(fora.inventory.loaded_hosts["host2"], '_bullshit') 21 | assert not hasattr(fora.inventory.loaded_hosts["host3"], '_bullshit') 22 | assert not hasattr(fora.inventory.loaded_hosts["host4"], '_bullshit') 23 | assert not hasattr(fora.inventory.loaded_hosts["host5"], '_bullshit') 24 | 25 | assert not hasattr(fora.inventory.loaded_hosts["host1"], 'bullshit') 26 | assert not hasattr(fora.inventory.loaded_hosts["host2"], 'bullshit') 27 | assert not hasattr(fora.inventory.loaded_hosts["host3"], 'bullshit') 28 | assert not hasattr(fora.inventory.loaded_hosts["host4"], 'bullshit') 29 | assert not hasattr(fora.inventory.loaded_hosts["host5"], 'bullshit') 30 | 31 | assert hasattr(fora.inventory.loaded_hosts["host1"], 'value_from_all') 32 | assert hasattr(fora.inventory.loaded_hosts["host2"], 'value_from_all') 33 | assert hasattr(fora.inventory.loaded_hosts["host3"], 'value_from_all') 34 | assert hasattr(fora.inventory.loaded_hosts["host4"], 'value_from_all') 35 | assert hasattr(fora.inventory.loaded_hosts["host5"], 'value_from_all') 36 | 37 | assert hasattr(fora.inventory.loaded_hosts["host1"], 'overwrite_host') 38 | assert hasattr(fora.inventory.loaded_hosts["host2"], 'overwrite_host') 39 | assert hasattr(fora.inventory.loaded_hosts["host3"], 'overwrite_host') 40 | assert hasattr(fora.inventory.loaded_hosts["host4"], 'overwrite_host') 41 | assert hasattr(fora.inventory.loaded_hosts["host5"], 'overwrite_host') 42 | 43 | assert cast(Any, fora.inventory.loaded_hosts["host1"]).overwrite_host == "host1" 44 | assert cast(Any, fora.inventory.loaded_hosts["host2"]).overwrite_host == "host2" 45 | assert cast(Any, fora.inventory.loaded_hosts["host3"]).overwrite_host == "host3" 46 | assert cast(Any, fora.inventory.loaded_hosts["host4"]).overwrite_host == "host4" 47 | assert cast(Any, fora.inventory.loaded_hosts["host5"]).overwrite_host == "host5" 48 | 49 | assert hasattr(fora.inventory.loaded_hosts["host1"], 'overwrite_group') 50 | assert hasattr(fora.inventory.loaded_hosts["host2"], 'overwrite_group') 51 | assert hasattr(fora.inventory.loaded_hosts["host3"], 'overwrite_group') 52 | assert hasattr(fora.inventory.loaded_hosts["host4"], 'overwrite_group') 53 | assert hasattr(fora.inventory.loaded_hosts["host5"], 'overwrite_group') 54 | 55 | assert cast(Any, fora.inventory.loaded_hosts["host1"]).overwrite_group == "desktops" 56 | assert cast(Any, fora.inventory.loaded_hosts["host2"]).overwrite_group == "somehosts" 57 | assert cast(Any, fora.inventory.loaded_hosts["host3"]).overwrite_group == "desktops" 58 | assert cast(Any, fora.inventory.loaded_hosts["host4"]).overwrite_group == "somehosts" 59 | assert cast(Any, fora.inventory.loaded_hosts["host5"]).overwrite_group == "all" 60 | 61 | os.chdir(request.config.invocation_dir) 62 | -------------------------------------------------------------------------------- /test/simple_inventory/testlink: -------------------------------------------------------------------------------- 1 | inventory.py -------------------------------------------------------------------------------- /test/templates/test.j2: -------------------------------------------------------------------------------- 1 | {{ myvar }} 2 | -------------------------------------------------------------------------------- /test/test_connection.py: -------------------------------------------------------------------------------- 1 | import grp 2 | import hashlib 3 | import os 4 | import pwd 5 | import pytest 6 | import subprocess 7 | from typing import cast 8 | 9 | import fora 10 | import fora.loader 11 | from fora.connection import Connection 12 | from fora.connectors.tunnel_dispatcher import RemoteOSError 13 | from fora.types import HostWrapper, ScriptWrapper 14 | 15 | host: HostWrapper = cast(HostWrapper, None) 16 | connection: Connection = cast(Connection, None) 17 | 18 | def test_init(): 19 | class DefaultArgs: 20 | debug = False 21 | diff = False 22 | fora.args = DefaultArgs() 23 | fora.loader.load_inventory("local:") 24 | 25 | global host 26 | host = fora.inventory.loaded_hosts["localhost"] 27 | fora.host = host 28 | fora.script = ScriptWrapper("__internal_test") 29 | class Empty: 30 | pass 31 | fora.script.wrap(Empty()) 32 | 33 | def current_test_user(): 34 | return pwd.getpwuid(os.getuid()).pw_name 35 | 36 | def current_test_group(): 37 | return grp.getgrgid(os.getgid()).gr_name 38 | 39 | def test_open_connection(): 40 | global connection 41 | connection = Connection(host) 42 | connection.__enter__() 43 | assert host.connection is connection 44 | 45 | ctx = fora.script.defaults() 46 | defs = ctx.__enter__() 47 | current_user = current_test_user() 48 | current_group = current_test_group() 49 | assert defs.as_user == current_user 50 | assert defs.as_group == current_group 51 | assert defs.owner == current_user 52 | assert defs.group == current_group 53 | assert defs.cwd == "/tmp" 54 | assert int(defs.dir_mode, 8) == 0o700 55 | assert int(defs.file_mode, 8) == 0o600 56 | assert int(defs.umask, 8) == 0o77 57 | 58 | def test_resolve_identity(): 59 | current_user = current_test_user() 60 | current_group = current_test_group() 61 | assert connection.base_settings.as_user == current_user 62 | assert connection.base_settings.as_group == current_group 63 | assert connection.base_settings.owner == current_user 64 | assert connection.base_settings.group == current_group 65 | 66 | def test_run_false(): 67 | with pytest.raises(subprocess.CalledProcessError) as e: 68 | connection.run(["false"]) 69 | assert 1 == e.value.returncode 70 | 71 | def test_run_false_unchecked(): 72 | ret = connection.run(["false"], check=False) 73 | assert ret.returncode == 1 74 | assert ret.stdout == b"" 75 | assert ret.stderr == b"" 76 | 77 | def test_run_true(): 78 | ret = connection.run(["true"]) 79 | assert ret.returncode == 0 80 | assert ret.stdout == b"" 81 | assert ret.stderr == b"" 82 | 83 | def test_run_not_a_shell(): 84 | with pytest.raises(RemoteOSError) as e: 85 | connection.run(["echo test"]) 86 | assert e.value.errno == 2 87 | 88 | def test_run_echo(): 89 | ret = connection.run(["echo", "abc"]) 90 | assert ret.returncode == 0 91 | assert ret.stdout == b"abc\n" 92 | assert ret.stderr == b"" 93 | 94 | def test_run_cat_input(): 95 | ret = connection.run(["cat"], input=b"test\nb") 96 | assert ret.returncode == 0 97 | assert ret.stdout == b"test\nb" 98 | assert ret.stderr == b"" 99 | 100 | def test_run_id(): 101 | ret = connection.run(["id"]) 102 | assert ret.returncode == 0 103 | assert ret.stdout is not None 104 | stdout = ret.stdout.decode('utf-8', 'ignore') 105 | assert f"uid={os.getuid()}({current_test_user()})" in stdout 106 | assert f"gid={os.getgid()}({current_test_group()})" in stdout 107 | assert ret.stderr == b"" 108 | 109 | def test_run_pwd(): 110 | ret = connection.run(["pwd"]) 111 | assert ret.returncode == 0 112 | assert ret.stdout == b"/tmp\n" 113 | assert ret.stderr == b"" 114 | 115 | def test_run_pwd_in_var_tmp(): 116 | ret = connection.run(["pwd"], cwd="/var/tmp") 117 | assert ret.returncode == 0 118 | assert ret.stdout == b"/var/tmp\n" 119 | assert ret.stderr == b"" 120 | 121 | def test_resolve_user_self(): 122 | assert connection.resolve_user(None) == current_test_user() 123 | 124 | def test_resolve_user_root_by_uid(): 125 | assert connection.resolve_user("0") == "root" 126 | 127 | def test_resolve_user_nobody(): 128 | assert connection.resolve_user("nobody") == "nobody" 129 | 130 | def test_resolve_user_invalid(): 131 | with pytest.raises(ValueError): 132 | connection.resolve_user("_invalid_") 133 | 134 | def test_resolve_group_self(): 135 | assert connection.resolve_group(None) == current_test_group() 136 | 137 | def test_resolve_group_root_by_uid(): 138 | assert connection.resolve_group("0") == "root" 139 | 140 | def test_resolve_group_nobody(): 141 | assert connection.resolve_group("nobody") == "nobody" 142 | 143 | def test_resolve_group_invalid(): 144 | with pytest.raises(ValueError): 145 | connection.resolve_group("_invalid_") 146 | 147 | @pytest.mark.parametrize("n", [0, 1, 8, 32, 128, 1024, 1024 * 32, 1024 * 256]) 148 | def test_upload_download(n): 149 | content = os.urandom(n) 150 | if os.path.exists("/tmp/__pytest_fora_upload"): 151 | os.remove("/tmp/__pytest_fora_upload") 152 | connection.upload("/tmp/__pytest_fora_upload", content=content) 153 | assert connection.download("/tmp/__pytest_fora_upload") == content 154 | assert connection.download_or("/tmp/__pytest_fora_upload") == content 155 | stat = connection.stat("/tmp/__pytest_fora_upload", sha512sum=True) 156 | assert stat is not None 157 | assert stat.type == "file" 158 | assert stat.size == n 159 | assert stat.sha512sum == hashlib.sha512(content).digest() 160 | os.remove("/tmp/__pytest_fora_upload") 161 | 162 | def test_upload_owner_group(): 163 | content = b"1234" 164 | if os.path.exists("/tmp/__pytest_fora_upload"): 165 | os.remove("/tmp/__pytest_fora_upload") 166 | connection.upload("/tmp/__pytest_fora_upload", content=content, mode="644", owner=str(os.getuid()), group=str(os.getgid())) 167 | assert connection.download("/tmp/__pytest_fora_upload") == content 168 | assert connection.download_or("/tmp/__pytest_fora_upload") == content 169 | stat = connection.stat("/tmp/__pytest_fora_upload", sha512sum=True) 170 | assert stat is not None 171 | assert stat.type == "file" 172 | assert stat.sha512sum == hashlib.sha512(content).digest() 173 | os.remove("/tmp/__pytest_fora_upload") 174 | 175 | def test_stat_nonexistent(): 176 | stat = connection.stat("/tmp/__nonexistent") 177 | assert stat is None 178 | 179 | def test_download_nonexistent(): 180 | assert connection.download_or("/tmp/__nonexistent") == None 181 | with pytest.raises(ValueError): 182 | assert connection.download("/tmp/__nonexistent") 183 | 184 | def test_run_none_in_fields(): 185 | ret = connection.connector.run(["true"], umask=None, user=None, group=None, cwd=None) 186 | assert ret.returncode == 0 187 | 188 | def test_run_invalid_command(): 189 | with pytest.raises(RemoteOSError, match=r"No such file or directory"): 190 | connection.run(["_invalid_"]) 191 | 192 | def test_run_invalid_umask(): 193 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'umask'"): 194 | connection.run(["true"], umask="_invalid_") 195 | 196 | def test_run_invalid_user(): 197 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'user'"): 198 | connection.run(["true"], user="_invalid_") 199 | 200 | def test_run_invalid_user_id(): 201 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'user'"): 202 | connection.run(["true"], user="1234567890") 203 | 204 | def test_run_invalid_group(): 205 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'group'"): 206 | connection.run(["true"], group="_invalid_") 207 | 208 | def test_run_invalid_group_id(): 209 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'group'"): 210 | connection.run(["true"], group="1234567890") 211 | 212 | def test_run_invalid_cwd(): 213 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'cwd'"): 214 | connection.run(["true"], cwd="/_invalid_") 215 | 216 | def test_upload_invalid_mode(): 217 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'mode'"): 218 | connection.upload("/invalid", content=b"", mode="_invalid_") 219 | 220 | def test_upload_invalid_owner(): 221 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'owner'"): 222 | connection.upload("/invalid", content=b"", owner="_invalid_") 223 | 224 | def test_upload_invalid_group(): 225 | with pytest.raises(ValueError, match=r"Invalid value.*given for field 'group'"): 226 | connection.upload("/invalid", content=b"", group="_invalid_") 227 | 228 | def test_query_group_nonexistent(): 229 | entry = connection.query_group("__nonexistent") 230 | assert entry is None 231 | 232 | def test_query_user_nonexistent(): 233 | entry = connection.query_user("__nonexistent") 234 | assert entry is None 235 | 236 | def test_query_group(): 237 | entry = connection.query_group("nobody") 238 | assert entry is not None 239 | assert entry.name == "nobody" 240 | 241 | def test_query_user(): 242 | with pytest.raises(RemoteOSError, match=r"Permission denied"): 243 | connection.query_user("nobody", query_password_hash=True) 244 | 245 | entry = connection.query_user("nobody") 246 | assert entry is not None 247 | assert entry.name == "nobody" 248 | 249 | def test_home_dir(): 250 | assert connection.home_dir() == pwd.getpwuid(os.getuid()).pw_dir 251 | 252 | def test_getenv(): 253 | assert connection.getenv("HOME") == os.getenv("HOME") 254 | assert connection.getenv("PATH") == os.getenv("PATH") 255 | assert connection.getenv("_nonexistent") is None 256 | 257 | def test_close_connection(): 258 | connection.__exit__(None, None, None) 259 | assert host.connection is None 260 | 261 | -------------------------------------------------------------------------------- /test/test_connector_resolve.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from types import SimpleNamespace 3 | from fora.inventory_wrapper import InventoryWrapper 4 | from fora.utils import FatalError 5 | import pytest 6 | from typing import Any, cast 7 | from fora.connectors.ssh import SshConnector 8 | 9 | from fora.types import HostWrapper 10 | 11 | def create_host(name: str): 12 | wrapper = HostWrapper(InventoryWrapper(), name=name, url=name) 13 | wrapper.wrap(SimpleNamespace()) 14 | return wrapper 15 | 16 | def test_explicit_connector(): 17 | @dataclass 18 | class TestConnector: 19 | name: str 20 | url: str 21 | 22 | h = cast(Any, create_host("ssh://red@herring.sea")) 23 | h.connector = TestConnector 24 | 25 | assert isinstance(h.create_connector(), TestConnector) 26 | 27 | def test_connector_invalid(): 28 | h = create_host("cannotresolve") 29 | with pytest.raises(FatalError, match=r"Url doesn't include a schema and no connector was specified"): 30 | h.create_connector() 31 | 32 | def test_connector_ssh(): 33 | h = cast(Any, create_host("ssh://user@host.localhost")) 34 | assert isinstance(h.create_connector(), SshConnector) 35 | 36 | def test_connector_unknown(): 37 | h = cast(Any, create_host("unknown://user@host.localhost")) 38 | 39 | with pytest.raises(FatalError, match=r"No connector found for schema"): 40 | h.create_connector() 41 | -------------------------------------------------------------------------------- /test/test_init_deploy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import shutil 4 | from pathlib import Path 5 | from typing import Any, cast 6 | 7 | from fora import example_deploys 8 | 9 | def test_init_unknown(): 10 | with pytest.raises(ValueError, match="Unknown deploy layout structure"): 11 | example_deploys.init_deploy_structure(cast(Any, "__unknown")) 12 | 13 | def test_init(request, tmp_path): 14 | try: 15 | for layout in ["minimal", "flat", "dotfiles", "modular", "staging_prod"]: 16 | p = os.path.join(tmp_path, layout) 17 | if os.path.exists(p): 18 | shutil.rmtree(p) 19 | Path(p).mkdir(exist_ok=False) 20 | 21 | os.chdir(p) 22 | with pytest.raises(SystemExit): 23 | example_deploys.init_deploy_structure(cast(Any, layout)) 24 | finally: 25 | os.chdir(request.config.invocation_dir) 26 | -------------------------------------------------------------------------------- /test/test_loading.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fora.main import main 4 | import fora 5 | import fora.loader 6 | 7 | def test_init(): 8 | # This function is needed to init fora global state, 9 | # which we will test and need for testing. 10 | fora.loader.load_inventory("local:") 11 | 12 | def test_group_functions_from_outside_definition(): 13 | assert fora.group is None 14 | 15 | def test_host_functions_from_outside_definition(): 16 | assert fora.host is None 17 | 18 | def test_help_output(): 19 | with pytest.raises(SystemExit) as e: 20 | main(["--help"]) 21 | assert e.value.code == 0 22 | 23 | def test_invalid_args(): 24 | with pytest.raises(SystemExit) as e: 25 | main(["--whatisthis_nonexistent"]) 26 | assert e.value.code == 1 27 | --------------------------------------------------------------------------------