├── src └── retrocookie │ ├── py.typed │ ├── pr │ ├── adapters │ │ ├── __init__.py │ │ ├── retrocookie.py │ │ └── github.py │ ├── compat │ │ ├── __init__.py │ │ ├── typing.py │ │ └── shlex.py │ ├── protocols │ │ ├── __init__.py │ │ ├── retrocookie.py │ │ └── github.py │ ├── base │ │ ├── __init__.py │ │ ├── bus.py │ │ └── exceptionhandlers.py │ ├── __init__.py │ ├── repository.py │ ├── list.py │ ├── cache.py │ ├── events.py │ ├── importer.py │ ├── core.py │ ├── __main__.py │ └── console.py │ ├── compat │ ├── __init__.py │ └── contextlib.py │ ├── __init__.py │ ├── utils.py │ ├── __main__.py │ ├── filter.py │ ├── core.py │ └── git.py ├── .gitattributes ├── .darglint ├── docs ├── license.rst ├── codeofconduct.rst ├── requirements.txt ├── usage.rst ├── usage-pr.rst ├── contributing.rst ├── reference.rst ├── conf.py └── index.rst ├── tests ├── pr │ ├── unit │ │ ├── __init__.py │ │ ├── fakes │ │ │ ├── __init__.py │ │ │ ├── retrocookie.py │ │ │ └── github.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ ├── test_bus.py │ │ │ └── test_exceptionhandlers.py │ │ ├── adapters │ │ │ ├── __init__.py │ │ │ └── test_github.py │ │ ├── test_repository.py │ │ ├── utils.py │ │ ├── data.py │ │ ├── conftest.py │ │ ├── test_cache.py │ │ ├── test_importer.py │ │ ├── test_main.py │ │ ├── test_list.py │ │ ├── test_console.py │ │ └── test_core.py │ ├── functional │ │ ├── __init__.py │ │ ├── test_github.py │ │ ├── test_main.py │ │ └── conftest.py │ └── __init__.py ├── __init__.py ├── test_utils.py ├── test_filter.py ├── helpers.py ├── conftest.py ├── test_main.py ├── test_core.py └── test_git.py ├── .github ├── workflows │ ├── constraints.txt │ ├── labeler.yml │ ├── release.yml │ └── tests.yml ├── dependabot.yml ├── release-drafter.yml └── labels.yml ├── .gitignore ├── codecov.yml ├── .readthedocs.yml ├── .cookiecutter.json ├── .flake8 ├── cutty.json ├── LICENSE.rst ├── .pre-commit-config.yaml ├── pyproject.toml ├── CONTRIBUTING.rst ├── CODE_OF_CONDUCT.rst ├── noxfile.py └── README.rst /src/retrocookie/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.darglint: -------------------------------------------------------------------------------- 1 | [darglint] 2 | strictness = long 3 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../LICENSE.rst 2 | -------------------------------------------------------------------------------- /tests/pr/unit/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit tests.""" 2 | -------------------------------------------------------------------------------- /docs/codeofconduct.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CODE_OF_CONDUCT.rst 2 | -------------------------------------------------------------------------------- /src/retrocookie/pr/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | """Adapters.""" 2 | -------------------------------------------------------------------------------- /tests/pr/functional/__init__.py: -------------------------------------------------------------------------------- 1 | """Functional tests.""" 2 | -------------------------------------------------------------------------------- /src/retrocookie/compat/__init__.py: -------------------------------------------------------------------------------- 1 | """Compatibility shims.""" 2 | -------------------------------------------------------------------------------- /src/retrocookie/pr/compat/__init__.py: -------------------------------------------------------------------------------- 1 | """Compatibility shims.""" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the retrocookie package.""" 2 | -------------------------------------------------------------------------------- /tests/pr/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the retrocookie.pr package.""" 2 | -------------------------------------------------------------------------------- /tests/pr/unit/fakes/__init__.py: -------------------------------------------------------------------------------- 1 | """Fakes for retrocookie.pr.""" 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | furo==2022.1.2 2 | sphinx==4.5.0 3 | sphinx-click==3.0.3 4 | -------------------------------------------------------------------------------- /src/retrocookie/pr/protocols/__init__.py: -------------------------------------------------------------------------------- 1 | """Protocols for retrocookie.pr.""" 2 | -------------------------------------------------------------------------------- /tests/pr/unit/base/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the retrocookie.pr.base package.""" 2 | -------------------------------------------------------------------------------- /src/retrocookie/pr/base/__init__.py: -------------------------------------------------------------------------------- 1 | """General-purpose modules without dependencies.""" 2 | -------------------------------------------------------------------------------- /tests/pr/unit/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the retrocookie.pr.adapters package.""" 2 | -------------------------------------------------------------------------------- /src/retrocookie/pr/__init__.py: -------------------------------------------------------------------------------- 1 | """Package initialization.""" 2 | appname = __package__.replace(".", "-") 3 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | 4 | .. click:: retrocookie.__main__:main 5 | :prog: retrocookie 6 | :nested: full 7 | -------------------------------------------------------------------------------- /.github/workflows/constraints.txt: -------------------------------------------------------------------------------- 1 | pip==22.0.2 2 | nox==2022.1.7 3 | nox-poetry==0.9.0 4 | poetry==1.1.13 5 | virtualenv==20.13.0 6 | -------------------------------------------------------------------------------- /docs/usage-pr.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | 4 | .. click:: retrocookie.pr.__main__:main 5 | :prog: retrocookie-pr 6 | :nested: full 7 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | :end-before: github-only 3 | 4 | .. _Code of Conduct: codeofconduct.html 5 | -------------------------------------------------------------------------------- /docs/reference.rst: -------------------------------------------------------------------------------- 1 | Reference 2 | ========= 3 | 4 | 5 | retrocookie 6 | ----------- 7 | 8 | .. autofunction:: retrocookie.retrocookie 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .mypy_cache/ 2 | /.coverage 3 | /.coverage.* 4 | /.nox/ 5 | /.python-version 6 | /.pytype/ 7 | /dist/ 8 | /docs/_build/ 9 | /src/*.egg-info/ 10 | __pycache__/ 11 | -------------------------------------------------------------------------------- /src/retrocookie/__init__.py: -------------------------------------------------------------------------------- 1 | """Retrocookie imports commits from template instances into the template.""" 2 | from .core import retrocookie 3 | 4 | 5 | __all__ = ["retrocookie"] 6 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | coverage: 3 | status: 4 | project: 5 | default: 6 | target: "90" 7 | patch: 8 | default: 9 | target: "100" 10 | -------------------------------------------------------------------------------- /src/retrocookie/pr/compat/typing.py: -------------------------------------------------------------------------------- 1 | """Compatibility shims for typing.""" 2 | import sys 3 | 4 | 5 | if sys.version_info >= (3, 8): 6 | from typing import Protocol 7 | else: 8 | from typing_extensions import Protocol 9 | 10 | __all__ = ["Protocol"] 11 | -------------------------------------------------------------------------------- /tests/pr/unit/fakes/retrocookie.py: -------------------------------------------------------------------------------- 1 | """Fake for retrocookie.retrocookie.""" 2 | from pathlib import Path 3 | 4 | 5 | def retrocookie(repository: Path, *, branch: str, upstream: str, path: Path) -> None: 6 | """Import commits from instance repository into template repository.""" 7 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | sphinx: 3 | configuration: docs/conf.py 4 | formats: all 5 | python: 6 | version: 3.9 7 | install: 8 | - requirements: docs/requirements.txt 9 | - method: pip 10 | path: . 11 | extra_requirements: 12 | - pr 13 | build: 14 | image: testing 15 | -------------------------------------------------------------------------------- /.cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "_template": "https://github.com/cjolowicz/cookiecutter-hypermodern-python.git", 3 | "author": "Claudio Jolowicz", 4 | "email": "mail@claudiojolowicz.com", 5 | "friendly_name": "Retrocookie", 6 | "github_user": "cjolowicz", 7 | "package_name": "retrocookie", 8 | "project_name": "retrocookie", 9 | "version": "0.1.0" 10 | } 11 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | """Sphinx configuration.""" 2 | from datetime import datetime 3 | 4 | 5 | project = "Retrocookie" 6 | author = "Claudio Jolowicz" 7 | copyright = f"{datetime.now().year}, {author}" 8 | extensions = [ 9 | "sphinx.ext.autodoc", 10 | "sphinx.ext.napoleon", 11 | "sphinx_click", 12 | ] 13 | autodoc_typehints = "description" 14 | html_theme = "furo" 15 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,B9,C,D,DAR,E,F,N,RST,S,W 3 | ignore = E203,E501,RST201,RST203,RST301,W503 4 | max-line-length = 80 5 | max-complexity = 10 6 | docstring-convention = google 7 | per-file-ignores = 8 | src/*/__main__.py:D301 9 | tests/*:S101 10 | tests/test_git.py:S101,E241 11 | rst-roles = class,const,func,meth,mod,ref 12 | rst-directives = deprecated 13 | -------------------------------------------------------------------------------- /src/retrocookie/pr/adapters/retrocookie.py: -------------------------------------------------------------------------------- 1 | """Retrocookie interface.""" 2 | from pathlib import Path 3 | 4 | from retrocookie import retrocookie as _retrocookie 5 | 6 | 7 | def retrocookie(repository: Path, *, branch: str, upstream: str, path: Path) -> None: 8 | """Wrap retrocookie.retrocookie with protocols.Retrocookie signature.""" 9 | _retrocookie(repository, branch=branch, upstream=upstream, path=path) 10 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: Labeler 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | 9 | jobs: 10 | labeler: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out the repository 14 | uses: actions/checkout@v3.0.2 15 | 16 | - name: Run Labeler 17 | uses: crazy-max/ghaction-github-labeler@v3.1.1 18 | with: 19 | skip-delete: true 20 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | :end-before: github-only 3 | 4 | .. _Contributor Guide: contributing.html 5 | .. _Usage: usage.html 6 | 7 | .. toctree:: 8 | :hidden: 9 | :maxdepth: 1 10 | 11 | retrocookie 12 | retrocookie-pr 13 | reference 14 | contributing 15 | Code of Conduct 16 | License 17 | Changelog 18 | -------------------------------------------------------------------------------- /tests/pr/unit/test_repository.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.repository.""" 2 | from retrocookie.pr.cache import Cache 3 | from retrocookie.pr.protocols import github 4 | from retrocookie.pr.repository import Repository 5 | 6 | 7 | def test_repository(api: github.API, cache: Cache) -> None: 8 | """It loads the repository.""" 9 | repository = Repository.load("owner/name", api=api, cache=cache) 10 | assert repository.clone.path.exists() 11 | -------------------------------------------------------------------------------- /src/retrocookie/pr/protocols/retrocookie.py: -------------------------------------------------------------------------------- 1 | """Protocol for retrocookie.retrocookie.""" 2 | from pathlib import Path 3 | 4 | from retrocookie.pr.compat.typing import Protocol 5 | 6 | 7 | class Retrocookie(Protocol): 8 | """The retrocookie function.""" 9 | 10 | def __call__( 11 | self, repository: Path, *, branch: str, upstream: str, path: Path 12 | ) -> None: 13 | """Import commits from instance repository into template repository.""" 14 | -------------------------------------------------------------------------------- /src/retrocookie/pr/compat/shlex.py: -------------------------------------------------------------------------------- 1 | """Compatibility shims for shlex.""" 2 | import sys 3 | 4 | 5 | if sys.version_info >= (3, 8): 6 | from shlex import join 7 | else: 8 | from shlex import quote 9 | from typing import Iterable 10 | 11 | def join(split_command: Iterable[str]) -> str: 12 | """Return a shell-escaped string from *split_command*.""" 13 | return " ".join(quote(arg) for arg in split_command) 14 | 15 | 16 | __all__ = ["join"] 17 | -------------------------------------------------------------------------------- /tests/pr/unit/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for testing.""" 2 | from typing import Iterator 3 | from typing import Type 4 | 5 | import pytest 6 | 7 | from retrocookie.compat import contextlib 8 | from retrocookie.pr.base import bus 9 | 10 | 11 | @contextlib.contextmanager 12 | def raises(event_type: Type[bus.Event]) -> Iterator[None]: 13 | """It raises the event on the bus.""" 14 | with pytest.raises(bus.Error) as exception_info: 15 | yield 16 | event = exception_info.value.event 17 | del exception_info 18 | assert isinstance(event, event_type) 19 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | - package-ecosystem: pip 8 | directory: "/.github/workflows" 9 | schedule: 10 | interval: daily 11 | - package-ecosystem: pip 12 | directory: "/docs" 13 | schedule: 14 | interval: daily 15 | - package-ecosystem: pip 16 | directory: "/" 17 | schedule: 18 | interval: daily 19 | versioning-strategy: lockfile-only 20 | allow: 21 | - dependency-type: "all" 22 | -------------------------------------------------------------------------------- /tests/pr/unit/data.py: -------------------------------------------------------------------------------- 1 | """Test data.""" 2 | from tests.pr.unit.fakes import github 3 | 4 | 5 | EXAMPLE_PROJECT = github.Repository("owner", "project") 6 | EXAMPLE_TEMPLATE = github.Repository("owner", "template") 7 | EXAMPLE_PROJECT_PULL = github.PullRequest( 8 | 1, 9 | "title", 10 | "body", 11 | "branch", 12 | "owner", 13 | "https://github.com/owner/project/pull/1", 14 | ) 15 | EXAMPLE_TEMPLATE_PULL = github.PullRequest( 16 | 1, 17 | "title", 18 | "body", 19 | "retrocookie-pr/branch", 20 | "owner", 21 | "https://github.com/owner/template/pull/1", 22 | ) 23 | -------------------------------------------------------------------------------- /cutty.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": { 3 | "location": "https://github.com/cjolowicz/cookiecutter-hypermodern-python", 4 | "revision": "8f800d4f1e5b0e91d58584439c673858e34a5627", 5 | "directory": null 6 | }, 7 | "bindings": { 8 | "project_name": "retrocookie", 9 | "package_name": "retrocookie", 10 | "friendly_name": "Retrocookie", 11 | "author": "Claudio Jolowicz", 12 | "email": "mail@claudiojolowicz.com", 13 | "github_user": "cjolowicz", 14 | "version": "0.1.0", 15 | "license": "MIT", 16 | "development_status": "Development Status :: 3 - Alpha" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/retrocookie/compat/contextlib.py: -------------------------------------------------------------------------------- 1 | """Compatibility shim for contextlib.""" 2 | import contextlib 3 | from typing import Callable 4 | from typing import ContextManager 5 | from typing import Iterator 6 | from typing import TypeVar 7 | 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def contextmanager( 13 | func: Callable[..., Iterator[T]] 14 | ) -> Callable[..., ContextManager[T]]: 15 | """Fix annotations of functions decorated by contextlib.contextmanager.""" 16 | result = contextlib.contextmanager(func) 17 | result.__annotations__ = {**func.__annotations__, "return": ContextManager[T]} 18 | return result 19 | -------------------------------------------------------------------------------- /tests/pr/functional/test_github.py: -------------------------------------------------------------------------------- 1 | """Functional tests for retrocookie.pr.adapters.github.""" 2 | from pathlib import Path 3 | 4 | from retrocookie.pr.adapters import github 5 | from tests.pr.functional.conftest import CreatePullRequest 6 | from tests.pr.functional.conftest import skip_without_token 7 | 8 | 9 | @skip_without_token 10 | def test_html_url( 11 | create_project_pull_request: CreatePullRequest, 12 | ) -> None: 13 | """It imports pull requests of the given user.""" 14 | project_pull = create_project_pull_request( 15 | Path("README.md"), 16 | "# Welcome to example", 17 | ) 18 | 19 | pull_request = github.PullRequest(project_pull) 20 | 21 | assert pull_request.html_url.startswith("https://github.com/") 22 | -------------------------------------------------------------------------------- /src/retrocookie/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities.""" 2 | import os 3 | from pathlib import Path 4 | from typing import Iterator 5 | 6 | from retrocookie.compat import contextlib 7 | 8 | 9 | def removeprefix(string: str, prefix: str) -> str: 10 | """Remove prefix from string, if present.""" 11 | return string[len(prefix) :] if string.startswith(prefix) else string 12 | 13 | 14 | def removesuffix(string: str, suffix: str) -> str: 15 | """Remove suffix from string, if present.""" 16 | return string[: -len(suffix)] if suffix and string.endswith(suffix) else string 17 | 18 | 19 | @contextlib.contextmanager 20 | def chdir(path: Path) -> Iterator[None]: 21 | """Context manager for changing the directory.""" 22 | cwd = Path.cwd() 23 | os.chdir(path) 24 | 25 | try: 26 | yield 27 | finally: 28 | os.chdir(cwd) 29 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | categories: 2 | - title: ":boom: Breaking Changes" 3 | label: "breaking" 4 | - title: ":rocket: Features" 5 | label: "enhancement" 6 | - title: ":fire: Removals and Deprecations" 7 | label: "removal" 8 | - title: ":beetle: Fixes" 9 | label: "bug" 10 | - title: ":racehorse: Performance" 11 | label: "performance" 12 | - title: ":rotating_light: Testing" 13 | label: "testing" 14 | - title: ":construction_worker: Continuous Integration" 15 | label: "ci" 16 | - title: ":books: Documentation" 17 | label: "documentation" 18 | - title: ":hammer: Refactoring" 19 | label: "refactoring" 20 | - title: ":lipstick: Style" 21 | label: "style" 22 | - title: ":package: Dependencies" 23 | labels: 24 | - "dependencies" 25 | - "build" 26 | template: | 27 | ## Changes 28 | 29 | $CHANGES 30 | -------------------------------------------------------------------------------- /src/retrocookie/pr/repository.py: -------------------------------------------------------------------------------- 1 | """High-level repository abstraction.""" 2 | from __future__ import annotations 3 | 4 | from dataclasses import dataclass 5 | 6 | from retrocookie import git 7 | from retrocookie.pr.cache import Cache 8 | from retrocookie.pr.protocols import github as gh 9 | 10 | 11 | @dataclass 12 | class Repository: 13 | """High-level repository abstraction.""" 14 | 15 | github: gh.Repository 16 | clone: git.Repository 17 | 18 | @classmethod 19 | def load( 20 | cls, 21 | repository: str, 22 | *, 23 | api: gh.API, 24 | cache: Cache, 25 | ) -> Repository: 26 | """Load all the data associated with a GitHub repository.""" 27 | owner, name = repository.split("/", 1) 28 | github = api.repository(owner, name) 29 | clone = cache.repository(github.clone_url) 30 | return cls(github, clone) 31 | -------------------------------------------------------------------------------- /LICENSE.rst: -------------------------------------------------------------------------------- 1 | MIT License 2 | =========== 3 | 4 | Copyright © 2020 Claudio Jolowicz 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | **The software is provided "as is", without warranty of any kind, express or 17 | implied, including but not limited to the warranties of merchantability, 18 | fitness for a particular purpose and noninfringement. In no event shall the 19 | authors or copyright holders be liable for any claim, damages or other 20 | liability, whether in an action of contract, tort or otherwise, arising from, 21 | out of or in connection with the software or the use or other dealings in the 22 | software.** 23 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.utils.""" 2 | import pytest 3 | 4 | from retrocookie.utils import removeprefix 5 | from retrocookie.utils import removesuffix 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "string, prefix, expected", 10 | [ 11 | ("Lorem ipsum", "Lorem", " ipsum"), 12 | ("Lorem ipsum", "", "Lorem ipsum"), 13 | ("Lorem ipsum", "Lorem ipsum", ""), 14 | ("Lorem ipsum", "x", "Lorem ipsum"), 15 | ("", "Lorem ipsum", ""), 16 | ("", "", ""), 17 | ], 18 | ) 19 | def test_removeprefix(string: str, prefix: str, expected: str) -> None: 20 | """It removes the prefix.""" 21 | assert expected == removeprefix(string, prefix) 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "string, suffix, expected", 26 | [ 27 | ("Lorem ipsum", "ipsum", "Lorem "), 28 | ("Lorem ipsum", "", "Lorem ipsum"), 29 | ("Lorem ipsum", "Lorem ipsum", ""), 30 | ("Lorem ipsum", "x", "Lorem ipsum"), 31 | ("", "Lorem ipsum", ""), 32 | ("", "", ""), 33 | ], 34 | ) 35 | def test_removesuffix(string: str, suffix: str, expected: str) -> None: 36 | """It removes the suffix.""" 37 | assert expected == removesuffix(string, suffix) 38 | -------------------------------------------------------------------------------- /tests/test_filter.py: -------------------------------------------------------------------------------- 1 | """Tests for filter module.""" 2 | from typing import Any 3 | from typing import Dict 4 | from typing import List 5 | from typing import Tuple 6 | 7 | import pytest 8 | 9 | from retrocookie.filter import escape_jinja 10 | from retrocookie.filter import get_replacements 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "text, expected", 15 | [ 16 | ("", ""), 17 | ("{{}}", '{{"{{"}}{{"}}"}}'), 18 | ("${{ matrix.os }}", '${{"{{"}} matrix.os {{"}}"}}'), 19 | ("lorem {# ipsum #} dolor", 'lorem {{"{#"}} ipsum {{"#}"}} dolor'), 20 | ("{{ a }} {{ b }}", '{{"{{"}} a {{"}}"}} {{"{{"}} b {{"}}"}}'), 21 | ], 22 | ) 23 | def test_escape_jinja(text: str, expected: str) -> None: 24 | """It returns the expected result.""" 25 | assert expected == escape_jinja(text.encode()).decode() 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "context, expected", 30 | [ 31 | ( 32 | {"test_key": "a test string", "other_key": ["this", "is", "a list"]}, 33 | [(b"a test string", b"{{cookiecutter.test_key}}")], 34 | ), 35 | ( 36 | {"test_key": "a test string", "other_key": None}, 37 | [(b"a test string", b"{{cookiecutter.test_key}}")], 38 | ), 39 | ], 40 | ) 41 | def test_get_replacements( 42 | context: Dict[str, Any], expected: List[Tuple[bytes, bytes]] 43 | ) -> None: 44 | """It ignore non string values.""" 45 | assert expected == get_replacements(context, [], []) 46 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | """Utilities for testing.""" 2 | import contextlib 3 | from pathlib import Path 4 | from typing import cast 5 | from typing import Iterator 6 | 7 | from retrocookie import git 8 | 9 | 10 | def commit(repository: git.Repository, message: str = "") -> str: 11 | """Create a commit and return the hash.""" 12 | repository.commit(message) 13 | return cast(str, repository.repo.head.target.hex) 14 | 15 | 16 | def write(repository: git.Repository, path: Path, text: str) -> str: 17 | """Write file in repository.""" 18 | path = repository.path / path 19 | path.write_text(text) 20 | repository.add(path) 21 | return commit(repository, f"Update {path.name}") 22 | 23 | 24 | def touch(repository: git.Repository, path: Path) -> str: 25 | """Create an empty file in a repository.""" 26 | path = repository.path / path 27 | path.touch() 28 | repository.add(path) 29 | return commit(repository, f"Touch {path.name}") 30 | 31 | 32 | def append(repository: git.Repository, path: Path, text: str) -> None: 33 | """Append to file in repository.""" 34 | text = repository.read_text(path) + text 35 | write(repository, path, text) 36 | 37 | 38 | @contextlib.contextmanager 39 | def branch( 40 | repository: git.Repository, branch: str, *, create: bool = False 41 | ) -> Iterator[None]: 42 | """Switch to branch.""" 43 | original = repository.get_current_branch() 44 | 45 | if create and not repository.exists_branch(branch): 46 | repository.create_branch(branch) 47 | 48 | repository.switch_branch(branch) 49 | 50 | try: 51 | yield 52 | finally: 53 | repository.switch_branch(original) 54 | -------------------------------------------------------------------------------- /src/retrocookie/pr/list.py: -------------------------------------------------------------------------------- 1 | """Listing pull requests.""" 2 | import contextlib 3 | from typing import Collection 4 | from typing import Iterator 5 | from typing import Optional 6 | 7 | from retrocookie.pr import events 8 | from retrocookie.pr.base.bus import Bus 9 | from retrocookie.pr.protocols import github 10 | 11 | 12 | def get_pull_request( 13 | repository: github.Repository, 14 | spec: str, 15 | *, 16 | bus: Bus, 17 | ) -> github.PullRequest: 18 | """Return the pull request matching the provided spec.""" 19 | with contextlib.suppress(ValueError): 20 | number = int(spec) 21 | return repository.pull_request(number) 22 | 23 | pull = repository.pull_request_by_head(f"{repository.owner}:{spec}") 24 | if pull is not None: 25 | return pull 26 | 27 | bus.events.raise_(events.PullRequestNotFound(spec)) 28 | 29 | 30 | def get_pull_requests( 31 | repository: github.Repository, 32 | specs: Collection[str] = (), 33 | *, 34 | bus: Bus, 35 | ) -> Iterator[github.PullRequest]: 36 | """Return pull requests. With specs, filter those matching specs.""" 37 | if specs: 38 | for spec in specs: 39 | yield get_pull_request(repository, spec, bus=bus) 40 | else: 41 | yield from repository.pull_requests() 42 | 43 | 44 | def list_pull_requests( 45 | repository: github.Repository, 46 | specs: Collection[str] = (), 47 | *, 48 | user: Optional[str] = None, 49 | bus: Bus, 50 | ) -> Iterator[github.PullRequest]: 51 | """List matching pull requests in repository.""" 52 | for pull in get_pull_requests(repository, specs, bus=bus): 53 | if user is None or pull.user == user: 54 | yield pull 55 | -------------------------------------------------------------------------------- /tests/pr/unit/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for retrocookie.pr.""" 2 | import sys 3 | from pathlib import Path 4 | from typing import Callable 5 | from typing import Iterator 6 | 7 | import pytest 8 | from pytest import MonkeyPatch 9 | 10 | from retrocookie.pr.base.bus import Bus 11 | from retrocookie.pr.cache import Cache 12 | from retrocookie.pr.protocols import github 13 | from retrocookie.pr.protocols.retrocookie import Retrocookie 14 | from retrocookie.pr.repository import Repository 15 | from tests.pr.unit.fakes.github import API as FakeAPI # noqa: N811 16 | from tests.pr.unit.fakes.retrocookie import retrocookie 17 | 18 | 19 | @pytest.fixture 20 | def bus() -> Iterator[Bus]: 21 | """Return the bus.""" 22 | bus = Bus() 23 | with bus.events.errorhandler(): 24 | yield bus 25 | 26 | 27 | @pytest.fixture 28 | def cache(tmp_path: Path) -> Cache: 29 | """Return a cache.""" 30 | return Cache(tmp_path / "cache") 31 | 32 | 33 | @pytest.fixture 34 | def api(tmp_path: Path) -> github.API: 35 | """Return a fake GitHub API.""" 36 | return FakeAPI(_backend=tmp_path / "github") 37 | 38 | 39 | @pytest.fixture(name="retrocookie") 40 | def fixture_retrocookie() -> Retrocookie: 41 | """Return a fake retrocookie function.""" 42 | return retrocookie 43 | 44 | 45 | @pytest.fixture 46 | def repository(api: github.API, cache: Cache) -> Callable[[str], Repository]: 47 | """Return a repository factory.""" 48 | 49 | def _repository(fullname: str) -> Repository: 50 | return Repository.load(fullname, api=api, cache=cache) 51 | 52 | return _repository 53 | 54 | 55 | @pytest.fixture(autouse=sys.platform == "win32") 56 | def mock_cache_repository_path(monkeypatch: MonkeyPatch) -> None: 57 | """Avoid errors due to excessively long paths on Windows.""" 58 | monkeypatch.setattr("retrocookie.pr.cache.DIGEST_SIZE", 3) 59 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: black 5 | name: black 6 | entry: black 7 | language: system 8 | types: [python] 9 | require_serial: true 10 | - id: check-added-large-files 11 | name: Check for added large files 12 | entry: check-added-large-files 13 | language: system 14 | - id: check-toml 15 | name: Check Toml 16 | entry: check-toml 17 | language: system 18 | types: [toml] 19 | - id: check-yaml 20 | name: Check Yaml 21 | entry: check-yaml 22 | language: system 23 | types: [yaml] 24 | - id: end-of-file-fixer 25 | name: Fix End of Files 26 | entry: end-of-file-fixer 27 | language: system 28 | types: [text] 29 | stages: [commit, push, manual] 30 | - id: flake8 31 | name: flake8 32 | entry: flake8 33 | language: system 34 | types: [python] 35 | require_serial: true 36 | - id: isort 37 | name: isort 38 | entry: isort 39 | require_serial: true 40 | language: system 41 | types_or: [cython, pyi, python] 42 | args: ["--filter-files"] 43 | - id: pyupgrade 44 | name: pyupgrade 45 | description: Automatically upgrade syntax for newer versions. 46 | entry: pyupgrade 47 | language: system 48 | types: [python] 49 | args: [--py37-plus, --keep-runtime-typing] 50 | - id: trailing-whitespace 51 | name: Trim Trailing Whitespace 52 | entry: trailing-whitespace-fixer 53 | language: system 54 | types: [text] 55 | stages: [commit, push, manual] 56 | - repo: https://github.com/pre-commit/mirrors-prettier 57 | rev: v2.3.0 58 | hooks: 59 | - id: prettier 60 | -------------------------------------------------------------------------------- /.github/labels.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Labels names are important as they are used by Release Drafter to decide 3 | # regarding where to record them in changelog or if to skip them. 4 | # 5 | # The repository labels will be automatically configured using this file and 6 | # the GitHub Action https://github.com/marketplace/actions/github-labeler. 7 | - name: breaking 8 | description: Breaking Changes 9 | color: bfd4f2 10 | - name: bug 11 | description: Something isn't working 12 | color: d73a4a 13 | - name: build 14 | description: Build System and Dependencies 15 | color: bfdadc 16 | - name: ci 17 | description: Continuous Integration 18 | color: 4a97d6 19 | - name: dependencies 20 | description: Pull requests that update a dependency file 21 | color: 0366d6 22 | - name: documentation 23 | description: Improvements or additions to documentation 24 | color: 0075ca 25 | - name: duplicate 26 | description: This issue or pull request already exists 27 | color: cfd3d7 28 | - name: enhancement 29 | description: New feature or request 30 | color: a2eeef 31 | - name: github_actions 32 | description: Pull requests that update Github_actions code 33 | color: "000000" 34 | - name: good first issue 35 | description: Good for newcomers 36 | color: 7057ff 37 | - name: help wanted 38 | description: Extra attention is needed 39 | color: 008672 40 | - name: invalid 41 | description: This doesn't seem right 42 | color: e4e669 43 | - name: performance 44 | description: Performance 45 | color: "016175" 46 | - name: python 47 | description: Pull requests that update Python code 48 | color: 2b67c6 49 | - name: question 50 | description: Further information is requested 51 | color: d876e3 52 | - name: refactoring 53 | description: Refactoring 54 | color: ef67c4 55 | - name: removal 56 | description: Removals and Deprecations 57 | color: 9ae7ea 58 | - name: style 59 | description: Style 60 | color: c120e5 61 | - name: testing 62 | description: Testing 63 | color: b1fc6f 64 | - name: wontfix 65 | description: This will not be worked on 66 | color: ffffff 67 | -------------------------------------------------------------------------------- /tests/pr/unit/adapters/test_github.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.github.""" 2 | from typing import Any 3 | 4 | import github3.exceptions 5 | import pytest 6 | import requests.exceptions 7 | 8 | from retrocookie.pr import events 9 | from retrocookie.pr.adapters import github 10 | from retrocookie.pr.base.bus import Bus 11 | from tests.pr.unit.utils import raises 12 | 13 | 14 | class FakeRequest: 15 | """A fake for requests.Request.""" 16 | 17 | url = "https://api.github.com/repos/owner/name" 18 | method = "GET" 19 | 20 | 21 | class FakeResponse: 22 | """A fake for requests.Response.""" 23 | 24 | request = FakeRequest() 25 | status_code = 403 26 | 27 | def json(self) -> Any: 28 | """Return the JSON body.""" 29 | return {"message": "Forbidden"} 30 | 31 | 32 | def test_errorhandler_github_with_request(bus: Bus) -> None: 33 | """It raises a GitHubError event.""" 34 | response = FakeResponse() 35 | 36 | with raises(events.GitHubError): 37 | with github.errorhandler(bus=bus): 38 | raise github3.exceptions.GitHubError(response) 39 | 40 | 41 | def test_errorhandler_github_without_request(bus: Bus) -> None: 42 | """It reraises the exception.""" 43 | with pytest.raises(github3.exceptions.GitHubError): 44 | with github.errorhandler(bus=bus): 45 | raise github3.exceptions.IncompleteResponse(json={}, exception=None) 46 | 47 | 48 | def test_errorhandler_connection_with_requests_exception(bus: Bus) -> None: 49 | """It raises a ConnectionError event.""" 50 | response = FakeResponse() 51 | cause = requests.RequestException(response=response) 52 | 53 | with raises(events.ConnectionError): 54 | with github.errorhandler(bus=bus): 55 | raise github3.exceptions.ConnectionError(cause) 56 | 57 | 58 | def test_errorhandler_connection_without_requests_exception(bus: Bus) -> None: 59 | """It reraises the exception.""" 60 | cause = Exception() 61 | 62 | with pytest.raises(github3.exceptions.ConnectionError): 63 | with github.errorhandler(bus=bus): 64 | raise github3.exceptions.ConnectionError(cause) 65 | -------------------------------------------------------------------------------- /tests/pr/unit/test_cache.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.cache.""" 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from retrocookie import git 7 | from retrocookie.pr.cache import Cache 8 | from tests.helpers import commit 9 | 10 | 11 | @pytest.fixture 12 | def repository(tmp_path: Path) -> git.Repository: 13 | """Return a repository.""" 14 | repository = git.Repository.init(tmp_path / "repository") 15 | repository.commit("") 16 | return repository 17 | 18 | 19 | @pytest.fixture 20 | def url(repository: git.Repository) -> str: 21 | """Return a repository URL.""" 22 | return f"file://{repository.path.resolve()}" 23 | 24 | 25 | def test_repository_clone(cache: Cache, url: str) -> None: 26 | """It clones the repository within the cache directory.""" 27 | clone = cache.repository(url) 28 | assert cache.path in clone.path.parents 29 | 30 | 31 | def test_repository_idempotent(cache: Cache, url: str) -> None: 32 | """It returns the same repository as before.""" 33 | clone1 = cache.repository(url) 34 | clone2 = cache.repository(url) 35 | assert clone1.path == clone2.path 36 | 37 | 38 | def test_repository_update(cache: Cache, url: str, repository: git.Repository) -> None: 39 | """It updates the repository from its remote.""" 40 | cache.repository(url) 41 | sha1 = commit(repository) 42 | clone = cache.repository(url) 43 | assert sha1 == clone.repo.head.target.hex 44 | 45 | 46 | def test_worktree(cache: Cache, url: str) -> None: 47 | """It creates a worktree for the given branch.""" 48 | repository = cache.repository(url) 49 | with cache.worktree(repository, "branch") as worktree: 50 | assert "branch" == worktree.get_current_branch() 51 | 52 | 53 | def test_token_roundtrip(cache: Cache) -> None: 54 | """It saves and loads the token.""" 55 | cache.save_token("token") 56 | assert "token" == cache.load_token() 57 | 58 | 59 | def test_load_invalid_token(cache: Cache) -> None: 60 | """It rejects invalid tokens on load.""" 61 | cache.save_token("token") 62 | path = cache.path / "token.json" 63 | path.write_text('{"token": 666}') 64 | with pytest.raises(TypeError): 65 | cache.load_token() 66 | -------------------------------------------------------------------------------- /tests/pr/unit/base/test_bus.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.bus.""" 2 | import contextlib 3 | from typing import ContextManager 4 | 5 | import pytest 6 | 7 | from retrocookie.pr.base.bus import Bus 8 | from retrocookie.pr.base.bus import Context 9 | from retrocookie.pr.base.bus import Error 10 | from retrocookie.pr.base.bus import Event 11 | 12 | 13 | def test_events_publish() -> None: 14 | """It invokes the handler.""" 15 | events = [] 16 | bus = Bus() 17 | 18 | @bus.events.subscribe 19 | def handler(event: Event) -> None: 20 | events.append(event) 21 | 22 | event = Event() 23 | bus.events.publish(event) 24 | [seen] = events 25 | assert event is seen 26 | 27 | 28 | def test_contexts_publish() -> None: 29 | """It invokes the handler.""" 30 | contexts = [] 31 | bus = Bus() 32 | 33 | @bus.contexts.subscribe 34 | def handler(context: Context) -> ContextManager[None]: 35 | contexts.append(context) 36 | return contextlib.nullcontext() 37 | 38 | context = Context() 39 | with bus.contexts.publish(context): 40 | pass 41 | [seen] = contexts 42 | assert context is seen 43 | 44 | 45 | def test_events_subscribe_without_annotations() -> None: 46 | """It fails when the handler has no type annotations.""" 47 | bus = Bus() 48 | with pytest.raises(Exception): 49 | 50 | @bus.events.subscribe 51 | def handler(event): # type: ignore[no-untyped-def] 52 | pass 53 | 54 | 55 | def test_contexts_subscribe_without_annotations() -> None: 56 | """It fails when the handler has no type annotations.""" 57 | bus = Bus() 58 | with pytest.raises(Exception): 59 | 60 | @bus.contexts.subscribe 61 | def handler(event): # type: ignore[no-untyped-def] 62 | pass 63 | 64 | 65 | def test_events_errorhandler_raise() -> None: 66 | """It raises an Error when an event is "raised".""" 67 | bus = Bus() 68 | with pytest.raises(Error): 69 | with bus.events.errorhandler(): 70 | bus.events.raise_(Event()) 71 | 72 | 73 | def test_events_errorhandler_reraise() -> None: 74 | """It raises an Error when an event is "reraised".""" 75 | bus = Bus() 76 | with pytest.raises(Error): 77 | with bus.events.errorhandler(): 78 | with bus.events.reraise(Event(), when=Exception): 79 | raise Exception("Boom") 80 | -------------------------------------------------------------------------------- /src/retrocookie/pr/cache.py: -------------------------------------------------------------------------------- 1 | """Application cache.""" 2 | import hashlib 3 | import json 4 | import sys 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Iterator 8 | 9 | from retrocookie import git 10 | from retrocookie.compat import contextlib 11 | 12 | 13 | # Windows does not properly support files and directories longer than 14 | # 260 characters. Use a smaller digest size on this platform. 15 | DIGEST_SIZE: int = 64 if sys.platform != "win32" else 32 16 | 17 | 18 | @dataclass 19 | class Cache: 20 | """Application cache.""" 21 | 22 | path: Path 23 | 24 | def save_token(self, token: str) -> None: 25 | """Save a token.""" 26 | data = {"token": token} 27 | text = json.dumps(data) 28 | path = self.path / "token.json" 29 | path.parent.mkdir(exist_ok=True, parents=True) 30 | path.touch(mode=0o600) 31 | path.write_text(text) 32 | 33 | def load_token(self) -> str: 34 | """Load a token.""" 35 | path = self.path / "token.json" 36 | text = path.read_text() 37 | data = json.loads(text) 38 | token = data["token"] 39 | if not isinstance(token, str): 40 | raise TypeError("invalid token") 41 | return token 42 | 43 | def _repository_path(self, url: str) -> Path: 44 | h = hashlib.blake2b(url.encode(), digest_size=DIGEST_SIZE).hexdigest() 45 | return self.path / "repositories" / h[:2] / h[2:] / "repo.git" 46 | 47 | def repository(self, url: str) -> git.Repository: 48 | """Clone or update repository.""" 49 | path = self._repository_path(url) 50 | 51 | if not path.exists(): 52 | return git.Repository.clone(url, path, mirror=True) 53 | 54 | repository = git.Repository(path) 55 | repository.update_remote() 56 | return repository 57 | 58 | @contextlib.contextmanager 59 | def worktree( 60 | self, 61 | repository: git.Repository, 62 | branch: str, 63 | *, 64 | base: str = "HEAD", 65 | force: bool = False, 66 | ) -> Iterator[git.Repository]: 67 | """Context manager to add and remove a worktree.""" 68 | assert self.path in repository.path.parents # noqa: S101 69 | path = repository.path.parent / "worktrees" / branch 70 | 71 | with repository.worktree( 72 | branch, path, base=base, force=force, force_remove=True 73 | ) as worktree: 74 | yield worktree 75 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | 9 | jobs: 10 | release: 11 | name: Release 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out the repository 15 | uses: actions/checkout@v3.0.2 16 | with: 17 | fetch-depth: 2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v3.1.2 21 | with: 22 | python-version: "3.9" 23 | 24 | - name: Upgrade pip 25 | run: | 26 | pip install --constraint=.github/workflows/constraints.txt pip 27 | pip --version 28 | 29 | - name: Install Poetry 30 | run: | 31 | pip install --constraint=.github/workflows/constraints.txt poetry 32 | poetry --version 33 | 34 | - name: Check if there is a parent commit 35 | id: check-parent-commit 36 | run: | 37 | echo "::set-output name=sha::$(git rev-parse --verify --quiet HEAD^)" 38 | 39 | - name: Detect and tag new version 40 | id: check-version 41 | if: steps.check-parent-commit.outputs.sha 42 | uses: salsify/action-detect-and-tag-new-version@v2.0.1 43 | with: 44 | version-command: | 45 | bash -o pipefail -c "poetry version | awk '{ print \$2 }'" 46 | 47 | - name: Bump version for developmental release 48 | if: "! steps.check-version.outputs.tag" 49 | run: | 50 | poetry version patch && 51 | version=$(poetry version | awk '{ print $2 }') && 52 | poetry version $version.dev.$(date +%s) 53 | 54 | - name: Build package 55 | run: | 56 | poetry build --ansi 57 | 58 | - name: Publish package on PyPI 59 | if: steps.check-version.outputs.tag 60 | uses: pypa/gh-action-pypi-publish@v1.5.0 61 | with: 62 | user: __token__ 63 | password: ${{ secrets.PYPI_TOKEN }} 64 | 65 | - name: Publish package on TestPyPI 66 | if: "! steps.check-version.outputs.tag" 67 | uses: pypa/gh-action-pypi-publish@v1.5.0 68 | with: 69 | user: __token__ 70 | password: ${{ secrets.TEST_PYPI_TOKEN }} 71 | repository_url: https://test.pypi.org/legacy/ 72 | 73 | - name: Publish the release notes 74 | uses: release-drafter/release-drafter@v5.20.0 75 | with: 76 | publish: ${{ steps.check-version.outputs.tag != '' }} 77 | tag: ${{ steps.check-version.outputs.tag }} 78 | env: 79 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 80 | -------------------------------------------------------------------------------- /tests/pr/unit/test_importer.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.importer.""" 2 | from typing import Callable 3 | 4 | import pytest 5 | 6 | from retrocookie.pr import appname 7 | from retrocookie.pr import events 8 | from retrocookie.pr.base.bus import Bus 9 | from retrocookie.pr.cache import Cache 10 | from retrocookie.pr.importer import Importer 11 | from retrocookie.pr.protocols.retrocookie import Retrocookie 12 | from retrocookie.pr.repository import Repository 13 | from tests.pr.unit.fakes import github 14 | from tests.pr.unit.utils import raises 15 | 16 | 17 | @pytest.fixture 18 | def template(repository: Callable[[str], Repository]) -> Repository: 19 | """Return the template repository.""" 20 | return repository("owner/template") 21 | 22 | 23 | @pytest.fixture 24 | def project(repository: Callable[[str], Repository]) -> Repository: 25 | """Return the project repository.""" 26 | return repository("owner/project") 27 | 28 | 29 | @pytest.fixture 30 | def importer( 31 | project: Repository, 32 | template: Repository, 33 | bus: Bus, 34 | cache: Cache, 35 | retrocookie: Retrocookie, 36 | ) -> Importer: 37 | """Return an importer.""" 38 | return Importer( 39 | project, 40 | template, 41 | bus=bus, 42 | cache=cache, 43 | retrocookie=retrocookie, 44 | ) 45 | 46 | 47 | @pytest.fixture 48 | def pull_request() -> github.PullRequest: 49 | """Return the pull request.""" 50 | return github.PullRequest(1, "title", "body", "branch", "user") 51 | 52 | 53 | def test_import( 54 | importer: Importer, 55 | pull_request: github.PullRequest, 56 | template: Repository, 57 | cache: Cache, 58 | ) -> None: 59 | """It imports the branch.""" 60 | cache.save_token("token") 61 | 62 | importer.import_(pull_request) 63 | 64 | [imported] = template.github.pull_requests() 65 | 66 | assert imported.branch == f"{appname}/{pull_request.branch}" 67 | 68 | 69 | def test_import_exists_branch( 70 | importer: Importer, 71 | pull_request: github.PullRequest, 72 | template: Repository, 73 | cache: Cache, 74 | ) -> None: 75 | """It imports the branch.""" 76 | cache.save_token("token") 77 | 78 | with raises(events.PullRequestAlreadyExists): 79 | for _ in range(2): 80 | importer.import_(pull_request) 81 | 82 | 83 | def test_import_force( 84 | importer: Importer, 85 | pull_request: github.PullRequest, 86 | template: Repository, 87 | cache: Cache, 88 | ) -> None: 89 | """It imports the branch.""" 90 | cache.save_token("token") 91 | 92 | for force in [False, True]: 93 | importer.import_(pull_request, force=force) 94 | 95 | [imported] = template.github.pull_requests() 96 | 97 | assert imported.branch == f"{appname}/{pull_request.branch}" 98 | -------------------------------------------------------------------------------- /src/retrocookie/pr/protocols/github.py: -------------------------------------------------------------------------------- 1 | """Protocols for retrocookie.pr.github.""" 2 | from __future__ import annotations 3 | 4 | from typing import AbstractSet 5 | from typing import Iterable 6 | from typing import Optional 7 | from typing import Set 8 | 9 | from retrocookie.pr.compat.typing import Protocol 10 | 11 | 12 | class PullRequest(Protocol): 13 | """Pull request.""" 14 | 15 | @property 16 | def number(self) -> int: 17 | """The number of the pull request.""" 18 | 19 | @property 20 | def title(self) -> str: 21 | """The title of the pull request.""" 22 | 23 | @property 24 | def body(self) -> Optional[str]: 25 | """The body of the pull request.""" 26 | 27 | @property 28 | def branch(self) -> str: 29 | """The branch merged by the pull request.""" 30 | 31 | @property 32 | def user(self) -> str: 33 | """The user who opened the pull request.""" 34 | 35 | @property 36 | def html_url(self) -> str: 37 | """URL for viewing the pull request in a browser.""" 38 | 39 | @property 40 | def labels(self) -> Set[str]: 41 | """Labels applied to the pull request.""" 42 | 43 | def update(self, title: str, body: Optional[str], labels: AbstractSet[str]) -> None: 44 | """Update the pull request.""" 45 | 46 | 47 | class Repository(Protocol): 48 | """GitHub repository.""" 49 | 50 | @property 51 | def owner(self) -> str: 52 | """The user who owns the repository.""" 53 | 54 | @property 55 | def full_name(self) -> str: 56 | """The full name of the repository.""" 57 | 58 | @property 59 | def clone_url(self) -> str: 60 | """URL for cloning the repository.""" 61 | 62 | @property 63 | def push_url(self) -> str: 64 | """URL for pushing to the repository.""" 65 | 66 | @property 67 | def default_branch(self) -> str: 68 | """The default branch of the repository.""" 69 | 70 | def pull_requests(self) -> Iterable[PullRequest]: 71 | """Pull requests open in the repository.""" 72 | 73 | def pull_request(self, number: int) -> PullRequest: 74 | """Return pull request identified by the given number.""" 75 | 76 | def pull_request_by_head(self, head: str) -> Optional[PullRequest]: 77 | """Return pull request for the given head.""" 78 | 79 | def create_pull_request( 80 | self, 81 | *, 82 | head: str, 83 | title: str, 84 | body: Optional[str], 85 | labels: AbstractSet[str], 86 | ) -> PullRequest: 87 | """Create a pull request.""" 88 | 89 | 90 | class API(Protocol): 91 | """GitHub API.""" 92 | 93 | @property 94 | def me(self) -> str: 95 | """Return the login of the authenticated user.""" 96 | 97 | def repository(self, owner: str, name: str) -> Repository: 98 | """Return the repository with the given owner and name.""" 99 | -------------------------------------------------------------------------------- /src/retrocookie/pr/events.py: -------------------------------------------------------------------------------- 1 | """Events and contexts for the message bus.""" 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | from retrocookie import git 6 | from retrocookie.pr.base import bus 7 | from retrocookie.pr.protocols import github 8 | 9 | 10 | @dataclass 11 | class GitNotFound(bus.Event): 12 | """Cannot find an installation of git.""" 13 | 14 | 15 | @dataclass 16 | class BadGitVersion(bus.Event): 17 | """The installed version of git is too old.""" 18 | 19 | version: git.Version 20 | expected: git.Version 21 | 22 | 23 | @dataclass 24 | class ProjectNotFound(bus.Event): 25 | """The generated project could not be identified.""" 26 | 27 | 28 | @dataclass 29 | class TemplateNotFound(bus.Event): 30 | """The project template could not be identified.""" 31 | 32 | project: github.Repository 33 | 34 | 35 | @dataclass 36 | class LoadProject(bus.Context): 37 | """The project repository is being loaded.""" 38 | 39 | repository: str 40 | 41 | 42 | @dataclass 43 | class LoadTemplate(bus.Context): 44 | """A template repository is being loaded.""" 45 | 46 | repository: str 47 | 48 | 49 | @dataclass 50 | class RepositoryNotFound(bus.Event): 51 | """The repository was not found.""" 52 | 53 | repository: str 54 | 55 | 56 | @dataclass 57 | class PullRequestNotFound(bus.Event): 58 | """The pull request was not found.""" 59 | 60 | pull_request: str 61 | 62 | 63 | @dataclass 64 | class PullRequestAlreadyExists(bus.Event): 65 | """The pull request already exists.""" 66 | 67 | template: github.Repository 68 | template_pull: github.PullRequest 69 | project_pull: github.PullRequest 70 | 71 | 72 | @dataclass 73 | class GitFailed(bus.Event): 74 | """The git command exited with a non-zero status.""" 75 | 76 | command: str 77 | options: List[str] 78 | status: int 79 | stdout: str 80 | stderr: str 81 | 82 | 83 | @dataclass 84 | class GitHubError(bus.Event): 85 | """The GitHub API returned an error response.""" 86 | 87 | url: str 88 | method: str 89 | code: int 90 | message: str 91 | errors: List[str] 92 | 93 | 94 | @dataclass 95 | class ConnectionError(bus.Event): 96 | """A connection to the GitHub API could not be established.""" 97 | 98 | url: str 99 | method: str 100 | error: str 101 | 102 | 103 | @dataclass 104 | class CreatePullRequest(bus.Context): 105 | """A pull request is being created.""" 106 | 107 | template: github.Repository 108 | template_branch: str 109 | project_pull: github.PullRequest 110 | 111 | 112 | @dataclass 113 | class UpdatePullRequest(bus.Context): 114 | """A pull request is being updated.""" 115 | 116 | template: github.Repository 117 | template_pull: github.PullRequest 118 | project_pull: github.PullRequest 119 | 120 | 121 | @dataclass 122 | class PullRequestCreated(bus.Event): 123 | """A pull request was created.""" 124 | 125 | template: github.Repository 126 | template_pull: github.PullRequest 127 | project_pull: github.PullRequest 128 | -------------------------------------------------------------------------------- /src/retrocookie/pr/importer.py: -------------------------------------------------------------------------------- 1 | """Import pull requests.""" 2 | from dataclasses import dataclass 3 | 4 | from retrocookie.pr import appname 5 | from retrocookie.pr import events 6 | from retrocookie.pr.base.bus import Bus 7 | from retrocookie.pr.base.bus import Context 8 | from retrocookie.pr.cache import Cache 9 | from retrocookie.pr.protocols import github 10 | from retrocookie.pr.protocols.retrocookie import Retrocookie 11 | from retrocookie.pr.repository import Repository 12 | 13 | 14 | @dataclass 15 | class Importer: 16 | """Importer for pull requests.""" 17 | 18 | project: Repository 19 | template: Repository 20 | bus: Bus 21 | cache: Cache 22 | retrocookie: Retrocookie 23 | 24 | def import_( 25 | self, 26 | project_pull: github.PullRequest, 27 | *, 28 | force: bool = False, 29 | ) -> None: 30 | """Import the given pull request from the project into its template.""" 31 | template_branch = f"{appname}/{project_pull.branch}" 32 | head = f"{self.template.github.owner}:{template_branch}" 33 | template_pull = self.template.github.pull_request_by_head(head) 34 | event: Context 35 | 36 | if template_pull is None: 37 | event = events.CreatePullRequest(self.template.github, head, project_pull) 38 | elif force: 39 | event = events.UpdatePullRequest( 40 | self.template.github, template_pull, project_pull 41 | ) 42 | else: 43 | self.bus.events.raise_( 44 | events.PullRequestAlreadyExists( 45 | self.template.github, template_pull, project_pull 46 | ) 47 | ) 48 | 49 | with self.bus.contexts.publish(event): 50 | with self.cache.worktree( 51 | self.template.clone, template_branch, force=True 52 | ) as worktree: 53 | self.retrocookie( 54 | self.project.clone.path, 55 | branch=project_pull.branch, 56 | upstream=self.project.github.default_branch, 57 | path=worktree.path, 58 | ) 59 | # https://stackoverflow.com/a/41153073/1355754 60 | self.template.clone.push( 61 | self.template.github.push_url, f"+{template_branch}", force=force 62 | ) 63 | 64 | if template_pull is None: 65 | template_pull = self.template.github.create_pull_request( 66 | head=head, 67 | title=project_pull.title, 68 | body=project_pull.body, 69 | labels=project_pull.labels, 70 | ) 71 | else: 72 | template_pull.update( 73 | title=project_pull.title, 74 | body=project_pull.body, 75 | labels=project_pull.labels, 76 | ) 77 | 78 | if isinstance(event, events.CreatePullRequest): 79 | self.bus.events.publish( 80 | events.PullRequestCreated( 81 | self.template.github, template_pull, project_pull 82 | ) 83 | ) 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "retrocookie" 3 | version = "0.4.2" 4 | description = "Update Cookiecutter templates with changes from their instances" 5 | authors = ["Claudio Jolowicz "] 6 | license = "MIT" 7 | readme = "README.rst" 8 | homepage = "https://github.com/cjolowicz/retrocookie" 9 | repository = "https://github.com/cjolowicz/retrocookie" 10 | documentation = "https://retrocookie.readthedocs.io" 11 | classifiers = [ 12 | "Development Status :: 3 - Alpha", 13 | "Environment :: Console", 14 | "Intended Audience :: Developers", 15 | "Operating System :: Microsoft :: Windows", 16 | "Operating System :: MacOS", 17 | "Operating System :: POSIX :: Linux", 18 | ] 19 | 20 | [tool.poetry.urls] 21 | Changelog = "https://github.com/cjolowicz/retrocookie/releases" 22 | 23 | [tool.poetry.dependencies] 24 | python = "^3.7" 25 | click = "^8.0.3" 26 | git-filter-repo = "^2.26.0" 27 | pygit2 = "^1.2.1" 28 | typing-extensions = {version = ">=3.7.4,<5.0.0", optional = true} 29 | "github3.py" = {version = "^3.0.0", optional = true} 30 | tenacity = {version = ">=6.3.1,<9.0.0", optional = true} 31 | appdirs = {version = "^1.4.4", optional = true} 32 | rich = {version = ">=9.5.1,<11.0.0", optional = true} 33 | 34 | [tool.poetry.dev-dependencies] 35 | pytest = ">=6.2.5" 36 | coverage = {extras = ["toml"], version = ">=6.1"} 37 | safety = ">=1.10.3" 38 | mypy = ">=0.800" 39 | typeguard = ">=2.9.1" 40 | xdoctest = {extras = ["colors"], version = ">=0.15.2"} 41 | sphinx = ">=4.3.0" 42 | sphinx-autobuild = ">=2020.9.1" 43 | pre-commit = ">=2.15.0" 44 | cookiecutter = ">=1.7.3" 45 | pygments = ">=2.10.0" 46 | flake8 = ">=4.0.1" 47 | black = ">=21.10b0" 48 | flake8-bandit = ">=2.1.2" 49 | flake8-bugbear = ">=21.9.2" 50 | flake8-docstrings = ">=1.6.0" 51 | flake8-rst-docstrings = ">=0.2.3" 52 | pep8-naming = ">=0.11.1" 53 | darglint = ">=1.8.1" 54 | pre-commit-hooks = ">=4.0.1" 55 | sphinx-click = ">=3.0.2" 56 | Pygments = ">=2.10.0" 57 | furo = ">=2021.11.12" 58 | types-requests = ">=2.26.0" 59 | pyupgrade = ">=2.29.0" 60 | isort = "^5.10.1" 61 | 62 | [tool.poetry.extras] 63 | pr = ["appdirs", "github3.py", "rich", "tenacity", "typing-extensions"] 64 | 65 | [tool.poetry.scripts] 66 | retrocookie = "retrocookie.__main__:main" 67 | retrocookie-pr = "retrocookie.pr.__main__:main" 68 | 69 | [tool.coverage.paths] 70 | source = ["src", "*/site-packages"] 71 | tests = ["tests", "*/tests"] 72 | 73 | [tool.coverage.run] 74 | branch = true 75 | source = ["retrocookie", "tests"] 76 | 77 | [tool.coverage.report] 78 | show_missing = true 79 | fail_under = 90 80 | 81 | [tool.isort] 82 | profile = "black" 83 | force_single_line = true 84 | order_by_type = false 85 | lines_after_imports = 2 86 | 87 | [tool.mypy] 88 | strict = true 89 | pretty = true 90 | show_column_numbers = true 91 | show_error_codes = true 92 | show_error_context = true 93 | warn_unreachable = true 94 | 95 | [[tool.mypy.overrides]] 96 | module = [ 97 | "appdirs", 98 | "cookiecutter.*", 99 | "git_filter_repo", 100 | "github3.*", 101 | "pygit2.*", 102 | ] 103 | ignore_missing_imports = true 104 | 105 | [tool.pytest.ini_options] 106 | filterwarnings = [ 107 | "ignore:no type annotations present -- not typechecking .*", 108 | "ignore:no code associated -- not typechecking .*", 109 | ] 110 | 111 | [build-system] 112 | requires = ["poetry-core>=1.0.0"] 113 | build-backend = "poetry.core.masonry.api" 114 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures.""" 2 | from __future__ import annotations 3 | 4 | import json 5 | import textwrap 6 | from pathlib import Path 7 | from typing import Dict 8 | from typing import TYPE_CHECKING 9 | 10 | import pytest 11 | from cookiecutter.main import cookiecutter 12 | 13 | 14 | if TYPE_CHECKING: 15 | from retrocookie import git 16 | 17 | 18 | @pytest.fixture 19 | def context() -> Dict[str, str]: 20 | """Cookiecutter context dictionary.""" 21 | return {"project_slug": "example"} 22 | 23 | 24 | @pytest.fixture 25 | def cookiecutter_path(tmp_path: Path) -> Path: 26 | """Cookiecutter path.""" 27 | path = tmp_path / "cookiecutter" 28 | path.mkdir() 29 | return path 30 | 31 | 32 | @pytest.fixture 33 | def cookiecutter_json(cookiecutter_path: Path, context: Dict[str, str]) -> Path: 34 | """The cookiecutter.json file.""" 35 | path = cookiecutter_path / "cookiecutter.json" 36 | text = json.dumps(context) 37 | path.write_text(text) 38 | return path 39 | 40 | 41 | @pytest.fixture 42 | def cookiecutter_subdirectory(cookiecutter_path: Path) -> Path: 43 | """The template directory in the cookiecutter.""" 44 | path = cookiecutter_path / "{{cookiecutter.project_slug}}" 45 | path.mkdir() 46 | return path 47 | 48 | 49 | @pytest.fixture 50 | def cookiecutter_readme(cookiecutter_subdirectory: Path) -> Path: 51 | """The README file in the cookiecutter.""" 52 | path = cookiecutter_subdirectory / "README.md" 53 | text = """\ 54 | # {{cookiecutter.project_slug}} 55 | 56 | Welcome to `{{cookiecutter.project_slug}}`! 57 | """ 58 | path.write_text(textwrap.dedent(text)) 59 | return path 60 | 61 | 62 | @pytest.fixture 63 | def dot_cookiecutter_json(cookiecutter_subdirectory: Path) -> Path: 64 | """The .cookiecutter.json file in the cookiecutter.""" 65 | path = cookiecutter_subdirectory / ".cookiecutter.json" 66 | text = """\ 67 | {{cookiecutter | jsonify}} 68 | """ 69 | path.write_text(textwrap.dedent(text)) 70 | return path 71 | 72 | 73 | @pytest.fixture 74 | def cookiecutter_project( 75 | cookiecutter_path: Path, 76 | cookiecutter_json: Path, 77 | cookiecutter_readme: Path, 78 | dot_cookiecutter_json: Path, 79 | ) -> Path: 80 | """The cookiecutter.""" 81 | return cookiecutter_path 82 | 83 | 84 | def make_repository(path: Path) -> git.Repository: 85 | """Turn a directory into a git repository.""" 86 | from retrocookie import git 87 | 88 | repository = git.Repository.init(path) 89 | repository.add() 90 | repository.commit("Initial commit") 91 | return repository 92 | 93 | 94 | @pytest.fixture 95 | def cookiecutter_repository(cookiecutter_project: Path) -> git.Repository: 96 | """The cookiecutter repository.""" 97 | return make_repository(cookiecutter_project) 98 | 99 | 100 | @pytest.fixture 101 | def cookiecutter_instance( 102 | cookiecutter_repository: git.Repository, 103 | context: Dict[str, str], 104 | tmp_path: Path, 105 | ) -> Path: 106 | """The cookiecutter instance.""" 107 | template = str(cookiecutter_repository.path) 108 | cookiecutter(template, no_input=True, output_dir=str(tmp_path)) 109 | return tmp_path / context["project_slug"] 110 | 111 | 112 | @pytest.fixture 113 | def cookiecutter_instance_repository(cookiecutter_instance: Path) -> git.Repository: 114 | """The cookiecutter instance repository.""" 115 | return make_repository(cookiecutter_instance) 116 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributor Guide 2 | ================= 3 | 4 | Thank you for your interest in improving this project. 5 | This project is open-source under the `MIT license`_ and 6 | welcomes contributions in the form of bug reports, feature requests, and pull requests. 7 | 8 | Here is a list of important resources for contributors: 9 | 10 | - `Source Code`_ 11 | - `Documentation`_ 12 | - `Issue Tracker`_ 13 | - `Code of Conduct`_ 14 | 15 | .. _MIT license: https://opensource.org/licenses/MIT 16 | .. _Source Code: https://github.com/cjolowicz/retrocookie 17 | .. _Documentation: https://retrocookie.readthedocs.io/ 18 | .. _Issue Tracker: https://github.com/cjolowicz/retrocookie/issues 19 | 20 | How to report a bug 21 | ------------------- 22 | 23 | Report bugs on the `Issue Tracker`_. 24 | 25 | When filing an issue, make sure to answer these questions: 26 | 27 | - Which operating system and Python version are you using? 28 | - Which version of this project are you using? 29 | - What did you do? 30 | - What did you expect to see? 31 | - What did you see instead? 32 | 33 | The best way to get your bug fixed is to provide a test case, 34 | and/or steps to reproduce the issue. 35 | 36 | 37 | How to request a feature 38 | ------------------------ 39 | 40 | Request features on the `Issue Tracker`_. 41 | 42 | 43 | How to set up your development environment 44 | ------------------------------------------ 45 | 46 | You need Python 3.7+ and the following tools: 47 | 48 | - Poetry_ 49 | - Nox_ 50 | - nox-poetry_ 51 | 52 | Install the package with development requirements: 53 | 54 | .. code:: console 55 | 56 | $ poetry install 57 | 58 | You can now run an interactive Python session, 59 | or the command-line interface: 60 | 61 | .. code:: console 62 | 63 | $ poetry run python 64 | $ poetry run retrocookie 65 | 66 | .. _Poetry: https://python-poetry.org/ 67 | .. _Nox: https://nox.thea.codes/ 68 | .. _nox-poetry: https://nox-poetry.readthedocs.io/ 69 | 70 | 71 | How to test the project 72 | ----------------------- 73 | 74 | Run the full test suite: 75 | 76 | .. code:: console 77 | 78 | $ nox 79 | 80 | List the available Nox sessions: 81 | 82 | .. code:: console 83 | 84 | $ nox --list-sessions 85 | 86 | You can also run a specific Nox session. 87 | For example, invoke the unit test suite like this: 88 | 89 | .. code:: console 90 | 91 | $ nox --session=tests 92 | 93 | Unit tests are located in the ``tests`` directory, 94 | and are written using the pytest_ testing framework. 95 | 96 | .. _pytest: https://pytest.readthedocs.io/ 97 | 98 | 99 | How to submit changes 100 | --------------------- 101 | 102 | Open a `pull request`_ to submit changes to this project. 103 | 104 | Your pull request needs to meet the following guidelines for acceptance: 105 | 106 | - The Nox test suite must pass without errors and warnings. 107 | - Include unit tests. This project maintains 100% code coverage. 108 | - If your changes add functionality, update the documentation accordingly. 109 | 110 | Feel free to submit early, though—we can always iterate on this. 111 | 112 | To run linting and code formatting checks before committing your change, you can install pre-commit as a Git hook by running the following command: 113 | 114 | .. code:: console 115 | 116 | $ nox --session=pre-commit -- install 117 | 118 | It is recommended to open an issue before starting work on anything. 119 | This will allow a chance to talk it over with the owners and validate your approach. 120 | 121 | .. _pull request: https://github.com/cjolowicz/retrocookie/pulls 122 | .. github-only 123 | .. _Code of Conduct: CODE_OF_CONDUCT.rst 124 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | """Test cases for the __main__ module.""" 2 | from pathlib import Path 3 | from typing import Any 4 | from typing import List 5 | from typing import NoReturn 6 | 7 | import pytest 8 | from click.testing import CliRunner 9 | from pytest import MonkeyPatch 10 | 11 | from retrocookie import __main__ 12 | from retrocookie import git 13 | from retrocookie import utils 14 | 15 | from .helpers import append 16 | from .helpers import branch 17 | 18 | 19 | @pytest.fixture 20 | def runner() -> CliRunner: 21 | """Fixture for invoking command-line interfaces.""" 22 | return CliRunner() 23 | 24 | 25 | @pytest.fixture 26 | def mock_retrocookie(monkeypatch: MonkeyPatch) -> None: 27 | """Replace retrocookie function by noop.""" 28 | monkeypatch.setattr(__main__, "retrocookie", lambda *args, **kwargs: None) 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "args", 33 | [ 34 | ["--help"], 35 | ["/home/user/src/repo", "--branch=topic"], 36 | ["--branch=topic", "/home/user/src/repo"], 37 | ["--branch=topic", "/home/user/src/repo", "18240dd"], 38 | ["--branch=topic", "https://example.com/owner/repo.git"], 39 | ["--branch=topic", "--create", "repo"], 40 | ["--branch=topic", "--create-branch=other", "repo"], 41 | [ 42 | "--branch=topic", 43 | "--include-variable=project_name", 44 | "--include-variable=package_name", 45 | "repo", 46 | ], 47 | ["--branch=topic", "--exclude-variable=github_user", "repo"], 48 | ["/home/user/src/repo"], 49 | ["/home/user/src/repo", "18240dd"], 50 | ["/home/user/src/repo", "--create-branch=topic"], 51 | ], 52 | ) 53 | def test_usage_success( 54 | runner: CliRunner, args: List[str], mock_retrocookie: None 55 | ) -> None: 56 | """It succeeds when invoked with the given arguments.""" 57 | result = runner.invoke(__main__.main, args) 58 | assert result.exit_code == 0, result.output 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "args", 63 | [ 64 | [], 65 | ["--create", "repo"], 66 | ["--branch=topic"], 67 | ["--branch=topic", "--create", "--create-branch=another", "repo"], 68 | ], 69 | ) 70 | def test_usage_error( 71 | runner: CliRunner, args: List[str], mock_retrocookie: None 72 | ) -> None: 73 | """It fails when invoked with the given arguments.""" 74 | result = runner.invoke(__main__.main, args) 75 | assert result.exit_code == 2 76 | 77 | 78 | def test_functional( 79 | runner: CliRunner, 80 | cookiecutter_repository: git.Repository, 81 | cookiecutter_instance_repository: git.Repository, 82 | ) -> None: 83 | """It succeeds when importing a topic branch with a README update.""" 84 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 85 | path = Path("README.md") 86 | text = "Lorem Ipsum\n" 87 | 88 | with branch(instance, "topic", create=True): 89 | append(instance, path, text) 90 | 91 | with utils.chdir(cookiecutter.path): 92 | result = runner.invoke(__main__.main, ["--branch=topic", str(instance.path)]) 93 | assert result.exit_code == 0 94 | 95 | 96 | def test_giterror(runner: CliRunner, monkeypatch: MonkeyPatch) -> None: 97 | """It prints an error message.""" 98 | message = "boom\nThis is the command output\n" 99 | 100 | def _(*args: Any, **kwargs: Any) -> NoReturn: 101 | raise git.CommandError(message) 102 | 103 | monkeypatch.setattr(__main__, "retrocookie", _) 104 | 105 | result = runner.invoke(__main__.main, ["repository"]) 106 | 107 | assert f"error: {message}" == result.output 108 | -------------------------------------------------------------------------------- /tests/pr/functional/test_main.py: -------------------------------------------------------------------------------- 1 | """Functional tests for retrocookie.pr.__main__.""" 2 | from pathlib import Path 3 | from typing import Callable 4 | from typing import List 5 | 6 | import pytest 7 | from github3.repos.repo import Repository 8 | 9 | from retrocookie.pr import __main__ 10 | from retrocookie.pr import appname 11 | from tests.pr.functional.conftest import CreatePullRequest 12 | from tests.pr.functional.conftest import FindPullRequest 13 | from tests.pr.functional.conftest import skip_without_token 14 | 15 | 16 | InvokeMain = Callable[[List[str]], None] 17 | 18 | 19 | @pytest.fixture 20 | def invoke_main(project: Repository, token: str) -> InvokeMain: 21 | """Invoke the main function.""" 22 | 23 | def _main(args: List[str]) -> None: 24 | args = [ 25 | f"--repository={project.full_name}", 26 | f"--token={token}", 27 | *args, 28 | ] 29 | 30 | __main__.main.main( 31 | args=args, 32 | prog_name=__main__.main.name, 33 | standalone_mode=False, 34 | ) 35 | 36 | return _main 37 | 38 | 39 | # This test connects to the real API. Skip it if we don't have a token (even 40 | # though the test does not actually use it). 41 | @skip_without_token 42 | def test_invalid_token() -> None: 43 | """It fails.""" 44 | with pytest.raises(SystemExit): 45 | __main__.main.main( 46 | args=["--repository=REPOSITORY", "--token=INVALID", "1"], 47 | prog_name=__main__.main.name, 48 | standalone_mode=False, 49 | ) 50 | 51 | 52 | @skip_without_token 53 | def test_number( 54 | create_project_pull_request: CreatePullRequest, 55 | invoke_main: InvokeMain, 56 | find_template_pull_request: FindPullRequest, 57 | ) -> None: 58 | """It imports the pull request with the given number.""" 59 | project_pull = create_project_pull_request( 60 | Path("README.md"), 61 | "# Welcome to example", 62 | ) 63 | 64 | invoke_main([str(project_pull.number)]) 65 | 66 | template_pull = find_template_pull_request() 67 | assert project_pull.title == template_pull.title 68 | 69 | 70 | @skip_without_token 71 | def test_rewrite( 72 | create_project_pull_request: CreatePullRequest, 73 | invoke_main: InvokeMain, 74 | template: Repository, 75 | ) -> None: 76 | """It rewrites the commits associated with the pull request.""" 77 | project_pull = create_project_pull_request( 78 | Path("README.md"), 79 | "# Welcome to example", 80 | ) 81 | 82 | invoke_main([str(project_pull.number)]) 83 | 84 | template_file = template.file_contents( 85 | "{{cookiecutter.project}}/README.md", 86 | ref=f"refs/heads/{appname}/{project_pull.head.ref}", 87 | ) 88 | assert "# Welcome to {{cookiecutter.project}}" == template_file.decoded.decode() 89 | 90 | 91 | @skip_without_token 92 | def test_force( 93 | create_project_pull_request: CreatePullRequest, 94 | invoke_main: InvokeMain, 95 | ) -> None: 96 | """It updates the pull request.""" 97 | project_pull = create_project_pull_request( 98 | Path("README.md"), 99 | "# Welcome to example", 100 | ) 101 | 102 | invoke_main([str(project_pull.number)]) 103 | invoke_main(["--force", "--all"]) 104 | 105 | 106 | @skip_without_token 107 | def test_user( 108 | create_project_pull_request: CreatePullRequest, 109 | invoke_main: InvokeMain, 110 | find_template_pull_request: FindPullRequest, 111 | ) -> None: 112 | """It imports pull requests of the given user.""" 113 | project_pull = create_project_pull_request( 114 | Path("README.md"), 115 | "# Welcome to example", 116 | ) 117 | 118 | invoke_main([f"--user={project_pull.user.login}", str(project_pull.number)]) 119 | 120 | template_pull = find_template_pull_request() 121 | assert project_pull.title == template_pull.title 122 | -------------------------------------------------------------------------------- /tests/pr/unit/test_main.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.__main__.""" 2 | import subprocess # noqa: S404 3 | from typing import List 4 | 5 | import pytest 6 | from click.testing import CliRunner 7 | from pytest import MonkeyPatch 8 | 9 | from retrocookie.pr import __main__ 10 | from retrocookie.pr import events 11 | from retrocookie.pr.base.bus import Bus 12 | from retrocookie.pr.base.bus import Error 13 | from retrocookie.pr.base.bus import Event 14 | from retrocookie.pr.cache import Cache 15 | from tests.pr.unit.data import EXAMPLE_PROJECT_PULL 16 | from tests.pr.unit.data import EXAMPLE_TEMPLATE 17 | from tests.pr.unit.data import EXAMPLE_TEMPLATE_PULL 18 | from tests.pr.unit.utils import raises 19 | 20 | 21 | def test_get_token_from_cache(cache: Cache) -> None: 22 | """It loads the token from the cache.""" 23 | cache.save_token("token") 24 | assert "token" == __main__.get_token(cache) 25 | 26 | 27 | def test_get_token_from_user(cache: Cache) -> None: 28 | """It prompts the user for the token.""" 29 | runner = CliRunner() 30 | with runner.isolation(input="token"): 31 | assert "token" == __main__.get_token(cache) 32 | 33 | 34 | def test_all_with_arguments() -> None: 35 | """It fails when --all is used with arguments.""" 36 | runner = CliRunner() 37 | result = runner.invoke(__main__.main, ["--all", "1"]) 38 | assert result.exit_code == 2 39 | 40 | 41 | def test_without_arguments() -> None: 42 | """It fails when no pull requests are specified.""" 43 | runner = CliRunner() 44 | result = runner.invoke(__main__.main, []) 45 | assert result.exit_code == 2 46 | 47 | 48 | def test_giterrorhandler_with_git_error(bus: Bus) -> None: 49 | """It raises a GitFailed event.""" 50 | with raises(events.GitFailed): 51 | with __main__.giterrorhandler(bus=bus): 52 | raise subprocess.CalledProcessError(1, ["git", "rev-parse", "HEAD"], "", "") 53 | 54 | 55 | def test_collect() -> None: 56 | """It collects errors.""" 57 | error = Error(Event()) 58 | errors: List[Error] = [] 59 | with __main__.collect(errors): 60 | raise error 61 | [caught] = errors # type: ignore[unreachable] 62 | assert error is caught 63 | 64 | 65 | @pytest.mark.parametrize( 66 | "error", 67 | [ 68 | subprocess.CalledProcessError(1, "false", "", ""), 69 | subprocess.CalledProcessError(1, ["echo", "hello"], "", ""), 70 | ], 71 | ) 72 | def test_giterrorhandler_without_git_error(bus: Bus, error: Exception) -> None: 73 | """It reraises the original exception.""" 74 | with pytest.raises(subprocess.CalledProcessError): 75 | with __main__.giterrorhandler(bus=bus): 76 | raise error 77 | 78 | 79 | def test_open_after_create(bus: Bus, monkeypatch: MonkeyPatch) -> None: 80 | """It opens the URL in a browser.""" 81 | event = events.PullRequestCreated( 82 | EXAMPLE_TEMPLATE, 83 | EXAMPLE_TEMPLATE_PULL, 84 | EXAMPLE_PROJECT_PULL, 85 | ) 86 | 87 | urls = [] 88 | 89 | def _webbrowser_open(url: str) -> None: 90 | urls.append(url) 91 | 92 | monkeypatch.setattr("webbrowser.open", _webbrowser_open) 93 | 94 | __main__.register_pull_request_viewer(bus=bus) 95 | 96 | bus.events.publish(event) 97 | 98 | assert urls == [event.template_pull.html_url] 99 | 100 | 101 | def test_open_after_update(bus: Bus, monkeypatch: MonkeyPatch) -> None: 102 | """It opens the URL in a browser.""" 103 | event = events.UpdatePullRequest( 104 | EXAMPLE_TEMPLATE, 105 | EXAMPLE_TEMPLATE_PULL, 106 | EXAMPLE_PROJECT_PULL, 107 | ) 108 | 109 | urls = [] 110 | 111 | def _webbrowser_open(url: str) -> None: 112 | urls.append(url) 113 | 114 | monkeypatch.setattr("webbrowser.open", _webbrowser_open) 115 | 116 | __main__.register_pull_request_viewer(bus=bus) 117 | 118 | with bus.contexts.publish(event): 119 | pass 120 | 121 | assert urls == [event.template_pull.html_url] 122 | -------------------------------------------------------------------------------- /src/retrocookie/pr/core.py: -------------------------------------------------------------------------------- 1 | """Import pull requests.""" 2 | from typing import Collection 3 | from typing import Optional 4 | from urllib.parse import urlparse 5 | 6 | from retrocookie import git 7 | from retrocookie.core import load_context 8 | from retrocookie.pr import events 9 | from retrocookie.pr.base.bus import Bus 10 | from retrocookie.pr.base.exceptionhandlers import ExceptionHandler 11 | from retrocookie.pr.cache import Cache 12 | from retrocookie.pr.importer import Importer 13 | from retrocookie.pr.list import list_pull_requests 14 | from retrocookie.pr.protocols import github 15 | from retrocookie.pr.protocols.retrocookie import Retrocookie 16 | from retrocookie.pr.repository import Repository 17 | from retrocookie.utils import removeprefix 18 | from retrocookie.utils import removesuffix 19 | 20 | 21 | def parse_repository_name(url: str) -> Optional[str]: 22 | """Extract the repository name from the URL, if possible.""" 23 | for prefix in ("gh:", "git@github.com:"): 24 | if url.startswith(prefix): 25 | path = removeprefix(url, prefix) 26 | return removesuffix(path, ".git") 27 | 28 | result = urlparse(url) 29 | if result.hostname == "github.com": 30 | return removesuffix(result.path[1:], ".git") 31 | 32 | return None 33 | 34 | 35 | def get_project_name(*, bus: Bus) -> str: 36 | """Return the repository name of the project.""" 37 | with bus.events.reraise(events.ProjectNotFound()): 38 | repository = git.Repository() 39 | 40 | for remote in repository.repo.remotes: 41 | name = parse_repository_name(remote.url) 42 | if name is not None: 43 | return name 44 | 45 | bus.events.raise_(events.ProjectNotFound()) 46 | 47 | 48 | def get_template_name(project: Repository, *, bus: Bus) -> str: 49 | """Return the repository name of the template, given the project.""" 50 | with bus.events.reraise( 51 | events.TemplateNotFound(project.github), 52 | when=(KeyError, TypeError), 53 | ): 54 | context = load_context(project.clone, "HEAD") 55 | 56 | name = parse_repository_name(context["_template"]) 57 | if name is not None: 58 | return name 59 | 60 | bus.events.raise_(events.TemplateNotFound(project.github)) 61 | 62 | 63 | def check_git_version(*, bus: Bus) -> None: 64 | """Check if we have a recent version of git.""" 65 | # newren/git-filter-repo requires git >= 2.22.0 66 | minimum_version = git.Version.parse("2.22.0") 67 | 68 | with bus.events.reraise(events.GitNotFound()): 69 | version = git.version() 70 | 71 | if version < minimum_version: 72 | bus.events.raise_(events.BadGitVersion(version, minimum_version)) 73 | 74 | 75 | def import_pull_requests( 76 | pull_requests: Collection[str] = (), 77 | *, 78 | api: github.API, 79 | bus: Bus, 80 | cache: Cache, 81 | errorhandler: ExceptionHandler, 82 | retrocookie: Retrocookie, 83 | repository: Optional[str] = None, 84 | user: Optional[str] = None, 85 | force: bool = False, 86 | ) -> None: 87 | """Import pull requests from a repository into its Cookiecutter template.""" 88 | check_git_version(bus=bus) 89 | 90 | if repository is None: 91 | repository = get_project_name(bus=bus) 92 | elif "/" not in repository: 93 | repository = "/".join([api.me, repository]) 94 | 95 | with bus.contexts.publish(events.LoadProject(repository)): 96 | project = Repository.load(repository, api=api, cache=cache) 97 | 98 | template_name = get_template_name(project, bus=bus) 99 | 100 | with bus.contexts.publish(events.LoadTemplate(template_name)): 101 | template = Repository.load(template_name, api=api, cache=cache) 102 | 103 | importer = Importer( 104 | project, 105 | template, 106 | bus=bus, 107 | cache=cache, 108 | retrocookie=retrocookie, 109 | ) 110 | 111 | for pull in list_pull_requests(project.github, pull_requests, user=user, bus=bus): 112 | with errorhandler: 113 | importer.import_(pull, force=force) 114 | -------------------------------------------------------------------------------- /tests/pr/unit/test_list.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.list.""" 2 | from typing import cast 3 | from typing import Iterable 4 | from typing import List 5 | 6 | import pytest 7 | 8 | from retrocookie.pr.base.bus import Bus 9 | from retrocookie.pr.list import list_pull_requests 10 | from retrocookie.pr.protocols.github import PullRequest as AbstractPullRequest 11 | from tests.pr.unit.fakes import github 12 | 13 | 14 | @pytest.fixture 15 | def repository(api: github.API) -> github.Repository: 16 | """Return a fake repository.""" 17 | return api.repository("owner", "name") 18 | 19 | 20 | pr1 = github.PullRequest(1, "title1", "body1", "branch1", "user1") 21 | pr2 = github.PullRequest(2, "title2", "body2", "branch2", "user2") 22 | 23 | 24 | def equal( 25 | expected: Iterable[github.PullRequest], actual: Iterable[AbstractPullRequest] 26 | ) -> bool: 27 | """Return True if both iterables contain the same pull requests.""" 28 | actual_ = [cast(github.PullRequest, item) for item in actual] 29 | return sorted(actual_) == sorted(expected) 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "pull_requests", 34 | [ 35 | [], 36 | [pr1, pr2], 37 | ], 38 | ) 39 | def test_list_all( 40 | repository: github.Repository, 41 | bus: Bus, 42 | pull_requests: List[github.PullRequest], 43 | ) -> None: 44 | """It lists all pull requests.""" 45 | repository._pull_requests = pull_requests 46 | actual = list_pull_requests(repository, bus=bus) 47 | assert equal(pull_requests, actual) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "pull_requests, user, expected", 52 | [ 53 | ([], pr1.user, []), 54 | ([pr1, pr2], pr2.user, [pr2]), 55 | ([pr1, pr2], "unknown", []), 56 | ], 57 | ) 58 | def test_list_by_user( 59 | repository: github.Repository, 60 | bus: Bus, 61 | pull_requests: List[github.PullRequest], 62 | user: str, 63 | expected: List[github.PullRequest], 64 | ) -> None: 65 | """It lists all pull requests.""" 66 | repository._pull_requests = pull_requests 67 | actual = list_pull_requests(repository, user=user, bus=bus) 68 | assert equal(expected, actual) 69 | 70 | 71 | @pytest.mark.parametrize( 72 | "pull_requests, specs, expected", 73 | [ 74 | ([pr1, pr2], ["1"], [pr1]), 75 | ([pr1, pr2], ["2"], [pr2]), 76 | ([pr1, pr2], ["1", "2"], [pr1, pr2]), 77 | ], 78 | ) 79 | def test_list_by_number( 80 | repository: github.Repository, 81 | bus: Bus, 82 | pull_requests: List[github.PullRequest], 83 | specs: List[str], 84 | expected: List[github.PullRequest], 85 | ) -> None: 86 | """It lists all pull requests.""" 87 | repository._pull_requests = pull_requests 88 | actual = list_pull_requests(repository, specs, bus=bus) 89 | assert equal(expected, actual) 90 | 91 | 92 | @pytest.mark.parametrize( 93 | "pull_requests, specs, expected", 94 | [ 95 | ([pr1, pr2], ["branch1"], [pr1]), 96 | ([pr1, pr2], ["branch2"], [pr2]), 97 | ([pr1, pr2], ["branch1", "branch2"], [pr1, pr2]), 98 | ], 99 | ) 100 | def test_list_by_head( 101 | repository: github.Repository, 102 | bus: Bus, 103 | pull_requests: List[github.PullRequest], 104 | specs: List[str], 105 | expected: List[github.PullRequest], 106 | ) -> None: 107 | """It lists all pull requests.""" 108 | repository._pull_requests = pull_requests 109 | actual = list_pull_requests(repository, specs, bus=bus) 110 | assert equal(expected, actual) 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "pull_requests, specs", 115 | [ 116 | ([], ["branch1"]), 117 | ([pr1, pr2], ["unknown"]), 118 | ], 119 | ) 120 | def test_list_not_found( 121 | repository: github.Repository, 122 | bus: Bus, 123 | pull_requests: List[github.PullRequest], 124 | specs: List[str], 125 | ) -> None: 126 | """It lists all pull requests.""" 127 | repository._pull_requests = pull_requests 128 | with pytest.raises(Exception): 129 | for _ in list_pull_requests(repository, specs, bus=bus): 130 | pass 131 | -------------------------------------------------------------------------------- /tests/pr/unit/base/test_exceptionhandlers.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.exceptionhandlers.""" 2 | from functools import reduce 3 | from typing import Tuple 4 | from typing import Union 5 | 6 | import pytest 7 | 8 | from retrocookie.pr.base.exceptionhandlers import ExceptionHandler 9 | from retrocookie.pr.base.exceptionhandlers import exceptionhandler 10 | from retrocookie.pr.base.exceptionhandlers import nullhandler 11 | 12 | 13 | class RedError(Exception): 14 | """Exception for testing.""" 15 | 16 | 17 | class GreenError(Exception): 18 | """Exception for testing.""" 19 | 20 | 21 | class BlueError(Exception): 22 | """Exception for testing.""" 23 | 24 | 25 | class IndigoError(BlueError): 26 | """Exception for testing.""" 27 | 28 | 29 | @exceptionhandler 30 | def suppress_red(error: RedError) -> bool: 31 | """Suppress RedError exceptions.""" 32 | return True 33 | 34 | 35 | @exceptionhandler() # empty parentheses are equivalent to none 36 | def suppress_green(error: GreenError) -> bool: 37 | """Suppress GreenError exceptions.""" 38 | return True 39 | 40 | 41 | @exceptionhandler 42 | def suppress_blue(error: BlueError) -> bool: 43 | """Suppress BlueError exceptions.""" 44 | return True 45 | 46 | 47 | @exceptionhandler 48 | def suppress_indigo(error: IndigoError) -> bool: 49 | """Suppress IndigoError exceptions.""" 50 | return True 51 | 52 | 53 | @exceptionhandler(RedError, BlueError) 54 | def suppress_red_and_blue(error: Union[RedError, BlueError]) -> bool: 55 | """Suppress RedError and BlueError exceptions.""" 56 | return True 57 | 58 | 59 | @exceptionhandler(RedError, GreenError) 60 | def suppress_red_and_green(error: Union[RedError, GreenError]) -> bool: 61 | """Suppress RedError and GreenError exceptions.""" 62 | return True 63 | 64 | 65 | @exceptionhandler(RedError, GreenError, BlueError) 66 | def suppress_red_green_blue(error: Union[RedError, GreenError, BlueError]) -> bool: 67 | """Suppress RedError, GreenError, and BlueError exceptions.""" 68 | return True 69 | 70 | 71 | def test_nullhandler() -> None: 72 | """It does not swallow the error.""" 73 | with pytest.raises(BlueError): 74 | with nullhandler: 75 | raise BlueError() 76 | 77 | 78 | def test_decorator_missing_annotation() -> None: 79 | """It raises TypeError.""" 80 | with pytest.raises(TypeError): 81 | 82 | @exceptionhandler 83 | def _(exception): # type: ignore[no-untyped-def] 84 | pass 85 | 86 | 87 | def test_decorator_decorator() -> None: 88 | """It can be used as a decorator.""" 89 | 90 | @suppress_blue 91 | def raise_blue() -> None: 92 | raise BlueError() 93 | 94 | raise_blue() # does not throw 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "handler", 99 | [ 100 | suppress_indigo, 101 | suppress_blue, 102 | suppress_red_and_blue, 103 | suppress_red_green_blue, 104 | ], 105 | ) 106 | def test_decorator_positive(handler: ExceptionHandler) -> None: 107 | """The exception is handled.""" 108 | with handler: 109 | raise IndigoError() 110 | 111 | 112 | @pytest.mark.parametrize( 113 | "handler", 114 | [ 115 | nullhandler, 116 | suppress_red, 117 | suppress_green, 118 | suppress_indigo, 119 | suppress_red_and_green, 120 | ], 121 | ) 122 | def test_decorator_negative(handler: ExceptionHandler) -> None: 123 | """The exception is not handled.""" 124 | with pytest.raises(BlueError): 125 | with handler: 126 | raise BlueError() 127 | 128 | 129 | @pytest.mark.parametrize( 130 | "handlers", 131 | [ 132 | (nullhandler, suppress_indigo), 133 | (suppress_indigo, nullhandler), 134 | (suppress_red, suppress_green, suppress_blue), 135 | (suppress_blue, suppress_green, suppress_red), 136 | ], 137 | ) 138 | def test_compose_lshift(handlers: Tuple[ExceptionHandler]) -> None: 139 | """The exception is handled.""" 140 | handler = reduce(lambda a, b: a << b, handlers) 141 | with handler: 142 | raise IndigoError() 143 | 144 | 145 | @pytest.mark.parametrize( 146 | "handlers", 147 | [ 148 | (nullhandler, suppress_indigo), 149 | (suppress_indigo, nullhandler), 150 | (suppress_blue, suppress_red), 151 | (suppress_red, suppress_blue), 152 | (suppress_red, suppress_green, suppress_blue), 153 | (suppress_blue, suppress_green, suppress_red), 154 | ], 155 | ) 156 | def test_compose_rshift(handlers: Tuple[ExceptionHandler]) -> None: 157 | """The exception is handled.""" 158 | handler = reduce(lambda a, b: a >> b, handlers) 159 | with handler: 160 | raise IndigoError() 161 | -------------------------------------------------------------------------------- /tests/pr/unit/fakes/github.py: -------------------------------------------------------------------------------- 1 | """Fake GitHub API.""" 2 | from __future__ import annotations 3 | 4 | import dataclasses 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import AbstractSet 8 | from typing import Iterable 9 | from typing import List 10 | from typing import Optional 11 | from typing import Set 12 | 13 | from retrocookie import git 14 | 15 | 16 | @dataclass(order=True) 17 | class PullRequest: 18 | """Fake pull request.""" 19 | 20 | number: int 21 | title: str 22 | body: Optional[str] 23 | branch: str 24 | user: str 25 | html_url: str = "" 26 | labels: Set[str] = dataclasses.field(default_factory=set) 27 | 28 | def update(self, title: str, body: Optional[str], labels: AbstractSet[str]) -> None: 29 | """Update the pull request.""" 30 | self.title = title 31 | self.body = body 32 | self.labels = set(labels) 33 | 34 | 35 | @dataclass 36 | class Repository: 37 | """Fake repository.""" 38 | 39 | owner: str 40 | name: str 41 | _pull_requests: List[PullRequest] = dataclasses.field(default_factory=list) 42 | _repository: Optional[git.Repository] = None 43 | 44 | @classmethod 45 | def create( 46 | cls, owner: str, name: str, _backend: Optional[Path] = None 47 | ) -> Repository: 48 | """Create a fake repository.""" 49 | if _backend is None: 50 | return cls(owner, name) 51 | 52 | path = _backend / owner / f"{name}.git" 53 | if path.exists(): 54 | repository = git.Repository(path) 55 | else: 56 | repository = git.Repository.init(path, bare=True) 57 | repository.commit("Initial") 58 | 59 | return cls(owner, name, _repository=repository) 60 | 61 | @property 62 | def full_name(self) -> str: 63 | """The full name of the repository.""" 64 | return "/".join([self.owner, self.name]) 65 | 66 | @property 67 | def clone_url(self) -> str: 68 | """URL for cloning the repository.""" 69 | if self._repository is None: 70 | return f"https://github.com/{self.full_name}.git" 71 | 72 | return f"file://{self._repository.path.resolve()}" 73 | 74 | @property 75 | def push_url(self) -> str: 76 | """URL for pushing to the repository.""" 77 | return self.clone_url 78 | 79 | @property 80 | def default_branch(self) -> str: 81 | """The default branch of the repository.""" 82 | return git.get_default_branch() 83 | 84 | def pull_requests(self) -> Iterable[PullRequest]: 85 | """Pull requests open in the repository.""" 86 | return self._pull_requests 87 | 88 | def pull_request(self, number: int) -> PullRequest: 89 | """Return pull request identified by the given number.""" 90 | for pull_request in self._pull_requests: 91 | if pull_request.number == number: 92 | return pull_request 93 | raise ValueError("not found") 94 | 95 | def pull_request_by_head(self, head: str) -> Optional[PullRequest]: 96 | """Return pull request for the given head.""" 97 | user, _, branch = head.rpartition(":") 98 | for pull_request in self._pull_requests: 99 | if self.owner == user and pull_request.branch == branch: 100 | return pull_request 101 | return None 102 | 103 | def create_pull_request( 104 | self, 105 | *, 106 | head: str, 107 | title: str, 108 | body: Optional[str], 109 | labels: AbstractSet[str], 110 | ) -> PullRequest: 111 | """Create a pull request.""" 112 | user, _, branch = head.rpartition(":") 113 | number = 1 + max((pr.number for pr in self._pull_requests), default=0) 114 | html_url = f"https://github.com/{self.full_name}/pull/{number}" 115 | pull_request = PullRequest( 116 | number=number, 117 | title=title, 118 | body=body, 119 | branch=branch, 120 | user=user or self.owner, 121 | html_url=html_url, 122 | labels=set(labels), 123 | ) 124 | self._pull_requests.append(pull_request) 125 | return pull_request 126 | 127 | 128 | @dataclass 129 | class API: 130 | """Fake GitHub API.""" 131 | 132 | _backend: Optional[Path] = None 133 | _repositories: List[Repository] = dataclasses.field(default_factory=list) 134 | 135 | @property 136 | def me(self) -> str: 137 | """Return the login of the authenticated user.""" 138 | return "owner" 139 | 140 | def repository(self, owner: str, name: str) -> Repository: 141 | """Return the repository with the given owner and name.""" 142 | for repository in self._repositories: 143 | if repository.owner == owner and repository.name == name: 144 | return repository 145 | 146 | repository = Repository.create(owner, name, _backend=self._backend) 147 | self._repositories.append(repository) 148 | return repository 149 | -------------------------------------------------------------------------------- /src/retrocookie/__main__.py: -------------------------------------------------------------------------------- 1 | """Command-line interface.""" 2 | from pathlib import Path 3 | from typing import Any 4 | from typing import Container 5 | from typing import Iterable 6 | from typing import Optional 7 | 8 | import click 9 | 10 | from . import git 11 | from .core import retrocookie 12 | 13 | 14 | @click.command() 15 | @click.option( 16 | "--branch", 17 | "-b", 18 | metavar="BRANCH", 19 | help="Remote branch to cherry-pick", 20 | ) 21 | @click.option( 22 | "--upstream", 23 | metavar="REF", 24 | help="Remote upstream branch", 25 | ) 26 | @click.option( 27 | "--create", 28 | "-c", 29 | is_flag=True, 30 | help="Create a local branch with the same name as --branch", 31 | ) 32 | @click.option( 33 | "--create-branch", 34 | metavar="BRANCH", 35 | help="Create a local branch for the imported commits", 36 | ) 37 | @click.option( 38 | "--include-variable", 39 | "include_variables", 40 | metavar="VAR", 41 | multiple=True, 42 | help="Only rewrite these Cookiecutter variables", 43 | ) 44 | @click.option( 45 | "--exclude-variable", 46 | "exclude_variables", 47 | metavar="VAR", 48 | multiple=True, 49 | help="Do not rewrite these Cookiecutter variables", 50 | ) 51 | @click.option( 52 | "--directory", 53 | "-C", 54 | metavar="DIR", 55 | help="Path to the repository [default: working directory]", 56 | ) 57 | @click.argument("repository") 58 | @click.argument("commits", nargs=-1) 59 | @click.version_option() 60 | def main( 61 | branch: Optional[str], 62 | upstream: Optional[str], 63 | create: bool, 64 | create_branch: Optional[str], 65 | include_variables: Container[str], 66 | exclude_variables: Container[str], 67 | directory: Optional[str], 68 | repository: str, 69 | commits: Iterable[str], 70 | ) -> None: 71 | """Retrocookie imports commits into Cookiecutter templates. 72 | 73 | The path to the source repository is passed as a positional argument. 74 | This should be an instance of the Cookiecutter template with a 75 | .cookiecutter.json file. 76 | 77 | Additional positional arguments are the commits to be imported. 78 | See gitrevisions(7) for a list of ways to spell commits and commit 79 | ranges. These arguments are interpreted in the same way as 80 | git-cherry-pick(1) does. If no commits are specified, the HEAD 81 | commit is imported. 82 | 83 | Specifying --branch is equivalent to passing master.. 84 | as a positional argument. Use --upstream to specify a different 85 | upstream branch. Use --create to create the same branch in the 86 | target repository. 87 | 88 | By default, commits are cherry-picked onto the current branch. 89 | Use --create-branch to specify a different branch, which must 90 | not exist yet. As a shorthand, --create is equivalent to passing 91 | --create-branch with the same argument as --branch. 92 | 93 | Commits are rewritten in the following way: 94 | 95 | \b 96 | - Files are moved to the template directory 97 | - Tokens with special meaning in Jinja are escaped 98 | - Values from .cookiecutter.json are replaced by templating tags 99 | with the corresponding variable name 100 | 101 | Use the --include-variable and --exclude-variable options to include or 102 | exclude specific variables from .cookiecutter.json. 103 | """ 104 | if create: 105 | if create_branch: 106 | raise click.UsageError( 107 | "--create and --create-branch are mutually exclusive" 108 | ) 109 | 110 | if branch is None: 111 | raise click.UsageError("--create requires --branch") 112 | 113 | create_branch = branch 114 | 115 | if not commits and branch is None: 116 | commits = ["HEAD"] 117 | 118 | path = Path(directory) if directory else None 119 | try: 120 | retrocookie( 121 | Path(repository), 122 | commits, 123 | branch=branch, 124 | upstream=upstream, 125 | create_branch=create_branch, 126 | include_variables=include_variables, 127 | exclude_variables=exclude_variables, 128 | path=path, 129 | ) 130 | except git.CommandError as error: 131 | printerror(error) 132 | raise SystemExit(1) from None 133 | 134 | 135 | def printerror(error: git.CommandError) -> None: 136 | """Produce a friendly error message when git failed.""" 137 | styles: dict[str, dict[str, Any]] = { 138 | "==>": {"fg": "bright_black", "bold": True}, 139 | "error": {"fg": "red"}, 140 | "fatal": {"fg": "red"}, 141 | "hint": {"fg": "yellow"}, 142 | } 143 | 144 | for line in f"error: {error}".splitlines(): 145 | for prefix, style in styles.items(): 146 | if line.startswith(prefix): 147 | click.secho(line, **style) 148 | break 149 | else: 150 | click.echo(line) 151 | 152 | 153 | if __name__ == "__main__": 154 | main(prog_name="retrocookie") # pragma: no cover 155 | -------------------------------------------------------------------------------- /src/retrocookie/filter.py: -------------------------------------------------------------------------------- 1 | """Interface for git-filter-repo.""" 2 | import contextlib 3 | import io 4 | import re 5 | from pathlib import Path 6 | from typing import Any 7 | from typing import Container 8 | from typing import Dict 9 | from typing import Iterable 10 | from typing import Iterator 11 | from typing import List 12 | from typing import Optional 13 | from typing import overload 14 | from typing import Sequence 15 | from typing import Tuple 16 | 17 | from git_filter_repo import Blob 18 | from git_filter_repo import FilteringOptions 19 | from git_filter_repo import RepoFilter 20 | 21 | from . import git 22 | 23 | 24 | def get_replacements( 25 | context: Dict[str, Any], 26 | include_variables: Container[str], 27 | exclude_variables: Container[str], 28 | ) -> List[Tuple[bytes, bytes]]: 29 | """Return replacements to be applied to commits from the template instance.""" 30 | 31 | def ref(key: str) -> str: 32 | return f"{{{{cookiecutter.{key}}}}}" 33 | 34 | return [ 35 | (value.encode(), ref(key).encode()) 36 | for key, value in context.items() 37 | if key not in exclude_variables 38 | and not (include_variables and key not in include_variables) 39 | and isinstance(value, str) 40 | ] 41 | 42 | 43 | def find_token( 44 | string: bytes, pos: int, tokens: Sequence[bytes] 45 | ) -> Tuple[Optional[bytes], int]: 46 | """Find the first occurrence of any of multiple tokens.""" 47 | pattern = re.compile(b"|".join(re.escape(token) for token in tokens)) 48 | match = pattern.search(string, pos) 49 | if match is None: 50 | return None, -1 51 | return match.group(), match.start() 52 | 53 | 54 | def quote_tokens( 55 | text: bytes, quotes: Tuple[bytes, bytes], tokens: Sequence[bytes] 56 | ) -> bytes: 57 | """Wrap tokens in ````.""" 58 | 59 | def _generate() -> Iterator[bytes]: 60 | index = 0 61 | 62 | while index < len(text): 63 | token, found = find_token(text, index, tokens) 64 | 65 | if token is None: 66 | yield text[index:] 67 | break 68 | 69 | if found != 0: 70 | yield text[index:found] 71 | 72 | yield token.join(quotes) 73 | index = found + len(token) 74 | 75 | return b"".join(_generate()) 76 | 77 | 78 | @overload 79 | def to_bytes(args: Tuple[str, str]) -> Tuple[bytes, bytes]: # noqa: D103 80 | ... # pragma: no cover 81 | 82 | 83 | @overload 84 | def to_bytes(args: Tuple[str, ...]) -> Tuple[bytes, ...]: # noqa: D103 85 | ... # pragma: no cover 86 | 87 | 88 | def to_bytes(args: Tuple[str, ...]) -> Tuple[bytes, ...]: 89 | """Convert each str to bytes.""" 90 | return tuple(arg.encode() for arg in args) 91 | 92 | 93 | def escape_jinja(text: bytes) -> bytes: 94 | """Escape Jinja tokens.""" 95 | quotes = to_bytes(('{{"', '"}}')) 96 | tokens = to_bytes(("{{", "}}", "{%", "%}", "{#", "#}")) 97 | return quote_tokens(text, quotes, tokens) 98 | 99 | 100 | class RepositoryFilter: 101 | """Perform path and blob replacements on a repository.""" 102 | 103 | def __init__( 104 | self, 105 | repository: git.Repository, 106 | source: git.Repository, 107 | commits: Iterable[str], 108 | template_directory: Path, 109 | context: Dict[str, Any], 110 | include_variables: Container[str], 111 | exclude_variables: Container[str], 112 | ) -> None: 113 | """Initialize.""" 114 | self.repository = repository 115 | self.source = source 116 | self.commits = commits 117 | self.template_directory = str(template_directory).encode() 118 | self.replacements = get_replacements( 119 | context, include_variables, exclude_variables 120 | ) 121 | 122 | def filename_callback(self, filename: bytes) -> bytes: 123 | """Rewrite filenames.""" 124 | for old, new in self.replacements: 125 | filename = filename.replace(old, new) 126 | return b"/".join((self.template_directory, filename)) 127 | 128 | def blob_callback(self, blob: Blob, metadata: Dict[str, Any]) -> None: 129 | """Rewrite blobs.""" 130 | blob.data = escape_jinja(blob.data) 131 | for old, new in self.replacements: 132 | blob.data = blob.data.replace(old, new) 133 | 134 | def _create_filter(self) -> RepoFilter: 135 | """Create the filter.""" 136 | args = FilteringOptions.parse_args( 137 | [ 138 | f"--source={self.source.path}", 139 | f"--target={self.repository.path}", 140 | "--replace-refs=update-and-add", 141 | "--refs", 142 | *self.commits, 143 | ] 144 | ) 145 | return RepoFilter( 146 | args, 147 | filename_callback=self.filename_callback, 148 | blob_callback=self.blob_callback, 149 | ) 150 | 151 | def run(self) -> None: 152 | """Run the filter.""" 153 | null = io.StringIO() 154 | with contextlib.redirect_stdout(null), contextlib.redirect_stderr(null): 155 | repofilter = self._create_filter() 156 | repofilter.run() 157 | -------------------------------------------------------------------------------- /tests/pr/unit/test_console.py: -------------------------------------------------------------------------------- 1 | """Tests for the console.""" 2 | import contextlib 3 | import io 4 | from typing import Callable 5 | from typing import Iterator 6 | from typing import Tuple 7 | 8 | import pytest 9 | 10 | from retrocookie import git 11 | from retrocookie.compat.contextlib import contextmanager 12 | from retrocookie.pr import console 13 | from retrocookie.pr import events 14 | from retrocookie.pr.base.bus import Bus 15 | from retrocookie.pr.base.bus import Context 16 | from retrocookie.pr.base.bus import Event 17 | from tests.pr.unit.data import EXAMPLE_PROJECT 18 | from tests.pr.unit.data import EXAMPLE_PROJECT_PULL 19 | from tests.pr.unit.data import EXAMPLE_TEMPLATE 20 | from tests.pr.unit.data import EXAMPLE_TEMPLATE_PULL 21 | 22 | 23 | @pytest.fixture 24 | def bus() -> Iterator[Bus]: 25 | """Fixture for a bus with console handlers.""" 26 | bus = Bus() 27 | console.start(bus=bus) 28 | yield bus 29 | 30 | 31 | @contextmanager 32 | def redirect() -> Iterator[Tuple[io.StringIO, io.StringIO]]: 33 | """Redirect standard output and error.""" 34 | stdout = io.StringIO() 35 | stderr = io.StringIO() 36 | with contextlib.redirect_stdout(stdout): 37 | with contextlib.redirect_stderr(stderr): 38 | yield stdout, stderr 39 | 40 | 41 | PublishEvent = Callable[[Event], Tuple[str, str]] 42 | 43 | 44 | @pytest.fixture 45 | def publish_event(bus: Bus) -> PublishEvent: 46 | """Factory fixture for publishing a event.""" 47 | 48 | def _(event: Event) -> Tuple[str, str]: 49 | with redirect() as (stdout, stderr): 50 | bus.events.publish(event) 51 | return stdout.getvalue(), stderr.getvalue() 52 | 53 | return _ 54 | 55 | 56 | def get_class_name(instance: object) -> str: 57 | """Return the class name.""" 58 | return instance.__class__.__name__ 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "event", 63 | [ 64 | events.GitNotFound(), 65 | events.BadGitVersion( 66 | git.Version.parse("0.99"), 67 | git.Version.parse("2.22.0"), 68 | ), 69 | events.ProjectNotFound(), 70 | events.TemplateNotFound(EXAMPLE_PROJECT), 71 | events.RepositoryNotFound("owner/name"), 72 | events.PullRequestNotFound("1"), 73 | events.PullRequestNotFound("branch"), 74 | events.PullRequestAlreadyExists( 75 | EXAMPLE_TEMPLATE, 76 | EXAMPLE_TEMPLATE_PULL, 77 | EXAMPLE_PROJECT_PULL, 78 | ), 79 | events.PullRequestCreated( 80 | EXAMPLE_TEMPLATE, 81 | EXAMPLE_TEMPLATE_PULL, 82 | EXAMPLE_PROJECT_PULL, 83 | ), 84 | events.GitFailed("rev-parse", ["8beecafe"], 128, "8beecafe", "fatal: ..."), 85 | events.GitFailed("cherry-pick", [], 1, "", "could not apply 8beecafe"), 86 | events.GitFailed("foo", [], 1, "", ""), 87 | events.GitHubError( 88 | "https://api.github.com/repos/owner/name", "GET", 401, "Bad credentials", [] 89 | ), 90 | events.ConnectionError( 91 | "https://api.github.com/repos/owner/name", "GET", "pigeon unavailable" 92 | ), 93 | ], 94 | ids=get_class_name, 95 | ) 96 | def test_publish_event(publish_event: PublishEvent, event: Event) -> None: 97 | """It displays the event on standard error.""" 98 | stdout, stderr = publish_event(event) 99 | assert not stdout and stderr 100 | 101 | 102 | PublishContext = Callable[[Context], Tuple[str, str]] 103 | 104 | 105 | @pytest.fixture 106 | def publish_context(bus: Bus) -> PublishContext: 107 | """Factory fixture for publishing a context.""" 108 | 109 | class _ExampleError(Exception): 110 | pass 111 | 112 | def _(context: Context) -> Tuple[str, str]: 113 | with redirect() as (stdout, stderr): 114 | with contextlib.suppress(_ExampleError): 115 | with bus.contexts.publish(context): 116 | # The exception ensures that something is written to the 117 | # console before control exits the context. 118 | raise _ExampleError("Boom") 119 | return stdout.getvalue(), stderr.getvalue() 120 | 121 | return _ 122 | 123 | 124 | @pytest.mark.parametrize( 125 | "context", 126 | [ 127 | events.LoadTemplate("owner/name"), 128 | events.LoadProject("owner/name"), 129 | events.CreatePullRequest( 130 | EXAMPLE_TEMPLATE, "retrocookie-pr/branch", EXAMPLE_PROJECT_PULL 131 | ), 132 | events.UpdatePullRequest( 133 | EXAMPLE_TEMPLATE, EXAMPLE_TEMPLATE_PULL, EXAMPLE_PROJECT_PULL 134 | ), 135 | ], 136 | ids=get_class_name, 137 | ) 138 | def test_publish_context(publish_context: PublishContext, context: Context) -> None: 139 | """It displays the context on standard error.""" 140 | stdout, stderr = publish_context(context) 141 | assert not stdout and stderr 142 | 143 | 144 | def test_progress_failure() -> None: 145 | """It displays a failure message.""" 146 | c = console.Console() 147 | with contextlib.suppress(Exception): 148 | with redirect() as (stdout, stderr): 149 | with c.progress("message"): 150 | raise Exception("Boom") 151 | assert "⨯" in stderr.getvalue() 152 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | tests: 9 | name: ${{ matrix.session }} ${{ matrix.python-version }} / ${{ matrix.os }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | include: 15 | - { python-version: 3.9, os: ubuntu-latest, session: "pre-commit" } 16 | - { python-version: 3.9, os: ubuntu-latest, session: "safety" } 17 | - { python-version: 3.9, os: ubuntu-latest, session: "mypy" } 18 | - { python-version: 3.8, os: ubuntu-latest, session: "mypy" } 19 | - { python-version: 3.7, os: ubuntu-latest, session: "mypy" } 20 | - { python-version: 3.9, os: ubuntu-latest, session: "tests" } 21 | - { python-version: 3.8, os: ubuntu-latest, session: "tests" } 22 | - { python-version: 3.7, os: ubuntu-latest, session: "tests" } 23 | - { python-version: 3.9, os: windows-latest, session: "tests" } 24 | - { python-version: 3.9, os: macos-latest, session: "tests" } 25 | - { python-version: 3.9, os: ubuntu-latest, session: "typeguard" } 26 | - { python-version: 3.9, os: ubuntu-latest, session: "xdoctest" } 27 | - { python-version: 3.8, os: ubuntu-latest, session: "docs-build" } 28 | 29 | env: 30 | NOXSESSION: ${{ matrix.session }} 31 | FORCE_COLOR: "1" 32 | PRE_COMMIT_COLOR: "always" 33 | 34 | steps: 35 | - name: Check out the repository 36 | uses: actions/checkout@v3.0.2 37 | 38 | - name: Set up Python ${{ matrix.python-version }} 39 | uses: actions/setup-python@v3.1.2 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | 43 | - name: Upgrade pip 44 | run: | 45 | pip install --constraint=.github/workflows/constraints.txt pip 46 | pip --version 47 | 48 | - name: Install Poetry 49 | run: | 50 | pipx install --pip-args=--constraint=.github/workflows/constraints.txt poetry 51 | poetry --version 52 | 53 | - name: Install Nox 54 | run: | 55 | pipx install --pip-args=--constraint=.github/workflows/constraints.txt nox 56 | pipx inject --pip-args=--constraint=.github/workflows/constraints.txt nox nox-poetry 57 | nox --version 58 | 59 | - name: Compute pre-commit cache key 60 | if: matrix.session == 'pre-commit' 61 | id: pre-commit-cache 62 | shell: python 63 | run: | 64 | import hashlib 65 | import sys 66 | 67 | python = "py{}.{}".format(*sys.version_info[:2]) 68 | payload = sys.version.encode() + sys.executable.encode() 69 | digest = hashlib.sha256(payload).hexdigest() 70 | result = "${{ runner.os }}-{}-{}-pre-commit".format(python, digest[:8]) 71 | 72 | print("::set-output name=result::{}".format(result)) 73 | 74 | - name: Restore pre-commit cache 75 | uses: actions/cache@v2.1.7 76 | if: matrix.session == 'pre-commit' 77 | with: 78 | path: ~/.cache/pre-commit 79 | key: ${{ steps.pre-commit-cache.outputs.result }}-${{ hashFiles('.pre-commit-config.yaml') }} 80 | restore-keys: | 81 | ${{ steps.pre-commit-cache.outputs.result }}- 82 | 83 | - name: Configure git 84 | if: matrix.session == 'tests' || matrix.session == 'typeguard' 85 | run: | 86 | git config --global user.name "GitHub Action" 87 | git config --global user.email "action@github.com" 88 | 89 | - name: Run Nox 90 | run: | 91 | nox --force-color --python=${{ matrix.python-version }} 92 | 93 | - name: Upload coverage data 94 | if: always() && matrix.session == 'tests' 95 | uses: "actions/upload-artifact@v3.0.0" 96 | with: 97 | name: coverage-data 98 | path: ".coverage.*" 99 | 100 | - name: Upload documentation 101 | if: matrix.session == 'docs-build' 102 | uses: actions/upload-artifact@v3.0.0 103 | with: 104 | name: docs 105 | path: docs/_build 106 | 107 | coverage: 108 | runs-on: ubuntu-latest 109 | needs: tests 110 | steps: 111 | - name: Check out the repository 112 | uses: actions/checkout@v3.0.2 113 | 114 | - name: Set up Python 3.9 115 | uses: actions/setup-python@v3.1.2 116 | with: 117 | python-version: 3.9 118 | 119 | - name: Upgrade pip 120 | run: | 121 | pip install --constraint=.github/workflows/constraints.txt pip 122 | pip --version 123 | 124 | - name: Install Poetry 125 | run: | 126 | pipx install --pip-args=--constraint=.github/workflows/constraints.txt poetry 127 | poetry --version 128 | 129 | - name: Install Nox 130 | run: | 131 | pipx install --pip-args=--constraint=.github/workflows/constraints.txt nox 132 | pipx inject --pip-args=--constraint=.github/workflows/constraints.txt nox nox-poetry 133 | nox --version 134 | 135 | - name: Download coverage data 136 | uses: actions/download-artifact@v3.0.0 137 | with: 138 | name: coverage-data 139 | 140 | - name: Combine coverage data and display human readable report 141 | run: | 142 | nox --force-color --session=coverage 143 | 144 | - name: Create coverage report 145 | run: | 146 | nox --force-color --session=coverage -- xml 147 | 148 | - name: Upload coverage report 149 | uses: codecov/codecov-action@v2.1.0 150 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.rst: -------------------------------------------------------------------------------- 1 | Contributor Covenant Code of Conduct 2 | ==================================== 3 | 4 | Our Pledge 5 | ---------- 6 | 7 | We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 8 | 9 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 10 | 11 | 12 | Our Standards 13 | ------------- 14 | 15 | Examples of behavior that contributes to a positive environment for our community include: 16 | 17 | - Demonstrating empathy and kindness toward other people 18 | - Being respectful of differing opinions, viewpoints, and experiences 19 | - Giving and gracefully accepting constructive feedback 20 | - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 21 | - Focusing on what is best not just for us as individuals, but for the overall community 22 | 23 | Examples of unacceptable behavior include: 24 | 25 | - The use of sexualized language or imagery, and sexual attention or 26 | advances of any kind 27 | - Trolling, insulting or derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or email 30 | address, without their explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | Enforcement Responsibilities 35 | ---------------------------- 36 | 37 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. 38 | 39 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. 40 | 41 | 42 | Scope 43 | ----- 44 | 45 | This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. 46 | 47 | 48 | Enforcement 49 | ----------- 50 | 51 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at mail@claudiojolowicz.com. All complaints will be reviewed and investigated promptly and fairly. 52 | 53 | All community leaders are obligated to respect the privacy and security of the reporter of any incident. 54 | 55 | 56 | Enforcement Guidelines 57 | ---------------------- 58 | 59 | Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: 60 | 61 | 62 | 1. Correction 63 | ~~~~~~~~~~~~~ 64 | 65 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 66 | 67 | **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. 68 | 69 | 70 | 2. Warning 71 | ~~~~~~~~~~ 72 | 73 | **Community Impact**: A violation through a single incident or series of actions. 74 | 75 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. 76 | 77 | 78 | 3. Temporary Ban 79 | ~~~~~~~~~~~~~~~~ 80 | 81 | **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. 82 | 83 | **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. 84 | 85 | 86 | 4. Permanent Ban 87 | ~~~~~~~~~~~~~~~~ 88 | 89 | **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 90 | 91 | **Consequence**: A permanent ban from any sort of public interaction within the community. 92 | 93 | 94 | Attribution 95 | ----------- 96 | 97 | This Code of Conduct is adapted from the `Contributor Covenant `__, version 2.0, 98 | available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 99 | 100 | Community Impact Guidelines were inspired by `Mozilla’s code of conduct enforcement ladder `__. 101 | 102 | .. _homepage: https://www.contributor-covenant.org 103 | 104 | For answers to common questions about this code of conduct, see the FAQ at 105 | https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. 106 | -------------------------------------------------------------------------------- /tests/pr/functional/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for functional tests.""" 2 | import json 3 | import os 4 | import secrets 5 | import time 6 | from itertools import count 7 | from pathlib import Path 8 | from typing import Callable 9 | from typing import Iterator 10 | 11 | import github3 12 | import pytest 13 | from github3.pulls import PullRequest 14 | from github3.repos.repo import Repository 15 | 16 | from retrocookie.pr import appname 17 | 18 | 19 | # Wait four seconds before creating each pull request. GitHub requests one 20 | # second between each request for a single user, but we run these tests 21 | # concurrently for two platforms and two event types (pull_request and push). 22 | # See https://docs.github.com/en/rest/guides/best-practices-for-integrators 23 | GITHUB_REQUEST_RATE_SECONDS = 4 24 | 25 | 26 | session_fixture = pytest.fixture(scope="session") 27 | skip_without_token = pytest.mark.skipif( 28 | "TEST_GITHUB_TOKEN" not in os.environ, 29 | reason="TEST_GITHUB_TOKEN is not set", 30 | ) 31 | 32 | 33 | @session_fixture 34 | def token() -> str: 35 | """Return the GitHub API token for functional tests.""" 36 | return os.environ["TEST_GITHUB_TOKEN"] 37 | 38 | 39 | @session_fixture 40 | def github(token: str) -> github3.GitHub: 41 | """Return a GitHub API client.""" 42 | return github3.login(token=token) 43 | 44 | 45 | CreateRepository = Callable[[str], Repository] 46 | 47 | 48 | @session_fixture 49 | def create_repository(github: github3.GitHub) -> CreateRepository: 50 | """Create a GitHub repository.""" 51 | owner: str = github.me().login 52 | random = secrets.token_hex(nbytes=3) 53 | 54 | def _create(slug: str) -> Repository: 55 | name = f"{appname}-test-{slug}-{random}" 56 | github.create_repository( 57 | name, 58 | description=f"Generated repository for {appname} tests", 59 | has_wiki=False, 60 | ) 61 | 62 | for retry in count(start=1): 63 | try: 64 | return github.repository(owner, name) 65 | except github3.exceptions.NotFoundError: 66 | if retry >= 3: 67 | raise 68 | 69 | time.sleep(GITHUB_REQUEST_RATE_SECONDS) 70 | 71 | raise AssertionError("unreachable") 72 | 73 | return _create 74 | 75 | 76 | @session_fixture 77 | def _template(create_repository: CreateRepository) -> Iterator[Repository]: 78 | repository = create_repository("template") 79 | yield repository 80 | repository.delete() 81 | 82 | 83 | @session_fixture 84 | def _project(create_repository: CreateRepository) -> Iterator[Repository]: 85 | repository = create_repository("project") 86 | yield repository 87 | repository.delete() 88 | 89 | 90 | @session_fixture 91 | def template(_template: Repository) -> Repository: 92 | """Create a GitHub repository for a Cookiecutter template.""" 93 | _template.create_file( 94 | path="cookiecutter.json", 95 | message="Create cookiecutter.json", 96 | content=json.dumps({"project": "example"}).encode(), 97 | ) 98 | _template.create_file( 99 | path="{{cookiecutter.project}}/.cookiecutter.json", 100 | message="Create .cookiecutter.json in project", 101 | content=b"{{cookiecutter | jsonify}}", 102 | ) 103 | _template.create_file( 104 | path="{{cookiecutter.project}}/README.md", 105 | message="Create README.md in project", 106 | content=b"# {{cookiecutter.project}}", 107 | ) 108 | return _template 109 | 110 | 111 | @session_fixture 112 | def project(_project: Repository, template: Repository) -> Repository: 113 | """Create a GitHub repository for a Cookiecutter template project.""" 114 | _project.create_file( 115 | path=".cookiecutter.json", 116 | message="Create .cookiecutter.json", 117 | content=json.dumps( 118 | {"_template": f"gh:{template.full_name}", "project": "example"} 119 | ).encode(), 120 | ) 121 | _project.create_file( 122 | path="README.md", 123 | message="Create README", 124 | content=b"# example", 125 | ) 126 | return _project 127 | 128 | 129 | @pytest.fixture 130 | def branch() -> str: 131 | """Return a branch name that is unique for every test case.""" 132 | random = secrets.token_hex(nbytes=3) 133 | return f"branch-{random}" 134 | 135 | 136 | CreatePullRequest = Callable[[Path, str], PullRequest] 137 | 138 | 139 | @pytest.fixture 140 | def create_project_pull_request(project: Repository, branch: str) -> CreatePullRequest: 141 | """Return a pull request for the template project.""" 142 | 143 | def _create(path: Path, content: str) -> PullRequest: 144 | time.sleep(GITHUB_REQUEST_RATE_SECONDS) 145 | 146 | project_default_branch = project.branch(project.default_branch) 147 | project.create_ref(f"refs/heads/{branch}", project_default_branch.commit.sha) 148 | 149 | project_file = project.file_contents(str(path)) 150 | project_file.update( 151 | message=f"Update {path}", 152 | content=content.encode(), 153 | branch=branch, 154 | ) 155 | 156 | return project.create_pull( 157 | title=f"Update {path}", 158 | head=f"{project.owner}:{branch}", 159 | base=project.default_branch, 160 | ) 161 | 162 | return _create 163 | 164 | 165 | FindPullRequest = Callable[[], PullRequest] 166 | 167 | 168 | @pytest.fixture 169 | def find_template_pull_request( 170 | template: Repository, 171 | branch: str, 172 | ) -> FindPullRequest: 173 | """Find the generated pull request in the template repository.""" 174 | 175 | def _find() -> PullRequest: 176 | [template_pull] = template.pull_requests( 177 | head=f"{template.owner}:{appname}/{branch}" 178 | ) 179 | return template_pull 180 | 181 | return _find 182 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | """Tests for core module.""" 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from retrocookie import core 8 | from retrocookie import git 9 | from retrocookie import retrocookie 10 | 11 | from .helpers import append 12 | from .helpers import branch 13 | from .helpers import touch 14 | from .helpers import write 15 | 16 | 17 | def in_template(path: Path) -> Path: 18 | """Prepend the template directory to the path.""" 19 | return "{{cookiecutter.project_slug}}" / path 20 | 21 | 22 | @dataclass 23 | class Example: 24 | """Example data for the test cases.""" 25 | 26 | path: Path = Path("README.md") 27 | text: str = "Lorem Ipsum\n" 28 | 29 | 30 | @pytest.fixture 31 | def example() -> Example: 32 | """Fixture with example data.""" 33 | return Example() 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "text, expected", 38 | [ 39 | ("Lorem Ipsum\n", "Lorem Ipsum\n"), 40 | ( 41 | "This project is called example.\n", 42 | "This project is called {{cookiecutter.project_slug}}.\n", 43 | ), 44 | ( 45 | "python-version: ${{ matrix.python-version }}", 46 | 'python-version: ${{"{{"}} matrix.python-version {{"}}"}}', 47 | ), 48 | ], 49 | ) 50 | def test_rewrite( 51 | cookiecutter_repository: git.Repository, 52 | cookiecutter_instance_repository: git.Repository, 53 | text: str, 54 | expected: str, 55 | example: Example, 56 | ) -> None: 57 | """It rewrites the file contents as expected.""" 58 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 59 | 60 | with branch(instance, "topic", create=True): 61 | append(instance, example.path, text) 62 | 63 | retrocookie( 64 | instance.path, 65 | path=cookiecutter.path, 66 | branch="topic", 67 | create_branch="topic", 68 | ) 69 | 70 | with branch(cookiecutter, "topic"): 71 | assert expected in cookiecutter.read_text(in_template(example.path)) 72 | 73 | 74 | def test_branch( 75 | cookiecutter_repository: git.Repository, 76 | cookiecutter_instance_repository: git.Repository, 77 | example: Example, 78 | ) -> None: 79 | """It creates the specified branch.""" 80 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 81 | 82 | with branch(instance, "topic", create=True): 83 | append(instance, example.path, example.text) 84 | 85 | retrocookie( 86 | instance.path, 87 | path=cookiecutter.path, 88 | branch="topic", 89 | create_branch="just-another-branch", 90 | ) 91 | 92 | with branch(cookiecutter, "just-another-branch"): 93 | assert example.text in cookiecutter.read_text(in_template(example.path)) 94 | 95 | 96 | def test_upstream( 97 | cookiecutter_repository: git.Repository, 98 | cookiecutter_instance_repository: git.Repository, 99 | example: Example, 100 | ) -> None: 101 | """It does not apply changes from the upstream branch.""" 102 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 103 | 104 | another = Path("file.txt") 105 | 106 | with branch(instance, "upstream", create=True): 107 | touch(instance, another) 108 | with branch(instance, "topic", create=True): 109 | append(instance, example.path, example.text) 110 | 111 | retrocookie( 112 | instance.path, 113 | path=cookiecutter.path, 114 | upstream="upstream", 115 | branch="topic", 116 | create_branch="topic", 117 | ) 118 | 119 | with branch(cookiecutter, "topic"): 120 | assert not cookiecutter.exists(another) 121 | assert example.text in cookiecutter.read_text(in_template(example.path)) 122 | 123 | 124 | def test_single_commit( 125 | cookiecutter_repository: git.Repository, 126 | cookiecutter_instance_repository: git.Repository, 127 | example: Example, 128 | ) -> None: 129 | """It cherry-picks the specified commit.""" 130 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 131 | 132 | append(instance, example.path, example.text) 133 | retrocookie(instance.path, ["HEAD"], path=cookiecutter.path) 134 | 135 | assert example.text in cookiecutter.read_text(in_template(example.path)) 136 | 137 | 138 | def test_multiple_commits_sequential( 139 | cookiecutter_repository: git.Repository, 140 | cookiecutter_instance_repository: git.Repository, 141 | ) -> None: 142 | """It cherry-picks the specified commits.""" 143 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 144 | names = "first", "second" 145 | 146 | for name in names: 147 | touch(instance, Path(name)) 148 | 149 | retrocookie(instance.path, ["HEAD~2.."], path=cookiecutter.path) 150 | 151 | for name in names: 152 | path = in_template(Path(name)) 153 | assert cookiecutter.exists(path) 154 | 155 | 156 | def test_multiple_commits_parallel( 157 | cookiecutter_repository: git.Repository, 158 | cookiecutter_instance_repository: git.Repository, 159 | ) -> None: 160 | """It cherry-picks the specified commits.""" 161 | cookiecutter, instance = cookiecutter_repository, cookiecutter_instance_repository 162 | names = "first", "second" 163 | 164 | for name in names: 165 | with branch(instance, name, create=True): 166 | touch(instance, Path(name)) 167 | 168 | retrocookie(instance.path, names, path=cookiecutter.path) 169 | 170 | for name in names: 171 | path = in_template(Path(name)) 172 | assert cookiecutter.exists(path) 173 | 174 | 175 | def test_find_template_directory_fails(tmp_path: Path) -> None: 176 | """It raises an exception when there is no template directory.""" 177 | repository = git.Repository.init(tmp_path) 178 | with pytest.raises(Exception): 179 | core.find_template_directory(repository) 180 | 181 | 182 | def test_load_context_error(cookiecutter_instance_repository: git.Repository) -> None: 183 | """It raises an exception when .cookiecutter.json is not JSON dictionary.""" 184 | write(cookiecutter_instance_repository, Path(".cookiecutter.json"), "[]") 185 | with pytest.raises(TypeError): 186 | core.load_context(cookiecutter_instance_repository, "HEAD") 187 | -------------------------------------------------------------------------------- /src/retrocookie/core.py: -------------------------------------------------------------------------------- 1 | """Core module.""" 2 | import json 3 | import tempfile 4 | from pathlib import Path 5 | from typing import Any 6 | from typing import cast 7 | from typing import Container 8 | from typing import Dict 9 | from typing import Iterable 10 | from typing import Iterator 11 | from typing import List 12 | from typing import Optional 13 | 14 | from retrocookie.compat import contextlib 15 | 16 | from . import git 17 | from .filter import RepositoryFilter 18 | 19 | 20 | @contextlib.contextmanager 21 | def temporary_repository() -> Iterator[git.Repository]: 22 | """Create repository in temporary directory.""" 23 | with tempfile.TemporaryDirectory() as tmpdir: 24 | path = Path(tmpdir) 25 | yield git.Repository.init(path) 26 | 27 | 28 | def find_template_directory(repository: git.Repository) -> Path: 29 | """Locate the subdirectory with the project template.""" 30 | tokens = "{{", "cookiecutter", "}}" 31 | for path in repository.path.iterdir(): 32 | if path.is_dir() and all(x in path.name for x in tokens): 33 | return path.relative_to(repository.path) 34 | raise Exception("cannot find template directory") 35 | 36 | 37 | def load_context(repository: git.Repository, ref: str) -> Dict[str, Any]: 38 | """Load the context from the .cookiecutter.json file.""" 39 | path = Path(".cookiecutter.json") 40 | text = repository.read_text(path, ref=ref) 41 | data = json.loads(text) 42 | if not isinstance(data, dict) or not all(isinstance(key, str) for key in data): 43 | raise TypeError(".cookiecutter.json does not contain an object") 44 | return cast(Dict[str, Any], data) 45 | 46 | 47 | def get_commits( 48 | repository: git.Repository, 49 | commits: Iterable[str], 50 | branch: Optional[str], 51 | upstream: str, 52 | ) -> List[str]: 53 | """Return hashes of the commits to be picked.""" 54 | 55 | def _generate() -> Iterator[str]: 56 | if branch is not None: 57 | yield from repository.parse_revisions(f"{upstream}..{branch}") 58 | 59 | if commits: 60 | yield from repository.parse_revisions(*commits) 61 | 62 | return list(_generate()) 63 | 64 | 65 | def rewrite_commits( 66 | repository: git.Repository, 67 | source: git.Repository, 68 | template_directory: Path, 69 | include_variables: Container[str], 70 | exclude_variables: Container[str], 71 | commits: Iterable[str], 72 | ) -> List[str]: 73 | """Rewrite the repository using template variables.""" 74 | commits = list(commits) 75 | with_parents = source.parse_revisions( 76 | *[f"{commit}^" for commit in commits], *commits 77 | ) 78 | context = load_context(source, commits[-1]) 79 | RepositoryFilter( 80 | repository=repository, 81 | source=source, 82 | commits=with_parents, 83 | template_directory=template_directory, 84 | context=context, 85 | include_variables=include_variables, 86 | exclude_variables=exclude_variables, 87 | ).run() 88 | return [repository.lookup_replacement(commit) for commit in commits] 89 | 90 | 91 | def apply_commits( 92 | repository: git.Repository, 93 | source: git.Repository, 94 | commits: Iterable[str], 95 | create_branch: Optional[str], 96 | ) -> None: 97 | """Create with the specified commits from .""" 98 | repository.fetch_commits(source, *commits) 99 | 100 | if create_branch: 101 | repository.create_branch(create_branch) 102 | repository.switch_branch(create_branch) 103 | 104 | repository.cherrypick(*commits) 105 | 106 | 107 | def retrocookie( 108 | instance_path: Path, 109 | commits: Iterable[str] = (), 110 | *, 111 | path: Optional[Path] = None, 112 | branch: Optional[str] = None, 113 | upstream: Optional[str] = None, 114 | create_branch: Optional[str] = None, 115 | include_variables: Container[str] = (), 116 | exclude_variables: Container[str] = (), 117 | ) -> None: # noqa: DAR101 https://github.com/terrencepreilly/darglint/issues/56 118 | """Import commits from instance repository into template repository. 119 | 120 | This function imports commits from an instance of the Cookiecutter 121 | template, rewriting occurrences of template variables back into the 122 | original templating tags, and prepending the template directory to 123 | filenames. Any tokens with special meaning in Jinja are escaped. 124 | 125 | Args: 126 | instance_path: The source repository, an instance of the Cookiecutter 127 | template with a ``.cookiecutter.json`` file. 128 | 129 | commits: The commits to be imported from the source repository. 130 | 131 | path: The target repository, a Cookiecutter template. This defaults to 132 | the current working directory. 133 | 134 | branch: The name of a branch to be imported from the source repository. 135 | This is equivalent to passing ``["master..branch"]`` in the 136 | ``commits`` parameter. 137 | 138 | upstream: The upstream for ``branch``. This defaults to the value of the 139 | ``init.defaultBranch`` option in git. If the option is not set, 140 | ``master`` is used. 141 | 142 | create_branch: The name of the branch to be created in the target 143 | repository. By default, commits are imported onto the current 144 | branch. 145 | 146 | include_variables: The Cookiecutter variables which should be 147 | rewritten. If this is not specified, all variables from 148 | cookiecutter.json are rewritten. 149 | 150 | exclude_variables: Any Cookiecutter variables which should not be 151 | rewritten. 152 | 153 | """ 154 | if upstream is None: 155 | upstream = git.get_default_branch() 156 | 157 | repository = git.Repository(path) 158 | instance = git.Repository(instance_path) 159 | template_directory = find_template_directory(repository) 160 | commits = get_commits(instance, commits, branch, upstream) 161 | 162 | with temporary_repository() as scratch: 163 | commits = rewrite_commits( 164 | scratch, 165 | instance, 166 | template_directory, 167 | include_variables, 168 | exclude_variables, 169 | commits, 170 | ) 171 | 172 | apply_commits(repository, scratch, commits, create_branch) 173 | -------------------------------------------------------------------------------- /src/retrocookie/pr/base/bus.py: -------------------------------------------------------------------------------- 1 | """Event bus. 2 | 3 | This module defines a :class:`Bus` class for publishing events and subscribing 4 | handlers to them. 5 | 6 | Events must derive from :class:`Event`. Handlers are functions that take the 7 | event as an argument. The function must have type annotations, to allow the bus 8 | to determine which handlers to invoke when an event is published. 9 | 10 | >>> class Timber(Event): 11 | ... pass 12 | ... 13 | >>> def shout(event: Timber): 14 | ... print("Timber!") 15 | ... 16 | >>> bus = Bus() 17 | >>> bus.events.subscribe(shout) 18 | >>> bus.events.publish(Timber()) 19 | Timber! 20 | 21 | Events can be raised on the bus. A raised event is wrapped in an :class:Error 22 | exception, causing the stack to unwind until an error handler is encounted. 23 | 24 | >>> try: 25 | ... with bus.events.errorhandler(): 26 | ... bus.events.raise_(Timber()) 27 | ... except Error: 28 | ... pass 29 | Timber! 30 | 31 | The error handler places the event on the bus, but it does not suppress the 32 | exception. 33 | 34 | Contexts are a special kind of event that has a duration. Handlers for contexts 35 | must return a context manager, which is invoked on entering and exiting the 36 | context. 37 | 38 | >>> class Copy(Context): 39 | ... pass 40 | ... 41 | >>> import contextlib 42 | ... 43 | >>> @contextlib.contextmanager 44 | ... def handle_copy(context: Copy): 45 | ... print("starting copy...") 46 | ... try: 47 | ... yield 48 | ... except Exception: 49 | ... print("copy failed!") 50 | ... raise 51 | ... else: 52 | ... print("copy complete.") 53 | ... 54 | >>> bus.contexts.subscribe(handle_copy) 55 | >>> with bus.contexts.publish(Copy()): 56 | ... print("copy: file a") 57 | ... print("copy: file b") 58 | starting copy... 59 | copy: file a 60 | copy: file b 61 | copy complete. 62 | """ 63 | from collections import defaultdict 64 | from contextlib import ExitStack 65 | from typing import Any 66 | from typing import ContextManager 67 | from typing import Dict 68 | from typing import get_type_hints 69 | from typing import Iterator 70 | from typing import List 71 | from typing import NoReturn 72 | from typing import Tuple 73 | from typing import Type 74 | from typing import TypeVar 75 | from typing import Union 76 | 77 | from retrocookie.compat.contextlib import contextmanager 78 | from retrocookie.pr.base.exceptionhandlers import ExceptionHandler 79 | from retrocookie.pr.base.exceptionhandlers import exceptionhandler 80 | from retrocookie.pr.compat.typing import Protocol 81 | 82 | 83 | __all__ = [ 84 | "Bus", 85 | "Context", 86 | "Error", 87 | "Event", 88 | ] 89 | 90 | 91 | class Event: 92 | """An event can be published on the bus, and subscribed to.""" 93 | 94 | 95 | class Context: 96 | """A context is an event with a duration.""" 97 | 98 | 99 | class Error(Exception): 100 | """An exception for transporting an event.""" 101 | 102 | def __init__(self, event: Event) -> None: 103 | """Initialize.""" 104 | self.event = event 105 | 106 | 107 | E = TypeVar("E", bound=Event, contravariant=True) 108 | C = TypeVar("C", bound=Context, contravariant=True) 109 | 110 | 111 | class _EventHandler(Protocol[E]): 112 | """Handler for an event.""" 113 | 114 | def __call__(self, event: E) -> None: 115 | """Invoke the handler.""" 116 | 117 | 118 | class _ContextHandler(Protocol[C]): 119 | """Handler for a context.""" 120 | 121 | def __call__(self, context: C) -> ContextManager[None]: 122 | """Invoke the handler.""" 123 | 124 | 125 | class _Events: 126 | """Publish and subscribe to events.""" 127 | 128 | def __init__(self) -> None: 129 | """Initialize.""" 130 | self.handlers: Dict[ 131 | Type[Event], 132 | List[_EventHandler[Any]], 133 | ] = defaultdict(list) 134 | 135 | def publish(self, event: Event) -> None: 136 | """Publish an event on the bus.""" 137 | handlers = self.handlers[type(event)] 138 | for handler in handlers: 139 | handler(event) 140 | 141 | def subscribe(self, handler: _EventHandler[E]) -> _EventHandler[E]: 142 | """Subscribe to an event.""" 143 | event_type = next(hint for hint in get_type_hints(handler).values()) 144 | self.handlers[event_type].append(handler) 145 | return handler 146 | 147 | def raise_(self, event: Event) -> NoReturn: 148 | """Raise an event.""" 149 | raise Error(event) 150 | 151 | def reraise( 152 | self, 153 | event: Event, 154 | *, 155 | when: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = (), 156 | ) -> ExceptionHandler: 157 | """Raise an event when an exception is caught.""" 158 | if not isinstance(when, tuple): 159 | when = (when,) 160 | elif not when: 161 | when = (Exception,) 162 | 163 | @exceptionhandler(*when) 164 | def _(exception: BaseException) -> NoReturn: 165 | raise Error(event) 166 | 167 | return _ 168 | 169 | def errorhandler(self) -> ExceptionHandler: 170 | """Publish error events on the bus.""" 171 | 172 | @exceptionhandler 173 | def _(error: Error) -> None: 174 | self.publish(error.event) 175 | 176 | return _ 177 | 178 | 179 | class _Contexts: 180 | """Publish and subscribe to contexts.""" 181 | 182 | def __init__(self) -> None: 183 | """Initialize.""" 184 | self.handlers: Dict[ 185 | Type[Context], 186 | List[_ContextHandler[Any]], 187 | ] = defaultdict(list) 188 | 189 | @contextmanager 190 | def publish(self, event: Context) -> Iterator[None]: 191 | """Publish a context.""" 192 | handlers = self.handlers[type(event)] 193 | with ExitStack() as stack: 194 | for handler in handlers: 195 | stack.enter_context(handler(event)) 196 | yield 197 | 198 | def subscribe(self, handler: _ContextHandler[C]) -> _ContextHandler[C]: 199 | """Subscribe to a context.""" 200 | event_type = next(hint for hint in get_type_hints(handler).values()) 201 | self.handlers[event_type].append(handler) 202 | return handler 203 | 204 | 205 | class Bus: 206 | """Event bus.""" 207 | 208 | def __init__(self) -> None: 209 | """Initialize.""" 210 | self.events = _Events() 211 | self.contexts = _Contexts() 212 | -------------------------------------------------------------------------------- /tests/pr/unit/test_core.py: -------------------------------------------------------------------------------- 1 | """Tests for retrocookie.pr.core.""" 2 | import json 3 | import os 4 | from pathlib import Path 5 | from typing import Iterator 6 | 7 | import pytest 8 | from pytest import MonkeyPatch 9 | 10 | from retrocookie import git 11 | from retrocookie.compat import contextlib 12 | from retrocookie.pr import appname 13 | from retrocookie.pr import events 14 | from retrocookie.pr.base.bus import Bus 15 | from retrocookie.pr.base.exceptionhandlers import nullhandler 16 | from retrocookie.pr.cache import Cache 17 | from retrocookie.pr.core import check_git_version 18 | from retrocookie.pr.core import get_project_name 19 | from retrocookie.pr.core import get_template_name 20 | from retrocookie.pr.core import import_pull_requests 21 | from retrocookie.pr.protocols import github 22 | from retrocookie.pr.repository import Repository 23 | from tests.helpers import write 24 | from tests.pr.unit.fakes.retrocookie import retrocookie 25 | from tests.pr.unit.utils import raises 26 | 27 | 28 | def test_import(bus: Bus, cache: Cache, api: github.API) -> None: 29 | """It imports the pull request.""" 30 | cache.save_token("token") 31 | 32 | template = Repository.load("owner/template", api=api, cache=cache) 33 | project = Repository.load("owner/project", api=api, cache=cache) 34 | 35 | main = project.clone.get_current_branch() 36 | 37 | with cache.worktree(project.clone, main, force=True) as worktree: 38 | write( 39 | worktree, 40 | Path(".cookiecutter.json"), 41 | json.dumps({"_template": "gh:owner/template"}), 42 | ) 43 | 44 | with cache.worktree(project.clone, "readme") as worktree: 45 | write( 46 | worktree, 47 | Path("README.md"), 48 | "# project", 49 | ) 50 | 51 | project.clone.git("push") 52 | project.github.create_pull_request( 53 | head="owner:readme", title="Add README.md", body="", labels=set() 54 | ) 55 | 56 | import_pull_requests( 57 | api=api, 58 | bus=bus, 59 | cache=cache, 60 | errorhandler=nullhandler, 61 | retrocookie=retrocookie, 62 | repository="owner/project", 63 | ) 64 | 65 | [pull_request] = template.github.pull_requests() 66 | 67 | assert pull_request.branch == f"{appname}/readme" 68 | 69 | 70 | def test_import_no_project_name( 71 | api: github.API, 72 | bus: Bus, 73 | cache: Cache, 74 | repository: git.Repository, 75 | ) -> None: 76 | """It derives the project name.""" 77 | cache.save_token("token") 78 | 79 | repository.repo.remotes.create("origin", "git@github.com:owner/project.git") 80 | 81 | with chdir(repository.path): 82 | template = Repository.load("owner/template", api=api, cache=cache) 83 | project = Repository.load("owner/project", api=api, cache=cache) 84 | 85 | main = project.clone.get_current_branch() 86 | 87 | with cache.worktree(project.clone, main, force=True) as worktree: 88 | write( 89 | worktree, 90 | Path(".cookiecutter.json"), 91 | json.dumps({"_template": "gh:owner/template"}), 92 | ) 93 | 94 | with cache.worktree(project.clone, "readme") as worktree: 95 | write( 96 | worktree, 97 | Path("README.md"), 98 | "# project", 99 | ) 100 | 101 | project.clone.git("push") 102 | project.github.create_pull_request( 103 | head="owner:readme", title="Add README.md", body="", labels=set() 104 | ) 105 | 106 | import_pull_requests( 107 | api=api, 108 | bus=bus, 109 | cache=cache, 110 | errorhandler=nullhandler, 111 | retrocookie=retrocookie, 112 | ) 113 | 114 | [pull_request] = template.github.pull_requests() 115 | 116 | assert pull_request.branch == f"{appname}/readme" 117 | 118 | 119 | @pytest.fixture 120 | def repository(tmp_path: Path) -> git.Repository: 121 | """Return a repository.""" 122 | return git.Repository.init(tmp_path / "repository") 123 | 124 | 125 | @contextlib.contextmanager 126 | def chdir(path: Path) -> Iterator[None]: 127 | """Change the directory temporarily.""" 128 | cwd = Path.cwd() 129 | os.chdir(path) 130 | try: 131 | yield 132 | finally: 133 | os.chdir(cwd) 134 | 135 | 136 | @pytest.mark.parametrize( 137 | "remote", 138 | [ 139 | "git@github.com:owner/name.git", 140 | "https://github.com/owner/name.git", 141 | "https://github.com/owner/name", 142 | ], 143 | ) 144 | def test_get_project_name(repository: git.Repository, remote: str, bus: Bus) -> None: 145 | """It returns the project name.""" 146 | repository.repo.remotes.create("origin", remote) 147 | with chdir(repository.path): 148 | assert get_project_name(bus=bus) == "owner/name" 149 | 150 | 151 | @pytest.mark.parametrize( 152 | "remote", 153 | [ 154 | "https://example.com/owner/name", 155 | ], 156 | ) 157 | def test_get_project_name_fails( 158 | repository: git.Repository, remote: str, bus: Bus 159 | ) -> None: 160 | """It fails to return the project name.""" 161 | repository.repo.remotes.create("origin", remote) 162 | with chdir(repository.path): 163 | with raises(events.ProjectNotFound): 164 | get_project_name(bus=bus) 165 | 166 | 167 | @pytest.mark.parametrize( 168 | "remote", 169 | [ 170 | "https://example.com/owner/name", 171 | ], 172 | ) 173 | def test_get_template_name_fails( 174 | bus: Bus, cache: Cache, api: github.API, remote: str 175 | ) -> None: 176 | """It fails to return the template name.""" 177 | project = Repository.load("owner/project", api=api, cache=cache) 178 | main = project.clone.get_current_branch() 179 | 180 | with cache.worktree(project.clone, main, force=True) as worktree: 181 | write( 182 | worktree, 183 | Path(".cookiecutter.json"), 184 | json.dumps({"_template": remote}), 185 | ) 186 | 187 | with raises(events.TemplateNotFound): 188 | get_template_name(project, bus=bus) 189 | 190 | 191 | def test_check_git_version(bus: Bus, monkeypatch: MonkeyPatch) -> None: 192 | """It raises an exception when git is too old.""" 193 | monkeypatch.setattr(git, "version", lambda: git.Version.parse("0.99")) 194 | with raises(events.BadGitVersion): 195 | check_git_version(bus=bus) 196 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | """Nox sessions.""" 2 | import os 3 | import shutil 4 | import sys 5 | from pathlib import Path 6 | from textwrap import dedent 7 | 8 | import nox 9 | 10 | 11 | try: 12 | from nox_poetry import Session 13 | from nox_poetry import session 14 | except ImportError: 15 | message = f"""\ 16 | Nox failed to import the 'nox-poetry' package. 17 | 18 | Please install it using the following command: 19 | 20 | {sys.executable} -m pip install nox-poetry""" 21 | raise SystemExit(dedent(message)) from None 22 | 23 | 24 | package = "retrocookie" 25 | python_versions = ["3.9", "3.8", "3.7"] 26 | nox.needs_version = ">= 2021.6.6" 27 | nox.options.sessions = ( 28 | "pre-commit", 29 | "safety", 30 | "mypy", 31 | "tests", 32 | "typeguard", 33 | "xdoctest", 34 | "docs-build", 35 | ) 36 | 37 | 38 | def activate_virtualenv_in_precommit_hooks(session: Session) -> None: 39 | """Activate virtualenv in hooks installed by pre-commit. 40 | 41 | This function patches git hooks installed by pre-commit to activate the 42 | session's virtual environment. This allows pre-commit to locate hooks in 43 | that environment when invoked from git. 44 | 45 | Args: 46 | session: The Session object. 47 | """ 48 | assert session.bin is not None # noqa: S101 49 | 50 | virtualenv = session.env.get("VIRTUAL_ENV") 51 | if virtualenv is None: 52 | return 53 | 54 | hookdir = Path(".git") / "hooks" 55 | if not hookdir.is_dir(): 56 | return 57 | 58 | for hook in hookdir.iterdir(): 59 | if hook.name.endswith(".sample") or not hook.is_file(): 60 | continue 61 | 62 | text = hook.read_text() 63 | bindir = repr(session.bin)[1:-1] # strip quotes 64 | if not ( 65 | Path("A") == Path("a") and bindir.lower() in text.lower() or bindir in text 66 | ): 67 | continue 68 | 69 | lines = text.splitlines() 70 | if not (lines[0].startswith("#!") and "python" in lines[0].lower()): 71 | continue 72 | 73 | header = dedent( 74 | f"""\ 75 | import os 76 | os.environ["VIRTUAL_ENV"] = {virtualenv!r} 77 | os.environ["PATH"] = os.pathsep.join(( 78 | {session.bin!r}, 79 | os.environ.get("PATH", ""), 80 | )) 81 | """ 82 | ) 83 | 84 | lines.insert(1, header) 85 | hook.write_text("\n".join(lines)) 86 | 87 | 88 | @session(name="pre-commit", python="3.9") 89 | def precommit(session: Session) -> None: 90 | """Lint using pre-commit.""" 91 | args = session.posargs or ["run", "--all-files", "--show-diff-on-failure"] 92 | session.install( 93 | "black", 94 | "darglint", 95 | "flake8", 96 | "flake8-bandit", 97 | "flake8-bugbear", 98 | "flake8-docstrings", 99 | "flake8-rst-docstrings", 100 | "isort", 101 | "pep8-naming", 102 | "pre-commit", 103 | "pre-commit-hooks", 104 | "pyupgrade", 105 | ) 106 | session.run("pre-commit", *args) 107 | if args and args[0] == "install": 108 | activate_virtualenv_in_precommit_hooks(session) 109 | 110 | 111 | @session(python="3.9") 112 | def safety(session: Session) -> None: 113 | """Scan dependencies for insecure packages.""" 114 | requirements = session.poetry.export_requirements() 115 | session.install("safety") 116 | session.run("safety", "check", "--full-report", f"--file={requirements}") 117 | 118 | 119 | @session(python=python_versions) 120 | def mypy(session: Session) -> None: 121 | """Type-check using mypy.""" 122 | args = session.posargs or ["src", "tests", "docs/conf.py"] 123 | session.install(".[pr]") 124 | session.install("mypy", "pytest", "types-requests") 125 | session.run("mypy", *args) 126 | if not session.posargs: 127 | session.run("mypy", f"--python-executable={sys.executable}", "noxfile.py") 128 | 129 | 130 | @session(python=python_versions) 131 | def tests(session: Session) -> None: 132 | """Run the test suite.""" 133 | session.install(".[pr]") 134 | session.install("cookiecutter", "coverage[toml]", "pygments", "pytest") 135 | try: 136 | session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs) 137 | finally: 138 | if session.interactive: 139 | session.notify("coverage", posargs=[]) 140 | 141 | 142 | @session 143 | def coverage(session: Session) -> None: 144 | """Produce the coverage report.""" 145 | args = session.posargs or ["report"] 146 | 147 | session.install("coverage[toml]") 148 | 149 | if not session.posargs and any(Path().glob(".coverage.*")): 150 | session.run("coverage", "combine") 151 | 152 | session.run("coverage", *args) 153 | 154 | 155 | @session(python=python_versions) 156 | def typeguard(session: Session) -> None: 157 | """Runtime type checking using Typeguard.""" 158 | session.install(".[pr]") 159 | session.install("cookiecutter", "pytest", "typeguard", "pygments") 160 | session.run("pytest", f"--typeguard-packages={package}", *session.posargs) 161 | 162 | 163 | @session(python=python_versions) 164 | def xdoctest(session: Session) -> None: 165 | """Run examples with xdoctest.""" 166 | if session.posargs: 167 | args = [package, *session.posargs] 168 | else: 169 | args = [f"--modname={package}", "--command=all"] 170 | if "FORCE_COLOR" in os.environ: 171 | args.append("--colored=1") 172 | 173 | session.install(".[pr]") 174 | session.install("xdoctest[colors]") 175 | session.run("python", "-m", "xdoctest", *args) 176 | 177 | 178 | @session(name="docs-build", python="3.8") 179 | def docs_build(session: Session) -> None: 180 | """Build the documentation.""" 181 | args = session.posargs or ["docs", "docs/_build"] 182 | if not session.posargs and "FORCE_COLOR" in os.environ: 183 | args.insert(0, "--color") 184 | 185 | session.install(".[pr]") 186 | session.install("sphinx", "sphinx-click", "furo") 187 | 188 | build_dir = Path("docs", "_build") 189 | if build_dir.exists(): 190 | shutil.rmtree(build_dir) 191 | 192 | session.run("sphinx-build", *args) 193 | 194 | 195 | @session(python="3.8") 196 | def docs(session: Session) -> None: 197 | """Build and serve the documentation with live reloading on file changes.""" 198 | args = session.posargs or ["--open-browser", "docs", "docs/_build"] 199 | session.install(".[pr]") 200 | session.install("sphinx", "sphinx-autobuild", "sphinx-click", "furo") 201 | 202 | build_dir = Path("docs", "_build") 203 | if build_dir.exists(): 204 | shutil.rmtree(build_dir) 205 | 206 | session.run("sphinx-autobuild", *args) 207 | -------------------------------------------------------------------------------- /src/retrocookie/pr/__main__.py: -------------------------------------------------------------------------------- 1 | """Command-line interface.""" 2 | import contextlib 3 | import subprocess # noqa: S404 4 | import webbrowser 5 | from pathlib import Path 6 | from typing import Collection 7 | from typing import Iterator 8 | from typing import List 9 | from typing import NoReturn 10 | from typing import Optional 11 | 12 | import appdirs 13 | import click 14 | 15 | from retrocookie.compat.contextlib import contextmanager 16 | from retrocookie.pr import appname 17 | from retrocookie.pr import console 18 | from retrocookie.pr import events 19 | from retrocookie.pr.adapters import github 20 | from retrocookie.pr.adapters.retrocookie import retrocookie 21 | from retrocookie.pr.base.bus import Bus 22 | from retrocookie.pr.base.bus import Error 23 | from retrocookie.pr.base.exceptionhandlers import ExceptionHandler 24 | from retrocookie.pr.base.exceptionhandlers import exceptionhandler 25 | from retrocookie.pr.base.exceptionhandlers import nullhandler 26 | from retrocookie.pr.cache import Cache 27 | from retrocookie.pr.core import import_pull_requests 28 | 29 | 30 | def get_token(cache: Cache) -> str: 31 | """Obtain the token from the cache or by prompting the user.""" 32 | with contextlib.suppress(FileNotFoundError): 33 | return cache.load_token() 34 | 35 | token: str = click.prompt("Please enter a personal access token for the GitHub API") 36 | cache.save_token(token) 37 | return token 38 | 39 | 40 | def giterrorhandler(*, bus: Bus) -> ExceptionHandler: 41 | """Handle errors from a git subprocess.""" 42 | 43 | @exceptionhandler 44 | def _(error: subprocess.CalledProcessError) -> None: 45 | if isinstance(error.cmd, list) and len(error.cmd) >= 2: 46 | command = Path(error.cmd[0]).name 47 | subcommand, *args = error.cmd[1:] 48 | if command == "git": 49 | bus.events.raise_( 50 | events.GitFailed( 51 | subcommand, args, error.returncode, error.stdout, error.stderr 52 | ) 53 | ) 54 | return None 55 | 56 | return _ 57 | 58 | 59 | @exceptionhandler 60 | def exithandler(exception: Error) -> NoReturn: 61 | """Exit on application errors.""" 62 | raise SystemExit(1) 63 | 64 | 65 | def collect(errors: List[Error]) -> ExceptionHandler: 66 | """Collect errors.""" 67 | 68 | @exceptionhandler 69 | def _(error: Error) -> bool: 70 | errors.append(error) 71 | return True 72 | 73 | return _ 74 | 75 | 76 | def register_pull_request_viewer(*, bus: Bus) -> None: 77 | """Register an event handler for viewing pull requests in a browser.""" 78 | 79 | @bus.events.subscribe 80 | def _(event: events.PullRequestCreated) -> None: 81 | webbrowser.open(event.template_pull.html_url) 82 | 83 | @bus.contexts.subscribe 84 | @contextmanager 85 | def _(event: events.UpdatePullRequest) -> Iterator[None]: 86 | yield 87 | webbrowser.open(event.template_pull.html_url) 88 | 89 | 90 | @click.command() 91 | @click.argument("pull_requests", metavar="pull-request", nargs=-1) 92 | @click.option( 93 | "-R", 94 | "--repository", 95 | help="GitHub repository containing the pull requests.", 96 | ) 97 | @click.option( 98 | "--token", 99 | metavar="token", 100 | envvar="GITHUB_TOKEN", 101 | help="GitHub token", 102 | ) 103 | @click.option("-u", "--user", help="Import pull requests opened by this GitHub user.") 104 | @click.option("--all", is_flag=True, help="Import all pull requests.") 105 | @click.option("--force", is_flag=True, help="Overwrite existing pull requests.") 106 | @click.option( 107 | "--keep-going", 108 | "-k", 109 | is_flag=True, 110 | help="Keep going when some pull requests cannot be imported.", 111 | ) 112 | @click.option( 113 | "--open", is_flag=True, help="View imported pull requests in a web browser." 114 | ) 115 | @click.option("--debug", is_flag=True, help="Print tracebacks for all errors.") 116 | @click.version_option() 117 | def main( 118 | pull_requests: Collection[str], 119 | repository: Optional[str], 120 | token: Optional[str], 121 | user: Optional[str], 122 | all: bool, 123 | force: bool, 124 | keep_going: bool, 125 | open: bool, 126 | debug: bool, 127 | ) -> None: 128 | """Import pull requests from generated projects into their template. 129 | 130 | Command-line arguments specify pull requests to import, by number or by 131 | branch. Alternatively, select pull requests by specifying the user that 132 | opened them, via the --user option. Use the --all option to import all open 133 | pull requests in the generated project. 134 | 135 | Use the --repository option to specify the GitHub repository of the 136 | generated project from which the pull requests should be imported. Provide 137 | the full name of the repository on GitHub in the form owner/name. The owner 138 | can be omitted if the repository is owned by the authenticated user. This 139 | option can be omitted when the command is invoked from a local clone. 140 | 141 | Update previously imported pull requests by specifying --force. By default, 142 | this program refuses to overwrite existing pull requests. 143 | 144 | The program needs a personal access token (PAT) to access the GitHub API, 145 | and prompts for the token when invoked for the first time. Subsequent 146 | invocations read the token from the application cache. Alternatively, you 147 | can specify the token using the --token option or the GITHUB_TOKEN 148 | environment variable; both of these methods bypass the cache. 149 | """ 150 | if not (pull_requests or user or all): 151 | raise click.UsageError("Please specify which pull requests to import.") 152 | 153 | if all and pull_requests: 154 | raise click.UsageError("Do not specify --all with individual pull requests.") 155 | 156 | bus = Bus() 157 | console.start(bus=bus) 158 | 159 | if open: # pragma: no cover 160 | register_pull_request_viewer(bus=bus) 161 | 162 | errors: List[Error] = [] 163 | 164 | corehandler = ( 165 | giterrorhandler(bus=bus) 166 | >> github.errorhandler(bus=bus) 167 | >> bus.events.errorhandler() 168 | ) 169 | mainerrorhandler = corehandler >> (exithandler if not debug else nullhandler) 170 | nestederrorhandler = (corehandler >> collect(errors)) if keep_going else nullhandler 171 | 172 | with mainerrorhandler: 173 | user_cache_dir = appdirs.user_cache_dir(appauthor=appname, appname=appname) 174 | cache = Cache(Path(user_cache_dir)) 175 | api = github.API.login(token or get_token(cache), bus=bus) 176 | 177 | import_pull_requests( 178 | pull_requests, 179 | api=api, 180 | bus=bus, 181 | cache=cache, 182 | errorhandler=nestederrorhandler, 183 | retrocookie=retrocookie, 184 | repository=repository, 185 | user=user, 186 | force=force, 187 | ) 188 | 189 | if errors: 190 | raise SystemExit(1) # pragma: no cover 191 | 192 | 193 | if __name__ == "__main__": # pragma: no cover 194 | main() 195 | -------------------------------------------------------------------------------- /src/retrocookie/pr/adapters/github.py: -------------------------------------------------------------------------------- 1 | """GitHub interface.""" 2 | from __future__ import annotations 3 | 4 | from typing import AbstractSet 5 | from typing import cast 6 | from typing import Iterable 7 | from typing import Optional 8 | from typing import Set 9 | from typing import Union 10 | 11 | import github3.exceptions 12 | import requests 13 | import tenacity 14 | 15 | from retrocookie.pr import events 16 | from retrocookie.pr.base.bus import Bus 17 | from retrocookie.pr.base.exceptionhandlers import ExceptionHandler 18 | from retrocookie.pr.base.exceptionhandlers import exceptionhandler 19 | 20 | 21 | class PullRequest: 22 | """Pull request in a GitHub repository.""" 23 | 24 | def __init__( 25 | self, 26 | pull_request: Union[ 27 | github3.pulls.PullRequest, 28 | github3.pulls.ShortPullRequest, 29 | ], 30 | ) -> None: 31 | """Initialize.""" 32 | self._pull_request = pull_request 33 | 34 | @property 35 | def number(self) -> int: 36 | """Number of the pull request.""" 37 | return cast(int, self._pull_request.number) 38 | 39 | @property 40 | def title(self) -> str: 41 | """Title of the pull request description.""" 42 | return cast(str, self._pull_request.title) 43 | 44 | @property 45 | def body(self) -> Optional[str]: 46 | """Body of the pull request description.""" 47 | return cast(Optional[str], self._pull_request.body) 48 | 49 | @property 50 | def branch(self) -> str: 51 | """Branch merged by the pull request.""" 52 | return cast(str, self._pull_request.head.ref) 53 | 54 | @property 55 | def user(self) -> str: 56 | """Login of the user that opened the pull request.""" 57 | return cast(str, self._pull_request.user.login) 58 | 59 | @property 60 | def html_url(self) -> str: 61 | """URL for viewing the pull request in a browser.""" 62 | return cast(str, self._pull_request.html_url) 63 | 64 | @property 65 | def labels(self) -> Set[str]: 66 | """The labels associated with the pull request.""" 67 | issue = self._pull_request.issue() 68 | return {label.name for label in issue.labels()} 69 | 70 | def update(self, title: str, body: Optional[str], labels: AbstractSet[str]) -> None: 71 | """Update the pull request.""" 72 | self._pull_request.update(title=title, body=body) 73 | self._pull_request.issue().replace_labels([*labels]) 74 | 75 | 76 | class Repository: 77 | """GitHub Repository.""" 78 | 79 | def __init__( 80 | self, repository: github3.repos.repo.Repository, *, token: str, bus: Bus 81 | ) -> None: 82 | """Initialize.""" 83 | self._repository = repository 84 | self._token = token 85 | self._bus = bus 86 | 87 | @property 88 | def owner(self) -> str: 89 | """The user who owns the repository.""" 90 | return cast(str, self._repository.owner.login) 91 | 92 | @property 93 | def full_name(self) -> str: 94 | """The full name of the repository.""" 95 | return cast(str, self._repository.full_name) 96 | 97 | @property 98 | def clone_url(self) -> str: 99 | """URL for cloning the repository.""" 100 | return cast(str, self._repository.clone_url) 101 | 102 | @property 103 | def push_url(self) -> str: 104 | """URL for pushing to the repository.""" 105 | return f"https://{self._token}:x-oauth-basic@github.com/{self.full_name}.git" 106 | 107 | @property 108 | def default_branch(self) -> str: 109 | """The default branch of the repository.""" 110 | return cast(str, self._repository.default_branch) 111 | 112 | def pull_request(self, number: int) -> PullRequest: 113 | """Return pull request identified by the given number.""" 114 | with self._bus.events.reraise( 115 | events.PullRequestNotFound(str(number)), 116 | when=github3.exceptions.NotFoundError, 117 | ): 118 | pull_request = self._repository.pull_request(number) 119 | 120 | return PullRequest(pull_request) 121 | 122 | def pull_request_by_head(self, head: str) -> Optional[PullRequest]: 123 | """Return pull request for the given head.""" 124 | for pull_request in self._repository.pull_requests(head=head): 125 | return PullRequest(pull_request) 126 | return None 127 | 128 | def pull_requests(self) -> Iterable[PullRequest]: 129 | """Pull requests open in the repository.""" 130 | for pull_request in self._repository.pull_requests(state="open"): 131 | yield PullRequest(pull_request) 132 | 133 | def create_pull_request( 134 | self, 135 | *, 136 | head: str, 137 | title: str, 138 | body: Optional[str], 139 | labels: AbstractSet[str], 140 | ) -> PullRequest: 141 | """Create a pull request.""" 142 | pull_request = self._repository.create_pull( 143 | title=title, 144 | body=body, 145 | head=head, 146 | base=self._repository.default_branch, 147 | ) 148 | 149 | # The GitHub API sometimes responds with 404 when the issue for 150 | # a newly created pull request is requested. 151 | for attempt in tenacity.Retrying( 152 | reraise=True, 153 | stop=tenacity.stop_after_attempt(3), # type: ignore[attr-defined] 154 | wait=tenacity.wait_fixed(3), # type: ignore[attr-defined] 155 | ): 156 | with attempt: 157 | pull_request.issue().add_labels(*labels) 158 | 159 | return PullRequest(pull_request) 160 | 161 | 162 | class API: 163 | """GitHub API.""" 164 | 165 | def __init__(self, github: github3.GitHub, *, token: str, bus: Bus) -> None: 166 | """Initialize.""" 167 | self._github = github 168 | self._token = token 169 | self._bus = bus 170 | 171 | @classmethod 172 | def login(cls, token: str, *, bus: Bus) -> API: 173 | """Login to GitHub.""" 174 | github = github3.login(token=token) 175 | return cls(github, token=token, bus=bus) 176 | 177 | @property 178 | def me(self) -> str: 179 | """Return the login of the authenticated user.""" 180 | return cast(str, self._github.me().login) 181 | 182 | def repository(self, owner: str, name: str) -> Repository: 183 | """Return the repository with the given owner and name.""" 184 | fullname = "/".join([owner, name]) 185 | with self._bus.events.reraise( 186 | events.RepositoryNotFound(fullname), 187 | when=github3.exceptions.NotFoundError, 188 | ): 189 | _repository = self._github.repository(owner, name) 190 | 191 | return Repository(_repository, token=self._token, bus=self._bus) 192 | 193 | 194 | def errorhandler(*, bus: Bus) -> ExceptionHandler: 195 | """Handle GitHub error responses.""" 196 | 197 | @exceptionhandler 198 | def _handler1(error: github3.exceptions.GitHubError) -> None: 199 | if error.response is not None and error.response.request is not None: 200 | bus.events.raise_( 201 | events.GitHubError( 202 | error.response.request.url, 203 | error.response.request.method, 204 | error.code, 205 | error.message, # noqa: B306 206 | error.errors, 207 | ) 208 | ) 209 | return None 210 | 211 | @exceptionhandler 212 | def _handler2(error: github3.exceptions.ConnectionError) -> None: 213 | for arg in error.args: 214 | if isinstance(arg, requests.RequestException): 215 | bus.events.raise_( 216 | events.ConnectionError( 217 | arg.request.url, 218 | arg.request.method, 219 | str(error), 220 | ) 221 | ) 222 | return None 223 | 224 | return _handler1 >> _handler2 225 | -------------------------------------------------------------------------------- /src/retrocookie/pr/console.py: -------------------------------------------------------------------------------- 1 | """Application console.""" 2 | from typing import ContextManager 3 | from typing import Iterator 4 | 5 | import rich.console 6 | import rich.theme 7 | import rich.traceback 8 | from rich.markup import escape 9 | 10 | from retrocookie.compat import contextlib 11 | from retrocookie.pr import events 12 | from retrocookie.pr.base.bus import Bus 13 | from retrocookie.pr.compat import shlex 14 | 15 | 16 | class Console: 17 | """Console.""" 18 | 19 | def __init__(self) -> None: 20 | """Create the console.""" 21 | theme = rich.theme.Theme( 22 | { 23 | "success": "green", 24 | "failure": "red", 25 | "hint": "yellow", 26 | "repository": "blue", 27 | "pull": "bold bright_white", 28 | "oldpull": "bright_black", 29 | "title": "cyan", 30 | "command": "bold", 31 | "status": "bold", 32 | "stdout": "bright_black", 33 | "stderr": "red", 34 | "version": "green", 35 | "http.method": "blue", 36 | "http.status": "bright_blue", 37 | } 38 | ) 39 | self.console = rich.console.Console(stderr=True, highlight=False, theme=theme) 40 | 41 | def success(self, message: str) -> None: 42 | """Display a success message.""" 43 | self.console.print(f"[success]✓[/] {message}") 44 | 45 | def failure(self, message: str) -> None: 46 | """Display a failure message.""" 47 | self.console.print(f"[failure]⨯[/] {message}") 48 | 49 | def hint(self, message: str) -> None: 50 | """Display a hint for the user.""" 51 | self.console.print(f"[hint]☞ {message}[/]") 52 | 53 | def highlight(self, text: str) -> None: 54 | """Output a text with highlighting.""" 55 | console = rich.console.Console(stderr=True) 56 | console.print(text) 57 | 58 | @contextlib.contextmanager 59 | def progress(self, message: str) -> Iterator[None]: 60 | """Display a status message with a spinner.""" 61 | try: 62 | with self.console.status(message): 63 | yield 64 | except Exception: 65 | self.failure(message) 66 | raise 67 | 68 | 69 | def start(*, bus: Bus) -> None: 70 | """Create the console and subscribe to events.""" 71 | rich.traceback.install() 72 | 73 | console = Console() 74 | 75 | _subscribe(console, bus) 76 | 77 | 78 | def _subscribe(console: Console, bus: Bus) -> None: # noqa: C901 79 | @bus.events.subscribe 80 | def _(event: events.GitNotFound) -> None: 81 | console.failure("git not found") 82 | console.hint("Do you have git installed?") 83 | console.hint("Is git on your PATH?") 84 | console.hint("Can you run [command]git version[/]?") 85 | 86 | @bus.events.subscribe 87 | def _(event: events.BadGitVersion) -> None: 88 | console.failure(f"This program requires git [version]{event.expected}[/].") 89 | console.hint(f"You have git [version]{event.version}[/], which is too old.") 90 | 91 | @bus.events.subscribe 92 | def _(event: events.GitFailed) -> None: 93 | def _lines() -> Iterator[str]: 94 | yield ( 95 | f"The command [command]git {event.command}[/] failed" 96 | f" with status [status]{event.status}[/]." 97 | ) 98 | 99 | command = escape(shlex.join(["git", event.command, *event.options])) 100 | 101 | yield "" 102 | yield f"❯ {command}" 103 | 104 | for stream in ["stdout", "stderr"]: 105 | output = getattr(event, stream) 106 | if output: 107 | yield "" 108 | for line in output.splitlines(): 109 | line = escape(line) 110 | yield f"[{stream}]{line}[/]" 111 | 112 | yield "" 113 | 114 | console.failure("\n".join(_lines())) 115 | 116 | if event.command == "cherry-pick": 117 | console.hint("Looks like the changes did not apply cleanly.") 118 | console.hint( 119 | "Try running [command]retrocookie[/] on a local clone instead." 120 | ) 121 | 122 | @bus.events.subscribe 123 | def _(event: events.GitHubError) -> None: 124 | def _lines() -> Iterator[str]: 125 | yield "The GitHub API returned an error response." 126 | yield f"[http.method]{event.method}[/] [repr.url]{event.url}[/]" 127 | yield f"[http.status]{event.code}[/] {event.message}" 128 | yield from event.errors 129 | 130 | console.failure("\n".join(_lines())) 131 | 132 | @bus.events.subscribe 133 | def _(event: events.ConnectionError) -> None: 134 | def _lines() -> Iterator[str]: 135 | yield "Connection to GitHub API could not be established." 136 | yield f"[http.method]{event.method}[/] [repr.url]{event.url}[/]" 137 | 138 | console.failure("\n".join(_lines())) 139 | console.highlight(event.error) 140 | 141 | @bus.events.subscribe 142 | def _(event: events.ProjectNotFound) -> None: 143 | console.failure("Project not found") 144 | console.hint("Does the current directory contain a repository?") 145 | console.hint("Is the repository on GitHub?") 146 | 147 | @bus.events.subscribe 148 | def _(event: events.TemplateNotFound) -> None: 149 | console.failure( 150 | f"Project template for [repository]{event.project.full_name}[/] not found" 151 | ) 152 | console.hint("Does the project contain a .cookiecutter.json file?") 153 | console.hint( 154 | 'Does the .cookiecutter.json contain the repository URL under "_template"?' 155 | ) 156 | console.hint("Is the template repository on GitHub?") 157 | 158 | @bus.contexts.subscribe # type: ignore[arg-type] 159 | def _(event: events.LoadProject) -> ContextManager[None]: 160 | return console.progress(f"Loading project [repository]{event.repository}[/]") 161 | 162 | @bus.contexts.subscribe # type: ignore[arg-type] 163 | def _(event: events.LoadTemplate) -> ContextManager[None]: 164 | return console.progress(f"Loading template [repository]{event.repository}[/]") 165 | 166 | @bus.events.subscribe 167 | def _(event: events.RepositoryNotFound) -> None: 168 | console.failure(f"Repository [repository]{event.repository}[/] not found") 169 | 170 | @bus.events.subscribe 171 | def _(event: events.PullRequestNotFound) -> None: 172 | spec = event.pull_request 173 | if spec.isnumeric(): 174 | spec = f"#{spec}" 175 | console.failure(f"Pull request [pull]{spec}[/] not found") 176 | 177 | @bus.events.subscribe 178 | def _(event: events.PullRequestAlreadyExists) -> None: 179 | console.failure( 180 | "Already imported" 181 | f" [pull]#{event.project_pull.number}[/]" 182 | f" [title]{escape(event.project_pull.title)}[/]" 183 | ) 184 | console.hint(f"See [pull]#{event.template_pull.number}[/]") 185 | console.hint("Use --force to update an existing pull request.") 186 | 187 | @bus.contexts.subscribe # type: ignore[arg-type] 188 | def _(event: events.CreatePullRequest) -> ContextManager[None]: 189 | return console.progress( 190 | f"[pull]#{event.project_pull.number}[/]" 191 | f" [title]{escape(event.project_pull.title)}[/]" 192 | ) 193 | 194 | @bus.contexts.subscribe 195 | @contextlib.contextmanager 196 | def _(event: events.UpdatePullRequest) -> Iterator[None]: 197 | with console.progress( 198 | f"[pull]#{event.project_pull.number}[/]" 199 | f" [title]{escape(event.project_pull.title)}[/]" 200 | f" [oldpull]#{event.template_pull.number}[/]" 201 | ): 202 | yield 203 | 204 | console.success( 205 | f"[pull]#{event.project_pull.number}[/]" 206 | f" [title]{escape(event.project_pull.title)}[/]" 207 | f" [pull]#{event.template_pull.number}[/]" 208 | ) 209 | 210 | @bus.events.subscribe 211 | def _(event: events.PullRequestCreated) -> None: 212 | console.success( 213 | f"[pull]#{event.project_pull.number}[/]" 214 | f" [title]{escape(event.project_pull.title)}[/]" 215 | f" [pull]#{event.template_pull.number}[/]" 216 | ) 217 | -------------------------------------------------------------------------------- /tests/test_git.py: -------------------------------------------------------------------------------- 1 | """Tests for git interface.""" 2 | from pathlib import Path 3 | from typing import Dict 4 | 5 | import pytest 6 | from pytest import MonkeyPatch 7 | 8 | from retrocookie import git 9 | from retrocookie.utils import chdir 10 | 11 | from .helpers import branch 12 | from .helpers import commit 13 | from .helpers import touch 14 | from .helpers import write 15 | 16 | 17 | @pytest.fixture 18 | def repository(tmp_path: Path) -> git.Repository: 19 | """Initialize repository in a temporary directory.""" 20 | return git.Repository.init(tmp_path / "repository") 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "config,expected", 25 | [ 26 | ({}, "master"), 27 | ({"init.defaultBranch": "main"}, "main"), 28 | ], 29 | ) 30 | def test_get_default_branch( 31 | config: Dict[str, str], expected: str, monkeypatch: MonkeyPatch 32 | ) -> None: 33 | """It returns the default branch.""" 34 | monkeypatch.setattr("pygit2.Config.get_global_config", dict) 35 | monkeypatch.setattr("pygit2.Config.get_system_config", lambda: config) 36 | 37 | assert expected == git.get_default_branch() 38 | 39 | 40 | def test_commit(repository: git.Repository) -> None: 41 | """It creates a commit.""" 42 | repository.commit("Empty") 43 | 44 | 45 | def test_read_text(repository: git.Repository) -> None: 46 | """It returns the file contents.""" 47 | path = repository.path / "README.md" 48 | path.write_text("# example\n") 49 | 50 | repository.add(path) 51 | repository.commit(f"Add {path}") 52 | 53 | assert path.read_text() == repository.read_text(path) 54 | 55 | 56 | def test_exists_true(repository: git.Repository) -> None: 57 | """It returns True if the file exists.""" 58 | path = Path("README.md") 59 | touch(repository, path) 60 | 61 | assert repository.exists(path) 62 | 63 | 64 | def test_exists_false(repository: git.Repository) -> None: 65 | """It returns False if the file does not exist.""" 66 | path = Path("README.md") 67 | commit(repository) # ensure HEAD exists 68 | 69 | assert not repository.exists(path) 70 | 71 | 72 | def test_cherrypick(repository: git.Repository) -> None: 73 | """It applies a commit from another branch.""" 74 | # Add README on master. 75 | readme = repository.path / "README.md" 76 | readme.write_text("# example\n") 77 | 78 | repository.add() 79 | repository.commit("Initial") 80 | 81 | # Add .gitignore on topic. 82 | with branch(repository, "topic", create=True): 83 | ignore = repository.path / ".gitignore" 84 | ignore_text = "*.pyc\n" 85 | ignore.write_text(ignore_text) 86 | 87 | repository.add() 88 | repository.commit("Ignore *.pyc") 89 | 90 | # Append to README, on master. 91 | with readme.open(mode="a") as io: 92 | message = "\nHello, world!\n" 93 | io.write(message) 94 | 95 | repository.add() 96 | repository.commit("Update") 97 | 98 | # Cherry-pick the .gitignore onto master. 99 | repository.cherrypick("topic") 100 | 101 | assert ignore_text == repository.read_text(ignore) 102 | assert message in repository.read_text(readme) 103 | 104 | 105 | def test_cherrypick_index(repository: git.Repository) -> None: 106 | """It updates the index from the cherry-pick.""" 107 | readme, install = map(Path, ("README", "INSTALL")) 108 | 109 | touch(repository, readme) 110 | 111 | with branch(repository, "install", create=True): 112 | touch(repository, install) 113 | 114 | repository.cherrypick("install") 115 | 116 | repository.repo.index.read() # refresh stale index 117 | assert "INSTALL" in {e.path for e in repository.repo.index} 118 | 119 | 120 | def test_cherrypick_worktree(repository: git.Repository) -> None: 121 | """It updates the worktree from the cherry-pick.""" 122 | readme, install = map(Path, ("README", "INSTALL")) 123 | 124 | touch(repository, readme) 125 | 126 | with branch(repository, "install", create=True): 127 | touch(repository, install) 128 | 129 | repository.cherrypick("install") 130 | 131 | assert (repository.path / install).exists() 132 | 133 | 134 | def test_cherrypick_conflict(repository: git.Repository) -> None: 135 | """It raises an exception if the cherry-pick results in conflicts.""" 136 | path = Path("README") 137 | 138 | touch(repository, path) 139 | 140 | with branch(repository, "topic", create=True): 141 | write(repository, path, "a") 142 | 143 | write(repository, path, "b") 144 | 145 | with pytest.raises(Exception, match="conflict"): 146 | repository.cherrypick("topic") 147 | 148 | 149 | def test_parse_revisions(repository: git.Repository) -> None: 150 | """It returns the hashes on topic for the range expression ``..topic``.""" 151 | commit(repository) 152 | 153 | with branch(repository, "topic", create=True): 154 | expected = [commit(repository), commit(repository)] 155 | 156 | revisions = repository.parse_revisions("..topic") 157 | 158 | assert revisions == expected 159 | 160 | 161 | def test_lookup_replacement(repository: git.Repository) -> None: 162 | """It returns the replacement ref for a replaced ref.""" 163 | first, second = commit(repository), commit(repository) 164 | 165 | repository.git("replace", second, first) 166 | 167 | assert first == repository.lookup_replacement(second) 168 | 169 | 170 | def test_fetch_relative_path(repository: git.Repository) -> None: 171 | """It fetches from a relative path.""" 172 | source_path = Path("..") / "another" 173 | path = Path("README") 174 | 175 | with chdir(repository.path): 176 | source = git.Repository.init(source_path) 177 | commit = touch(source, path) 178 | 179 | repository.fetch_commits(source, commit) 180 | 181 | assert repository.read_text(path, ref=commit) == "" 182 | 183 | 184 | def test_add_worktree(repository: git.Repository, tmp_path: Path) -> None: 185 | """It creates a worktree.""" 186 | commit(repository) 187 | 188 | path = tmp_path / "worktree" 189 | repository.add_worktree("branch", path) 190 | 191 | assert path.exists() 192 | 193 | 194 | @pytest.mark.parametrize("force", [False, True]) 195 | def test_remove_worktree( 196 | repository: git.Repository, tmp_path: Path, force: bool 197 | ) -> None: 198 | """It removes the worktree.""" 199 | commit(repository) 200 | 201 | path = tmp_path / "worktree" 202 | repository.add_worktree("branch", path) 203 | repository.remove_worktree(path, force=force) 204 | 205 | assert not path.exists() 206 | 207 | 208 | def test_worktree(repository: git.Repository, tmp_path: Path) -> None: 209 | """It creates and removes a worktree.""" 210 | commit(repository) 211 | 212 | path = tmp_path / "worktree" 213 | with repository.worktree("branch", path): 214 | assert path.exists() 215 | 216 | assert not path.exists() 217 | 218 | 219 | @pytest.mark.parametrize( 220 | "version,expected", 221 | [ 222 | # fmt: off 223 | ("0.99", git.Version(major=0, minor=99, patch=0)), 224 | ("0.99.9n", git.Version(major=0, minor=99, patch=9)), 225 | ("1.0rc6", git.Version(major=1, minor=0, patch=0)), 226 | ("1.0.0", git.Version(major=1, minor=0, patch=0)), 227 | ("1.0.0b", git.Version(major=1, minor=0, patch=0)), 228 | ("1.8.5.6", git.Version(major=1, minor=8, patch=5)), 229 | ("1.9-rc2", git.Version(major=1, minor=9, patch=0)), 230 | ("2.4.12", git.Version(major=2, minor=4, patch=12)), 231 | ("2.29.2.windows.3", git.Version(major=2, minor=29, patch=2)), # GitHub Actions 232 | ("2.30.0", git.Version(major=2, minor=30, patch=0)), 233 | ("2.30.0-rc0", git.Version(major=2, minor=30, patch=0)), 234 | ("2.30.0-rc2", git.Version(major=2, minor=30, patch=0)), 235 | # fmt: on 236 | ], 237 | ) 238 | def test_valid_version(version: str, expected: git.Version) -> None: 239 | """It produces the expected version.""" 240 | assert expected == git.Version.parse(version) 241 | 242 | 243 | @pytest.mark.parametrize( 244 | "version", 245 | [ 246 | "", 247 | "0", 248 | "1", 249 | "a", 250 | "1a", 251 | "1.a", 252 | "1a.0.0", 253 | "lorem.1.0.0", 254 | "1:1.0.0", 255 | ], 256 | ) 257 | def test_invalid_version(version: str) -> None: 258 | """It raises an exception.""" 259 | with pytest.raises(ValueError): 260 | git.Version.parse(version) 261 | -------------------------------------------------------------------------------- /src/retrocookie/git.py: -------------------------------------------------------------------------------- 1 | """Git interface.""" 2 | from __future__ import annotations 3 | 4 | import contextlib 5 | import functools 6 | import operator 7 | import re 8 | import subprocess # noqa: S404 9 | from dataclasses import dataclass 10 | from dataclasses import field 11 | from pathlib import Path 12 | from typing import Any 13 | from typing import cast 14 | from typing import Iterator 15 | from typing import List 16 | from typing import Optional 17 | 18 | import pygit2 19 | 20 | from retrocookie.compat.contextlib import contextmanager 21 | from retrocookie.utils import removeprefix 22 | 23 | 24 | class CommandError(Exception): 25 | """The command exited with a non-zero status.""" 26 | 27 | 28 | def git( 29 | *args: str, check: bool = True, **kwargs: Any 30 | ) -> subprocess.CompletedProcess[str]: 31 | """Invoke git.""" 32 | try: 33 | return subprocess.run( # noqa: S603,S607 34 | ["git", *args], check=check, text=True, capture_output=True, **kwargs 35 | ) 36 | except subprocess.CalledProcessError as error: 37 | command = " ".join(["git", *args[:1]]) 38 | message = f"{command} exited with status {error.returncode}" 39 | 40 | for stream in ["stderr", "stdout"]: 41 | output = getattr(error, stream) 42 | output = output.strip() 43 | if output: 44 | message = "\n\n".join([message, f"==> {stream} <==", output]) 45 | 46 | raise CommandError(message) from error 47 | 48 | 49 | VERSION_PATTERN = re.compile( 50 | r""" 51 | (?P\d+)\. 52 | (?P\d+) 53 | (\.(?P\d+))? 54 | """, 55 | re.VERBOSE, 56 | ) 57 | 58 | 59 | @dataclass(frozen=True, order=True) 60 | class Version: 61 | """Simplistic representation of git versions.""" 62 | 63 | major: int 64 | minor: int 65 | patch: int 66 | _text: Optional[str] = field(default=None, compare=False) 67 | 68 | @classmethod 69 | def parse(cls, text: str) -> Version: 70 | """Extract major.minor[.patch] from the start of the text.""" 71 | match = VERSION_PATTERN.match(text) 72 | 73 | if match is None: 74 | raise ValueError(f"invalid version {text!r}") 75 | 76 | parts = match.groupdict(default="0") 77 | 78 | return cls( 79 | int(parts["major"]), int(parts["minor"]), int(parts["patch"]), _text=text 80 | ) 81 | 82 | def __str__(self) -> str: 83 | """Return the original representation.""" 84 | return ( 85 | self._text 86 | if self._text is not None 87 | else f"{self.major}.{self.minor}.{self.patch}" 88 | ) 89 | 90 | 91 | def version() -> Version: 92 | """Return the git version.""" 93 | text = git("version").stdout.strip() 94 | text = removeprefix(text, "git version ") 95 | return Version.parse(text) 96 | 97 | 98 | def get_default_branch() -> str: 99 | """Return the default branch for new repositories.""" 100 | get_configs = [ 101 | pygit2.Config.get_global_config, 102 | pygit2.Config.get_system_config, 103 | ] 104 | for get_config in get_configs: 105 | with contextlib.suppress(IOError, KeyError): 106 | config = get_config() 107 | branch = config["init.defaultBranch"] 108 | assert isinstance(branch, str) # noqa: S101 109 | return branch 110 | 111 | return "master" 112 | 113 | 114 | class Repository: 115 | """Git repository.""" 116 | 117 | def __init__( 118 | self, path: Optional[Path] = None, *, repo: Optional[pygit2.Repository] = None 119 | ) -> None: 120 | """Initialize.""" 121 | if repo is None: 122 | self.path = path or Path.cwd() 123 | self.repo = pygit2.Repository(self.path) 124 | else: 125 | self.path = Path(repo.workdir or repo.path) 126 | self.repo = repo 127 | 128 | def git(self, *args: str, **kwargs: Any) -> subprocess.CompletedProcess[str]: 129 | """Invoke git.""" 130 | return git(*args, cwd=self.path, **kwargs) 131 | 132 | @classmethod 133 | def init(cls, path: Path, *, bare: bool = False) -> Repository: 134 | """Create a repository.""" 135 | # https://github.com/libgit2/libgit2/issues/2849 136 | path.parent.mkdir(exist_ok=True, parents=True) 137 | repo = pygit2.init_repository(path, bare=bare) 138 | return cls(path, repo=repo) 139 | 140 | @classmethod 141 | def clone(cls, url: str, path: Path, *, mirror: bool = False) -> Repository: 142 | """Clone a repository.""" 143 | options = ["--mirror"] if mirror else [] 144 | git("clone", *options, url, str(path)) 145 | return cls(path) 146 | 147 | def create_branch(self, branch: str, ref: str = "HEAD") -> None: 148 | """Create a branch.""" 149 | commit = self.repo.revparse_single(ref) 150 | self.repo.branches.create(branch, commit) 151 | 152 | def get_current_branch(self) -> str: 153 | """Return the current branch.""" 154 | return self.repo.head.shorthand # type: ignore[no-any-return] 155 | 156 | def exists_branch(self, branch: str) -> bool: 157 | """Return True if the branch exists.""" 158 | return branch in self.repo.branches 159 | 160 | def switch_branch(self, branch: str) -> None: 161 | """Switch the current branch.""" 162 | self.repo.checkout(self.repo.branches[branch]) 163 | 164 | def update_remote(self) -> None: 165 | """Update the remotes.""" 166 | self.git("remote", "update") 167 | 168 | def fetch_commits(self, source: Repository, *commits: str) -> None: 169 | """Fetch the given commits and their immediate parents.""" 170 | path = source.path.resolve() 171 | self.git("fetch", "--no-tags", "--depth=2", str(path), *commits) 172 | 173 | def push(self, remote: str, *refs: str, force: bool = False) -> None: 174 | """Update remote refs.""" 175 | options = ["--force-with-lease"] if force else [] 176 | self.git("push", *options, remote, *refs) 177 | 178 | def parse_revisions(self, *revisions: str) -> List[str]: 179 | """Parse revisions using the format specified in gitrevisions(7).""" 180 | process = self.git("rev-list", "--no-walk", *revisions) 181 | result = process.stdout.split() 182 | result.reverse() 183 | return result 184 | 185 | def lookup_replacement(self, commit: str) -> str: 186 | """Lookup the replace ref for the given commit.""" 187 | refname = f"refs/replace/{commit}" 188 | ref = self.repo.lookup_reference(refname) 189 | return cast(str, ref.target.hex) 190 | 191 | def _ensure_relative(self, path: Path) -> Path: 192 | """Interpret the path relative to the repository root.""" 193 | return path.relative_to(self.path) if path.is_absolute() else path 194 | 195 | def read_text(self, path: Path, *, ref: str = "HEAD") -> str: 196 | """Return the contents of the blob at the given path.""" 197 | commit = self.repo.revparse_single(ref) 198 | path = self._ensure_relative(path) 199 | blob = functools.reduce(operator.truediv, path.parts, commit.tree) 200 | return cast(str, blob.data.decode()) 201 | 202 | def exists(self, path: Path, *, ref: str = "HEAD") -> bool: 203 | """Return True if a blob exists at the given path.""" 204 | commit = self.repo.revparse_single(ref) 205 | path = self._ensure_relative(path) 206 | try: 207 | functools.reduce(operator.truediv, path.parts, commit.tree) 208 | return True 209 | except KeyError: 210 | return False 211 | 212 | def add(self, *paths: Path) -> None: 213 | """Add paths to the index.""" 214 | for path in paths: 215 | path = self._ensure_relative(path) 216 | self.repo.index.add(path) 217 | else: 218 | self.repo.index.add_all() 219 | self.repo.index.write() 220 | 221 | def commit(self, message: str) -> None: 222 | """Create a commit.""" 223 | try: 224 | head = self.repo.head 225 | refname = head.name 226 | parents = [head.target] 227 | except pygit2.GitError: 228 | branch = get_default_branch() 229 | refname = f"refs/heads/{branch}" 230 | parents = [] 231 | 232 | tree = self.repo.index.write_tree() 233 | author = committer = self.repo.default_signature 234 | 235 | self.repo.create_commit(refname, author, committer, message, tree, parents) 236 | 237 | def cherrypick(self, *refs: str) -> None: 238 | """Cherry-pick the given commits.""" 239 | self.git("cherry-pick", *refs) 240 | 241 | @contextmanager 242 | def worktree( 243 | self, 244 | branch: str, 245 | path: Path, 246 | *, 247 | base: str = "HEAD", 248 | force: bool = False, 249 | force_remove: bool = False, 250 | ) -> Iterator[Repository]: 251 | """Context manager to add and remove a worktree.""" 252 | repository = self.add_worktree(branch, path, base=base, force=force) 253 | try: 254 | yield repository 255 | finally: 256 | self.remove_worktree(path, force=force_remove) 257 | 258 | def add_worktree( 259 | self, 260 | branch: str, 261 | path: Path, 262 | *, 263 | base: str = "HEAD", 264 | force: bool = False, 265 | ) -> Repository: 266 | """Add a worktree.""" 267 | self.git( 268 | "worktree", 269 | "add", 270 | str(path), 271 | "--no-track", 272 | "-B" if force else "-b", 273 | branch, 274 | base, 275 | ) 276 | 277 | return Repository(path) 278 | 279 | def remove_worktree(self, path: Path, *, force: bool = False) -> None: 280 | """Remove a worktree.""" 281 | if force: 282 | self.git("worktree", "remove", "--force", str(path)) 283 | else: 284 | self.git("worktree", "remove", str(path)) 285 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Retrocookie 2 | =========== 3 | 4 | |PyPI| |Python Version| |License| |Read the Docs| |Tests| |Codecov| 5 | 6 | .. |PyPI| image:: https://img.shields.io/pypi/v/retrocookie.svg 7 | :target: https://pypi.org/project/retrocookie/ 8 | :alt: PyPI 9 | .. |Python Version| image:: https://img.shields.io/pypi/pyversions/retrocookie 10 | :target: https://pypi.org/project/retrocookie 11 | :alt: Python Version 12 | .. |License| image:: https://img.shields.io/pypi/l/retrocookie 13 | :target: https://opensource.org/licenses/MIT 14 | :alt: License 15 | .. |Read the Docs| image:: https://img.shields.io/readthedocs/retrocookie/latest.svg?label=Read%20the%20Docs 16 | :target: https://retrocookie.readthedocs.io/ 17 | :alt: Read the documentation at https://retrocookie.readthedocs.io/ 18 | .. |Tests| image:: https://github.com/cjolowicz/retrocookie/workflows/Tests/badge.svg 19 | :target: https://github.com/cjolowicz/retrocookie/actions?workflow=Tests 20 | :alt: Tests 21 | .. |Codecov| image:: https://codecov.io/gh/cjolowicz/retrocookie/branch/main/graph/badge.svg 22 | :target: https://codecov.io/gh/cjolowicz/retrocookie 23 | :alt: Codecov 24 | 25 | 26 | Retrocookie updates Cookiecutter_ templates with changes from generated projects. 27 | 28 | When developing Cookiecutter templates, 29 | you often need to work in a generated project rather than the template itself. 30 | Reasons for this include the following: 31 | 32 | - You need to run the Continuous Integration suite for the generated project 33 | - Your development tools choke when running on the templated project 34 | 35 | Any changes you make in the generated project 36 | need to be backported into the template, 37 | carefully replacing expanded variables from ``cookiecutter.json`` by templating tags, 38 | and escaping any use of ``{{`` and ``}}`` 39 | or other tokens with special meaning in Jinja. 40 | 41 | ✨ Retrocookie helps you in this situation. ✨ 42 | 43 | It is designed to fetch commits from the repository of a generated project, 44 | and import them into your Cookiecutter repository, 45 | rewriting them on the fly to insert templating tags, 46 | escape Jinja-special constructs, 47 | and place files in the template directory. 48 | 49 | Under the hood, 50 | Retrocookie rewrites the selected commits using git-filter-repo_, 51 | saving them to a temporary repository. 52 | It then fetches and cherry-picks the rewritten commits 53 | from the temporary repository into the Cookiecutter template, 54 | using pygit2_. 55 | 56 | Maybe you're thinking, 57 | how can this possibly work? 58 | One cannot reconstruct a Jinja template from its rendered output. 59 | However, simple replacements of template variables work well in practice 60 | when you're only importing a handful of commits at a time. 61 | 62 | **Important:** 63 | 64 | Retrocookie relies on a ``.cookiecutter.json`` file in the generated project 65 | to work out how to rewrite commits. 66 | This file is similar to the ``cookiecutter.json`` file in the template, 67 | but contains the specific values chosen during project generation. 68 | You can generate this file by putting it into the template directory in the Cookiecutter, 69 | with the following contents: 70 | 71 | .. code:: jinja 72 | 73 | {{ cookiecutter | jsonify }} 74 | 75 | 76 | Requirements 77 | ------------ 78 | 79 | * Python 3.7+ 80 | * git >= 2.22.0 81 | 82 | 83 | Installation 84 | ------------ 85 | 86 | You can install *Retrocookie* via pip_ from PyPI_: 87 | 88 | .. code:: console 89 | 90 | $ pip install retrocookie 91 | 92 | Optionally, install the ``pr`` extra for the retrocookie-pr_ command: 93 | 94 | .. code:: console 95 | 96 | $ pip install retrocookie[pr] 97 | 98 | 99 | Example 100 | ------- 101 | 102 | Here's an example to demonstrate the general workflow. 103 | 104 | To start with, we clone the repository of your Cookiecutter template. 105 | For this example, I'll use my own `Hypermodern Python Cookiecutter`_. 106 | 107 | .. code:: console 108 | 109 | $ git clone https://github.com/cjolowicz/cookiecutter-hypermodern-python 110 | 111 | Next, we'll create a project from the template, 112 | and set up a git repository for it: 113 | 114 | .. code:: console 115 | 116 | $ cookiecutter --no-input cookiecutter-hypermodern-python 117 | $ cd hypermodern-python 118 | $ git init 119 | $ git add . 120 | $ git commit --message="Initial commit" 121 | 122 | Let's open a feature branch in the project repository, 123 | and make a fictitious change involving the default project name *hypermodern-python*: 124 | 125 | .. code:: console 126 | 127 | $ git switch --create add-example 128 | $ echo '# hypermodern-python' > EXAMPLE.md 129 | $ git add EXAMPLE.md 130 | $ git commit --message="Add example" 131 | 132 | Back in the Cookiecutter repository, 133 | we can now invoke retrocookie to import the changes from the feature branch: 134 | 135 | .. code:: console 136 | 137 | $ cd ../cookiecutter-hypermodern-python 138 | $ retrocookie --branch add-example --create ../hypermodern-python 139 | 140 | A ``git show`` in the Cookiecutter shows the file under the template directory, 141 | on a branch named as in the original repository, 142 | with the project name replaced by a Jinja tag: 143 | 144 | .. code:: diff 145 | 146 | commit abb4f823b9f1760e3a678c927ec9797c0a40a9b6 (HEAD -> add-example) 147 | Author: Your Name 148 | Date: Fri Dec 4 23:40:41 2020 +0100 149 | 150 | Add example 151 | 152 | diff --git a/{{cookiecutter.project_name}}/EXAMPLE.md b/{{cookiecutter.project_name}}/EXAMPLE.md 153 | new file mode 100644 154 | index 0000000..a158618 155 | --- /dev/null 156 | +++ b/{{cookiecutter.project_name}}/EXAMPLE.md 157 | @@ -0,0 +1 @@ 158 | +# {{cookiecutter.project_name}} 159 | 160 | 161 | Usage 162 | ----- 163 | 164 | The basic form: 165 | 166 | .. code:: 167 | 168 | $ retrocookie [...] 169 | $ retrocookie -b [--create] 170 | 171 | The ```` is a filesystem path to the source repository. 172 | For ````, see `gitrevisions(7)`__. 173 | 174 | __ https://git-scm.com/docs/gitrevisions 175 | 176 | Import ``HEAD`` from ````: 177 | 178 | .. code:: 179 | 180 | $ retrocookie 181 | 182 | Import the last two commits: 183 | 184 | .. code:: 185 | 186 | $ retrocookie HEAD~2.. 187 | 188 | Import by commit hash: 189 | 190 | .. code:: 191 | 192 | $ retrocookie 53268f7 6a3368a c0b4c6c 193 | 194 | Import commits from branch ``topic``: 195 | 196 | .. code:: 197 | 198 | $ retrocookie --branch=topic 199 | 200 | Equivalently: 201 | 202 | .. code:: 203 | 204 | $ retrocookie master..topic 205 | 206 | Import commits from ``topic`` into a branch with the same name: 207 | 208 | .. code:: 209 | 210 | $ retrocookie --branch=topic --create 211 | 212 | Equivalently, using short options: 213 | 214 | .. code:: 215 | 216 | $ retrocookie -cb topic 217 | 218 | Import commits from branch ``topic``, which was branched off ``1.0``: 219 | 220 | .. code:: 221 | 222 | $ retrocookie --branch=topic --upstream=1.0 223 | 224 | Equivalently: 225 | 226 | .. code:: 227 | 228 | $ retrocookie 1.0..topic 229 | 230 | Import ``HEAD`` into a new branch ``topic``: 231 | 232 | .. code:: 233 | 234 | $ retrocookie --create-branch=topic 235 | 236 | Please see the `Command-line Reference `_ for further details. 237 | 238 | 239 | .. _retrocookie-pr: 240 | 241 | Importing pull requests from generated projects with retrocookie-pr 242 | ------------------------------------------------------------------- 243 | 244 | You can import pull requests from a generated project to the project template, 245 | assuming their repositories are on GitHub_. 246 | This requires activating the ``pr`` extra when installing with pip_: 247 | 248 | .. code:: 249 | 250 | $ pip install retrocookie[pr] 251 | 252 | The command ``retrocookie-pr`` has the basic form: 253 | 254 | .. code:: 255 | 256 | $ retrocookie-pr [-R ] [...] 257 | $ retrocookie-pr [-R ] --user= 258 | $ retrocookie-pr [-R ] --all 259 | 260 | Command-line arguments specify pull requests to import, by number or by branch. 261 | Pull requests from forks are currently not supported. 262 | 263 | Use the ``-R `` option to specify the GitHub repository of the generated project 264 | from which the pull requests should be imported. 265 | Provide the full name of the repository on GitHub in the form ``owner/name``. 266 | The owner can be omitted if the repository is owned by the authenticated user. 267 | This option can be omitted when the command is invoked from a local clone. 268 | 269 | You can also select pull requests by specifying the user that opened them, via the ``--user`` option. 270 | This is handy for importing automated pull requests, such as dependency updates from Dependabot_. 271 | 272 | Use the ``--all`` option to import all open pull requests in the generated project. 273 | 274 | You can update previously imported pull requests by specifying ``--force``. 275 | By default, ``retrocookie-pr`` refuses to overwrite existing pull requests. 276 | 277 | The command needs a `personal access token`_ to access the GitHub API. 278 | (This token is also used to push to the GitHub repository of the project template.) 279 | You will be prompted for the token when you invoke the command for the first time. 280 | On subsequent invocations, the token is read from the application cache. 281 | Alternatively, you can specify the token using the ``--token`` option or the ``GITHUB_TOKEN`` environment variable; 282 | both of these methods bypass the cache. 283 | 284 | Use the ``--open`` option to open each imported pull request in a web browser. 285 | 286 | Please see the `Command-line Reference `_ for further details. 287 | 288 | 289 | Contributing 290 | ------------ 291 | 292 | Contributions are very welcome. 293 | To learn more, see the `Contributor Guide`_. 294 | 295 | 296 | License 297 | ------- 298 | 299 | Distributed under the terms of the MIT_ license, 300 | *Retrocookie* is free and open source software. 301 | 302 | 303 | Issues 304 | ------ 305 | 306 | If you encounter any problems, 307 | please `file an issue`_ along with a detailed description. 308 | 309 | 310 | Credits 311 | ------- 312 | 313 | This project was generated from `@cjolowicz`_'s `Hypermodern Python Cookiecutter`_ template. 314 | 315 | 316 | .. _@cjolowicz: https://github.com/cjolowicz 317 | .. _Cookiecutter: https://github.com/audreyr/cookiecutter 318 | .. _Dependabot: https://dependabot.com/ 319 | .. _GitHub: https://github.com/ 320 | .. _Hypermodern Python Cookiecutter: https://github.com/cjolowicz/cookiecutter-hypermodern-python 321 | .. _MIT: http://opensource.org/licenses/MIT 322 | .. _PyPI: https://pypi.org/ 323 | .. _file an issue: https://github.com/cjolowicz/retrocookie/issues 324 | .. _git-filter-repo: https://github.com/newren/git-filter-repo 325 | .. _git rebase: https://git-scm.com/docs/git-rebase 326 | .. _pip: https://pip.pypa.io/ 327 | .. _personal access token: https://docs.github.com/en/free-pro-team@latest/github/authenticating-to-github/creating-a-personal-access-token 328 | .. _pygit2: https://github.com/libgit2/pygit2 329 | .. github-only 330 | .. _Contributor Guide: CONTRIBUTING.rst 331 | .. _Usage: https://retrocookie.readthedocs.io/en/latest/usage.html 332 | -------------------------------------------------------------------------------- /src/retrocookie/pr/base/exceptionhandlers.py: -------------------------------------------------------------------------------- 1 | """Exception handlers as first-class citizens. 2 | 3 | The utilities in this module allow the creation of exception handlers as 4 | first-class objects that can be passed around and combined. 5 | 6 | Consider the following imaginary application: 7 | 8 | >>> class ApplicationError(Exception): 9 | ... pass 10 | ... 11 | >>> def main(): 12 | ... pass 13 | 14 | Conventionally, application errors would be handled using a ``try...except`` 15 | block: 16 | 17 | >>> try: 18 | ... main() 19 | ... except ApplicationError as error: 20 | ... sys.exit(f"fatal: {error}") 21 | 22 | Instead, this module lets you create an exception handler object: 23 | 24 | >>> @exceptionhandler(ApplicationError) 25 | ... def fatal(error): 26 | ... sys.exit(f"fatal: {error}") 27 | 28 | The :func:`exceptionhandler` decorator creates an instance of 29 | :class:`ExceptionHandler`, a context manager that invokes the wrapped function 30 | when an exception of type ApplicationError is caught. 31 | 32 | The exception handler can now be used anywhere in the code, like this: 33 | 34 | >>> with fatal: 35 | ... main() 36 | 37 | Exception handlers can also be used as decorators to apply exception handling to 38 | a function. So instead of wrapping every call to ``main`` in a ``with`` 39 | statement, you could apply the exception handler to ``main`` once: 40 | 41 | >>> @fatal 42 | ... def main(): 43 | ... ... 44 | 45 | If your code uses type annotations, you can omit the arguments to 46 | ``exceptionhandler``, and define the handler like this: 47 | 48 | >>> @exceptionhandler 49 | ... def fatal(error: ApplicationError): 50 | ... sys.exit(f"fatal: {error}") 51 | 52 | An important difference between an ``except`` clause and an ``ExceptionHandler`` 53 | is that ``except`` clauses suppress exceptions by default, and require an 54 | explicit ``raise`` statement to re-raise the exception. By contrast, exception 55 | handlers re-raise exceptions by default, and require an explicit ``return True`` 56 | statement to suppress the exception: 57 | 58 | >>> @exceptionhandler(Exception) 59 | ... def suppress(exception): 60 | ... return True 61 | 62 | Exception handlers can be combined using the ``<<`` and ``>>`` operators: 63 | 64 | >>> with fatal >> suppress: 65 | ... ... 66 | 67 | The operator points in the direction of exception flow, so in this example 68 | ``fatal`` sees exceptions before ``suppress``; in other words, the code exits on 69 | application errors, while any remaining exceptions are silently discarded. 70 | 71 | In general, ``a >> b`` is equivalent to ``b << a``. For example, the code above 72 | could also have been written like this: 73 | 74 | >>> with suppress << fatal: 75 | ... ... 76 | 77 | This style is closer to how the code would look with multiple context managers: 78 | 79 | >>> with suppress, fatal: 80 | ... ... 81 | 82 | While this form is shorter, composing exception handlers explicitly gives us an 83 | ``ExceptionHandler`` object. The combined exception handler can be stored in a 84 | variable, passed around, or combined with further handlers: 85 | 86 | >>> handler = fatal >> suppress 87 | 88 | This module defines only one actual exception handler of its own: 89 | ``nullhandler`` is a noop that does not handle any exceptions. A practical use 90 | of ``nullhandler`` is when exception handling should be optional, letting you 91 | either pass the real handler or ``nullhandler``. 92 | 93 | In algebraic terms, ``nullhandler`` is the identity element of composition 94 | (``<<`` and ``>>``): Combining any handler with ``nullhandler`` gives you a 95 | handler that behaves exactly like the original handler. As composition is also 96 | associative, this makes exception handlers with composition a simple monoid_. 97 | 98 | .. _monoid: https://en.wikipedia.org/wiki/Monoid 99 | """ 100 | from __future__ import annotations 101 | 102 | import contextlib 103 | from types import TracebackType 104 | from typing import Callable 105 | from typing import Generic 106 | from typing import get_type_hints 107 | from typing import List 108 | from typing import Optional 109 | from typing import overload 110 | from typing import Tuple 111 | from typing import Type 112 | from typing import TypeVar 113 | from typing import Union 114 | 115 | 116 | __all__ = [ 117 | "ExceptionHandler", 118 | "exceptionhandler", 119 | "nullhandler", 120 | ] 121 | 122 | 123 | class ExceptionHandler(contextlib.ContextDecorator): 124 | """Context manager for handling exceptions. 125 | 126 | Apply exception handlers to a block of code using a with block:: 127 | 128 | with myhandler: 129 | ... 130 | 131 | Note that there are no parentheses, because there is no need to invoke the 132 | exception handler. Exception handlers are `reentrant context managers`_, 133 | which means that you do not need to create them afresh for each use, even 134 | when they are used multiple times in nested ``with`` statements. 135 | 136 | .. _reentrant context managers: 137 | https://docs.python.org/3/library/contextlib.html 138 | #single-use-reusable-and-reentrant-context-managers 139 | 140 | Exception handlers can also be used as decorators:: 141 | 142 | @myhandler 143 | def f(): 144 | ... 145 | 146 | This is equivalent to enclosing the body of the function in a ``with`` 147 | block:: 148 | 149 | def f(): 150 | with myhandler: 151 | ... 152 | 153 | This class should not be instantiated directly. Use the 154 | :func:`exceptionhandler` decorator to create an exception handler. 155 | """ 156 | 157 | def __enter__(self) -> None: 158 | """Enter the runtime context.""" 159 | 160 | def __exit__( 161 | self, 162 | exception_type: Optional[Type[BaseException]], 163 | exception: Optional[BaseException], 164 | traceback: Optional[TracebackType], 165 | ) -> Optional[bool]: 166 | """Exit the runtime context.""" 167 | return None 168 | 169 | def __lshift__(self, other: ExceptionHandler) -> ExceptionHandler: 170 | """Compose this exception handler with another one. 171 | 172 | The expression ``a << b`` is equivalent to these nested ``with`` 173 | blocks:: 174 | 175 | with a: 176 | with b: 177 | ... 178 | 179 | As a mnemonic, the operator points in the direction of exception flow. 180 | In ``a << b``, the right operand sees the exception first. 181 | """ 182 | return _Compose(self, other) 183 | 184 | def __rshift__(self, other: ExceptionHandler) -> ExceptionHandler: 185 | """Compose this exception handler with another one. 186 | 187 | The expression ``a >> b`` is equivalent to these nested ``with`` 188 | blocks:: 189 | 190 | with b: 191 | with a: 192 | ... 193 | 194 | As a mnemonic, the operator points in the direction of exception flow. 195 | In ``a >> b``, the left operand sees the exception first. 196 | """ 197 | return other << self 198 | 199 | 200 | nullhandler = ExceptionHandler() 201 | 202 | 203 | class _Compose(ExceptionHandler): 204 | """Exception handler composed of other exception handlers.""" 205 | 206 | def __init__(self, first: ExceptionHandler, second: ExceptionHandler) -> None: 207 | """Initialize.""" 208 | # Reduce object count in expressions like ``a << b << c``. 209 | self.stacks: List[contextlib.ExitStack] = [] 210 | self.handlers: List[ExceptionHandler] = [ 211 | handler 212 | for arg in (first, second) 213 | for handler in (arg.handlers if isinstance(arg, _Compose) else [arg]) 214 | ] 215 | 216 | def __enter__(self) -> None: 217 | """Enter the runtime context.""" 218 | stack = contextlib.ExitStack() 219 | self.stacks.append(stack) 220 | 221 | for handler in self.handlers: 222 | stack.enter_context(handler) 223 | 224 | stack.__enter__() 225 | 226 | def __exit__( 227 | self, 228 | exception_type: Optional[Type[BaseException]], 229 | exception: Optional[BaseException], 230 | traceback: Optional[TracebackType], 231 | ) -> Optional[bool]: 232 | """Exit the runtime context.""" 233 | stack = self.stacks.pop() 234 | return stack.__exit__(exception_type, exception, traceback) 235 | 236 | 237 | E = TypeVar("E", bound=BaseException, contravariant=True) 238 | 239 | 240 | _Callback = Callable[[E], Optional[bool]] 241 | 242 | 243 | class _Decorator(Generic[E], ExceptionHandler): 244 | """Exception handler wrapping a callback function. 245 | 246 | This class is a helper for the @exceptionhandler decorator. 247 | """ 248 | 249 | def __init__( 250 | self, 251 | callback: _Callback[E], 252 | *exception_types: Type[BaseException], 253 | ) -> None: 254 | """Initialize.""" 255 | self.callback = callback 256 | self.exception_types = exception_types 257 | 258 | def __exit__( 259 | self, 260 | exception_type: Optional[Type[BaseException]], 261 | exception: Optional[BaseException], 262 | traceback: Optional[TracebackType], 263 | ) -> Optional[bool]: 264 | """Exit the context.""" 265 | return ( 266 | self.callback(exception) # type: ignore[arg-type] 267 | if any( 268 | isinstance(exception, exception_type) 269 | for exception_type in self.exception_types 270 | ) 271 | else None 272 | ) 273 | 274 | 275 | def _exceptionhandler( 276 | *exception_types: Type[BaseException], 277 | ) -> Callable[[_Callback[E]], ExceptionHandler]: 278 | """Decorator factory implementing @exceptionhandler.""" 279 | 280 | def _decorator(callback: _Callback[E]) -> ExceptionHandler: 281 | if exception_types: 282 | return _Decorator(callback, *exception_types) 283 | 284 | try: 285 | exception_type = next( 286 | hint 287 | for key, hint in get_type_hints(callback).items() 288 | if key != "return" 289 | ) 290 | except StopIteration: 291 | raise TypeError(f"missing type annotation on {callback}") from None 292 | 293 | return _Decorator(callback, exception_type) 294 | 295 | return _decorator 296 | 297 | 298 | @overload 299 | def exceptionhandler(__callback: _Callback[E]) -> ExceptionHandler: 300 | """Overload for the plain decorator.""" # noqa: D418 301 | 302 | 303 | @overload 304 | def exceptionhandler( 305 | *exception_types: Type[BaseException], 306 | ) -> Callable[[_Callback[E]], ExceptionHandler]: 307 | """Overload for the decorator factory.""" # noqa: D418 308 | 309 | 310 | def exceptionhandler( 311 | *args: Union[_Callback[E], Type[BaseException]], 312 | ) -> Union[ExceptionHandler, Callable[[_Callback[E]], ExceptionHandler]]: 313 | """Decorator for creating exception handlers. 314 | 315 | This decorator creates an exception handler that invokes the wrapped 316 | function when an exception of the specified type is caught:: 317 | 318 | @exceptionhandler(SomeError) 319 | def handler(error): 320 | return True 321 | 322 | Multiple exception types may be specified, or none. In the latter case, the 323 | exception type is derived from the type annotation for the function 324 | parameter. Parentheses for the decorator are optional in this case. The 325 | following example is equivalent to the first one:: 326 | 327 | @exceptionhandler 328 | def handler(error: SomeError) -> bool: 329 | return True 330 | 331 | A return value of True suppresses the exception. If the function returns 332 | False or None, the exception is reraised when the handler returns:: 333 | 334 | @exceptionhandler 335 | def print_handler(exception: SomeException) -> None: 336 | print(exception) # the exception is printed but not swallowed 337 | 338 | Another option is to raise another exception:: 339 | 340 | @exceptionhandler 341 | def exit_handler(exception: SomeException) -> NoReturn: 342 | raise SystemExit(str(exception)) 343 | 344 | """ 345 | if all(isinstance(arg, type) and issubclass(arg, BaseException) for arg in args): 346 | exception_types: Tuple[Type[BaseException]] = args # type: ignore[assignment] 347 | return _exceptionhandler(*exception_types) 348 | 349 | callback: _Callback[E] 350 | [callback] = args # type: ignore[assignment] 351 | return _exceptionhandler()(callback) 352 | --------------------------------------------------------------------------------