├── .github ├── dependabot.yml ├── renovate.json5 └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── example.py ├── prosemirror ├── __init__.py ├── model │ ├── __init__.py │ ├── comparedeep.py │ ├── content.py │ ├── diff.py │ ├── dom.py │ ├── fragment.py │ ├── from_dom.py │ ├── mark.py │ ├── node.py │ ├── replace.py │ ├── resolvedpos.py │ ├── schema.py │ └── to_dom.py ├── py.typed ├── schema │ ├── basic │ │ ├── __init__.py │ │ └── schema_basic.py │ └── list │ │ ├── __init__.py │ │ └── schema_list.py ├── test_builder │ ├── __init__.py │ └── build.py ├── transform │ ├── __init__.py │ ├── attr_step.py │ ├── doc_attr_step.py │ ├── map.py │ ├── mark.py │ ├── mark_step.py │ ├── replace.py │ ├── replace_step.py │ ├── step.py │ ├── structure.py │ └── transform.py └── utils.py ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── prosemirror_model ├── __init__.py └── tests │ ├── __init__.py │ ├── test_content.py │ ├── test_diff.py │ ├── test_dom.py │ ├── test_mark.py │ ├── test_node.py │ ├── test_resolve.py │ └── test_slice.py └── prosemirror_transform ├── __init__.py └── tests ├── __init__.py ├── conftest.py ├── test_mapping.py ├── test_step.py ├── test_structure.py └── test_trans.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "10:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /.github/renovate.json5: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:recommended", 5 | ":disableDependencyDashboard", 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | # IMPORTANT: this permission is mandatory for trusted publishing 13 | id-token: write 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.13" 19 | - run: pip install uv 20 | - run: uv build 21 | - run: uv publish 22 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12", "3.13"] 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - run: pip install uv 21 | - run: uv venv 22 | - run: uv sync 23 | - run: uv run pytest --cov=./prosemirror/ 24 | - run: uv run codecov 25 | if: matrix.python-version == '3.11' 26 | env: 27 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 28 | 29 | lint: 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/checkout@v4 33 | - uses: actions/setup-python@v5 34 | with: 35 | python-version: "3.13" 36 | - run: pip install uv 37 | - run: uv venv 38 | - run: uv sync 39 | - run: uv run ruff format --check prosemirror tests 40 | - run: uv run ruff check prosemirror tests 41 | - run: uv run pyright prosemirror 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | .DS_Store 125 | .vscode/ 126 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Fellow Insights Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # prosemirror-py 2 | 3 | [![CI](https://github.com/fellowapp/prosemirror-py/actions/workflows/test.yml/badge.svg)](https://github.com/fellowapp/prosemirror-py/actions/workflows/test.yml) 4 | [![Code Coverage](https://codecov.io/gh/fellowapp/prosemirror-py/branch/master/graph/badge.svg?style=flat)](https://codecov.io/gh/fellowapp/prosemirror-py) 5 | [![PyPI Package](https://img.shields.io/pypi/v/prosemirror.svg?style=flat)](https://pypi.org/project/prosemirror/) 6 | [![License](https://img.shields.io/pypi/l/prosemirror.svg?style=flat)](https://github.com/fellowapp/prosemirror-py/blob/master/LICENSE.md) 7 | [![Fellow Careers](https://img.shields.io/badge/fellow.app-hiring-576cf7.svg?style=flat)](https://fellow.app/careers/) 8 | 9 | This package provides Python implementations of the following 10 | [ProseMirror](https://prosemirror.net/) packages: 11 | 12 | - [`prosemirror-model`](https://github.com/ProseMirror/prosemirror-model) version 1.18.1 13 | - [`prosemirror-transform`](https://github.com/ProseMirror/prosemirror-transform) version 1.8.0 14 | - [`prosemirror-test-builder`](https://github.com/ProseMirror/prosemirror-test-builder) 15 | - [`prosemirror-schema-basic`](https://github.com/ProseMirror/prosemirror-schema-basic) version 1.1.2 16 | - [`prosemirror-schema-list`](https://github.com/ProseMirror/prosemirror-schema-list) 17 | 18 | The original implementation has been followed as closely as possible during 19 | translation to simplify keeping this package up-to-date with any upstream 20 | changes. 21 | 22 | ## Why? 23 | 24 | ProseMirror provides a powerful toolkit for building rich-text editors, but it's 25 | JavaScript-only. Until now, the only option for manipulating and working with 26 | ProseMirror documents from Python was to embed a JS runtime. With this 27 | translation, you can now define schemas, parse documents, and apply transforms 28 | directly via a native Python API. 29 | 30 | ## Status 31 | 32 | The full ProseMirror test suite has been translated and passes. This project 33 | only supports Python 3. The code has type annotations to support mypy or other 34 | typechecking tools. 35 | 36 | ## Usage 37 | 38 | Since this library is a direct port, the best place to learn how to use it is 39 | the [official ProseMirror documentation](https://prosemirror.net/docs/guide/). 40 | Here is a simple example using the included "basic" schema: 41 | 42 | ```python 43 | from prosemirror.transform import Transform 44 | from prosemirror.schema.basic import schema 45 | 46 | # Create a document containing a single paragraph with the text "Hello, world!" 47 | doc = schema.node("doc", {}, [ 48 | schema.node("paragraph", {}, [ 49 | schema.text("Hello, world!") 50 | ]) 51 | ]) 52 | 53 | # Create a Transform which will be applied to the document. 54 | tr = Transform(doc) 55 | 56 | # Delete the text from position 3 to 5. Adds a ReplaceStep to the transform. 57 | tr.delete(3, 5) 58 | 59 | # Make the first three characters bold. Adds an AddMarkStep to the transform. 60 | tr.add_mark(1, 4, schema.mark("strong")) 61 | 62 | # This transform can be converted to JSON to be sent and applied elsewhere. 63 | assert [step.to_json() for step in tr.steps] == [{ 64 | 'stepType': 'replace', 65 | 'from': 3, 66 | 'to': 5 67 | }, { 68 | 'stepType': 'addMark', 69 | 'mark': {'type': 'strong', 'attrs': {}}, 70 | 'from': 1, 71 | 'to': 4 72 | }] 73 | 74 | # The resulting document can also be converted to JSON. 75 | assert tr.doc.to_json() == { 76 | 'type': 'doc', 77 | 'content': [{ 78 | 'type': 'paragraph', 79 | 'content': [{ 80 | 'type': 'text', 81 | 'marks': [{'type': 'strong', 'attrs': {}}], 82 | 'text': 'Heo' 83 | }, { 84 | 'type': 'text', 85 | 'text': ', world!' 86 | }] 87 | }] 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from prosemirror.model import Node, Schema 2 | 3 | basic_spec = { 4 | "nodes": { 5 | "doc": {"content": "block+"}, 6 | "paragraph": { 7 | "content": "inline*", 8 | "group": "block", 9 | "parseDOM": [{"tag": "p"}], 10 | }, 11 | "blockquote": { 12 | "content": "block+", 13 | "group": "block", 14 | "defining": True, 15 | "parseDOM": [{"tag": "blockquote"}], 16 | }, 17 | "horizontal_rule": {"group": "block", "parseDOM": [{"tag": "hr"}]}, 18 | "heading": { 19 | "attrs": {"level": {"default": 1}}, 20 | "content": "inline*", 21 | "group": "block", 22 | "defining": True, 23 | "parseDOM": [ 24 | {"tag": "h1", "attrs": {"level": 1}}, 25 | {"tag": "h2", "attrs": {"level": 2}}, 26 | {"tag": "h3", "attrs": {"level": 3}}, 27 | {"tag": "h4", "attrs": {"level": 4}}, 28 | {"tag": "h5", "attrs": {"level": 5}}, 29 | {"tag": "h6", "attrs": {"level": 6}}, 30 | ], 31 | }, 32 | "code_block": { 33 | "content": "text*", 34 | "marks": "", 35 | "group": "block", 36 | "code": True, 37 | "defining": True, 38 | "parseDOM": [{"tag": "pre", "preserveWhitespace": "full"}], 39 | }, 40 | "text": {"group": "inline"}, 41 | "image": { 42 | "inline": True, 43 | "attrs": {"src": {}, "alt": {"default": None}, "title": {"default": None}}, 44 | "group": "inline", 45 | "draggable": True, 46 | "parseDOM": [{"tag": "img[src]"}], 47 | }, 48 | "hard_break": { 49 | "inline": True, 50 | "group": "inline", 51 | "selectable": False, 52 | "parseDOM": [{"tag": "br"}], 53 | }, 54 | "ordered_list": { 55 | "attrs": {"order": {"default": 1}}, 56 | "parseDOM": [{"tag": "ol"}], 57 | "content": "list_item+", 58 | "group": "block", 59 | }, 60 | "bullet_list": { 61 | "parseDOM": [{"tag": "ul"}], 62 | "content": "list_item+", 63 | "group": "block", 64 | }, 65 | "list_item": { 66 | "parseDOM": [{"tag": "li"}], 67 | "defining": True, 68 | "content": "paragraph block*", 69 | }, 70 | }, 71 | "marks": { 72 | "link": { 73 | "attrs": {"href": {}, "title": {"default": None}}, 74 | "inclusive": False, 75 | "parseDOM": [{"tag": "a[href]"}], 76 | }, 77 | "em": { 78 | "parseDOM": [{"tag": "i"}, {"tag": "em"}, {"style": "font-style=italic"}], 79 | }, 80 | "strong": { 81 | "parseDOM": [{"tag": "strong"}, {"tag": "b"}, {"style": "font-weight"}], 82 | }, 83 | "code": {"parseDOM": [{"tag": "code"}]}, 84 | }, 85 | } 86 | 87 | 88 | basic_schema = Schema(basic_spec) 89 | basic_doc = { 90 | "type": "doc", 91 | "content": [ 92 | { 93 | "type": "heading", 94 | "attrs": {"level": 1}, 95 | "content": [{"type": "text", "text": "Fellow"}], 96 | }, 97 | { 98 | "type": "paragraph", 99 | "content": [ 100 | {"type": "text", "text": "Test "}, 101 | {"type": "text", "marks": [{"type": "strong"}], "text": "this"}, 102 | {"type": "text", "text": " text"}, 103 | ], 104 | }, 105 | ], 106 | } 107 | doc_node = Node.from_json(basic_schema, basic_doc) 108 | -------------------------------------------------------------------------------- /prosemirror/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Fragment, Mark, Node, ResolvedPos, Schema, Slice 2 | from .schema.basic import schema as basic_schema 3 | from .transform import Mapping, Step, Transform 4 | 5 | __all__ = [ 6 | "Fragment", 7 | "Mapping", 8 | "Mark", 9 | "Node", 10 | "ResolvedPos", 11 | "Schema", 12 | "Slice", 13 | "Step", 14 | "Transform", 15 | "basic_schema", 16 | ] 17 | -------------------------------------------------------------------------------- /prosemirror/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .content import ContentMatch 2 | from .fragment import Fragment 3 | from .from_dom import DOMParser 4 | from .mark import Mark 5 | from .node import Node 6 | from .replace import ReplaceError, Slice 7 | from .resolvedpos import NodeRange, ResolvedPos 8 | from .schema import MarkType, NodeType, Schema 9 | from .to_dom import DOMSerializer 10 | 11 | __all__ = [ 12 | "ContentMatch", 13 | "DOMParser", 14 | "DOMSerializer", 15 | "Fragment", 16 | "Mark", 17 | "MarkType", 18 | "Node", 19 | "NodeRange", 20 | "NodeType", 21 | "ReplaceError", 22 | "ResolvedPos", 23 | "Schema", 24 | "Slice", 25 | ] 26 | -------------------------------------------------------------------------------- /prosemirror/model/comparedeep.py: -------------------------------------------------------------------------------- 1 | from prosemirror.utils import JSON 2 | 3 | 4 | def compare_deep(a: JSON, b: JSON) -> bool: 5 | return a == b 6 | -------------------------------------------------------------------------------- /prosemirror/model/diff.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, TypedDict 2 | 3 | from prosemirror.utils import text_length 4 | 5 | from . import node as pm_node 6 | 7 | if TYPE_CHECKING: 8 | from prosemirror.model.fragment import Fragment 9 | 10 | 11 | class Diff(TypedDict): 12 | a: int 13 | b: int 14 | 15 | 16 | def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> int | None: 17 | i = 0 18 | while True: 19 | if a.child_count == i or b.child_count == i: 20 | return None if a.child_count == b.child_count else pos 21 | child_a, child_b = a.child(i), b.child(i) 22 | if child_a == child_b: 23 | pos += child_a.node_size 24 | continue 25 | if not child_a.same_markup(child_b): 26 | return pos 27 | if child_a.is_text: 28 | assert isinstance(child_a, pm_node.TextNode) 29 | assert isinstance(child_b, pm_node.TextNode) 30 | if child_a.text != child_b.text: 31 | if child_b.text.startswith(child_a.text): 32 | return pos + text_length(child_a.text) 33 | if child_a.text.startswith(child_b.text): 34 | return pos + text_length(child_b.text) 35 | next_index = next( 36 | ( 37 | index_a 38 | for ((index_a, char_a), (_, char_b)) in zip( 39 | enumerate(child_a.text), 40 | enumerate(child_b.text), 41 | strict=True, 42 | ) 43 | if char_a != char_b 44 | ), 45 | None, 46 | ) 47 | if next_index is not None: 48 | return pos + next_index 49 | if child_a.content.size or child_b.content.size: 50 | inner = find_diff_start(child_a.content, child_b.content, pos + 1) 51 | if inner: 52 | return inner 53 | pos += child_a.node_size 54 | i += 1 55 | 56 | 57 | def find_diff_end(a: "Fragment", b: "Fragment", pos_a: int, pos_b: int) -> Diff | None: 58 | i_a, i_b = a.child_count, b.child_count 59 | while True: 60 | if i_a == 0 or i_b == 0: 61 | if i_a == i_b: 62 | return None 63 | else: 64 | return {"a": pos_a, "b": pos_b} 65 | i_a -= 1 66 | i_b -= 1 67 | child_a, child_b = a.child(i_a), b.child(i_b) 68 | size = child_a.node_size 69 | if child_a == child_b: 70 | pos_a -= size 71 | pos_b -= size 72 | continue 73 | 74 | if not child_a.same_markup(child_b): 75 | return {"a": pos_a, "b": pos_b} 76 | 77 | if child_a.is_text: 78 | assert isinstance(child_a, pm_node.TextNode) 79 | assert isinstance(child_b, pm_node.TextNode) 80 | if child_a.text != child_b.text: 81 | same, min_size = ( 82 | 0, 83 | min(text_length(child_a.text), text_length(child_b.text)), 84 | ) 85 | while ( 86 | same < min_size 87 | and child_a.text[text_length(child_a.text) - same - 1] 88 | == child_b.text[text_length(child_b.text) - same - 1] 89 | ): 90 | same += 1 91 | pos_a -= 1 92 | pos_b -= 1 93 | return {"a": pos_a, "b": pos_b} 94 | 95 | if child_a.content.size or child_b.content.size: 96 | inner = find_diff_end( 97 | child_a.content, 98 | child_b.content, 99 | pos_a - 1, 100 | pos_b - 1, 101 | ) 102 | if inner: 103 | return inner 104 | 105 | pos_a -= size 106 | pos_b -= size 107 | -------------------------------------------------------------------------------- /prosemirror/model/dom.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/prosemirror/model/dom.py -------------------------------------------------------------------------------- /prosemirror/model/fragment.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable, Iterable, Sequence 2 | from typing import ( 3 | TYPE_CHECKING, 4 | Any, 5 | ClassVar, 6 | Optional, 7 | Union, 8 | cast, 9 | ) 10 | 11 | from prosemirror.utils import JSON, JSONDict, JSONList, text_length 12 | 13 | if TYPE_CHECKING: 14 | from prosemirror.model.schema import Schema 15 | 16 | from .diff import Diff 17 | from .node import Node, TextNode 18 | 19 | 20 | def ret_index(index: int, offset: int) -> dict[str, int]: 21 | return {"index": index, "offset": offset} 22 | 23 | 24 | class Fragment: 25 | empty: ClassVar["Fragment"] 26 | content: list["Node"] 27 | size: int 28 | 29 | def __init__(self, content: list["Node"], size: int | None = None) -> None: 30 | self.content = content 31 | self.size = size if size is not None else sum(c.node_size for c in content) 32 | 33 | def nodes_between( 34 | self, 35 | from_: int, 36 | to: int, 37 | f: Callable[["Node", int, Optional["Node"], int], bool | None], 38 | node_start: int = 0, 39 | parent: Optional["Node"] = None, 40 | ) -> None: 41 | i = 0 42 | pos = 0 43 | while pos < to: 44 | child = self.content[i] 45 | end = pos + child.node_size 46 | if ( 47 | end > from_ 48 | and f(child, node_start + pos, parent, i) is not False 49 | and getattr(child.content, "size", None) 50 | ): 51 | start = pos + 1 52 | child.nodes_between( 53 | max(0, from_ - start), 54 | min(child.content.size, to - start), 55 | f, 56 | node_start + start, 57 | ) 58 | pos = end 59 | i += 1 60 | 61 | def descendants( 62 | self, 63 | f: Callable[["Node", int, Optional["Node"], int], bool | None], 64 | ) -> None: 65 | self.nodes_between(0, self.size, f) 66 | 67 | def text_between( 68 | self, 69 | from_: int, 70 | to: int, 71 | block_separator: str = "", 72 | leaf_text: Callable[["Node"], str] | str = "", 73 | ) -> str: 74 | text = [] 75 | separated = True 76 | 77 | def iteratee( 78 | node: "Node", 79 | pos: int, 80 | _parent: Optional["Node"], 81 | _to: int, 82 | ) -> None: 83 | nonlocal text 84 | nonlocal separated 85 | if node.is_text: 86 | text_node = cast("TextNode", node) 87 | text.append(text_node.text[max(from_, pos) - pos : to - pos]) 88 | separated = not block_separator 89 | elif node.is_leaf: 90 | if leaf_text: 91 | text.append(leaf_text(node) if callable(leaf_text) else leaf_text) 92 | elif (node_leaf_text := node.type.spec.get("leafText")) is not None: 93 | text.append(node_leaf_text(node)) 94 | separated = not block_separator 95 | elif not separated and node.is_block: 96 | text.append(block_separator) 97 | separated = True 98 | 99 | self.nodes_between(from_, to, iteratee, 0) 100 | return "".join(text) 101 | 102 | def append(self, other: "Fragment") -> "Fragment": 103 | if not other.size: 104 | return self 105 | if not self.size: 106 | return other 107 | last, first, content, i = ( 108 | self.last_child, 109 | other.first_child, 110 | self.content.copy(), 111 | 0, 112 | ) 113 | assert last is not None 114 | assert first is not None 115 | if pm_node.is_text(last) and last.same_markup(first): 116 | assert isinstance(first, pm_node.TextNode) 117 | content[len(content) - 1] = last.with_text(last.text + first.text) 118 | i = 1 119 | while i < len(other.content): 120 | content.append(other.content[i]) 121 | i += 1 122 | return Fragment(content, self.size + other.size) 123 | 124 | def cut(self, from_: int, to: int | None = None) -> "Fragment": 125 | if to is None: 126 | to = self.size 127 | if from_ == 0 and to == self.size: 128 | return self 129 | result: list[Node] = [] 130 | size = 0 131 | if to <= from_: 132 | return Fragment(result, size) 133 | i, pos = 0, 0 134 | while pos < to: 135 | child = self.content[i] 136 | end = pos + child.node_size 137 | if end > from_: 138 | if pos < from_ or end > to: 139 | if pm_node.is_text(child): 140 | child = child.cut( 141 | max(0, from_ - pos), 142 | min(text_length(child.text), to - pos), 143 | ) 144 | else: 145 | child = child.cut( 146 | max(0, from_ - pos - 1), 147 | min(child.content.size, to - pos - 1), 148 | ) 149 | result.append(child) 150 | size += child.node_size 151 | pos = end 152 | i += 1 153 | return Fragment(result, size) 154 | 155 | def cut_by_index(self, from_: int, to: int | None = None) -> "Fragment": 156 | if from_ == to: 157 | return Fragment.empty 158 | if from_ == 0 and to == len(self.content): 159 | return self 160 | return Fragment(self.content[from_:to]) 161 | 162 | def replace_child(self, index: int, node: "Node") -> "Fragment": 163 | current = self.content[index] 164 | if current == node: 165 | return self 166 | copy = self.content.copy() 167 | size = self.size + node.node_size - current.node_size 168 | copy[index] = node 169 | return Fragment(copy, size) 170 | 171 | def add_to_start(self, node: "Node") -> "Fragment": 172 | return Fragment([node, *self.content], self.size + node.node_size) 173 | 174 | def add_to_end(self, node: "Node") -> "Fragment": 175 | return Fragment([*self.content, node], self.size + node.node_size) 176 | 177 | def eq(self, other: "Fragment") -> bool: 178 | if len(self.content) != len(other.content): 179 | return False 180 | return all(a.eq(b) for (a, b) in zip(self.content, other.content, strict=True)) 181 | 182 | @property 183 | def first_child(self) -> Optional["Node"]: 184 | return self.content[0] if self.content else None 185 | 186 | @property 187 | def last_child(self) -> Optional["Node"]: 188 | return self.content[-1] if self.content else None 189 | 190 | @property 191 | def child_count(self) -> int: 192 | return len(self.content) 193 | 194 | def child(self, index: int) -> "Node": 195 | return self.content[index] 196 | 197 | def maybe_child(self, index: int) -> Optional["Node"]: 198 | try: 199 | return self.content[index] 200 | except IndexError: 201 | return None 202 | 203 | def for_each(self, f: Callable[["Node", int, int], Any]) -> None: 204 | i = 0 205 | p = 0 206 | while i < len(self.content): 207 | child = self.content[i] 208 | f(child, p, i) 209 | p += child.node_size 210 | i += 1 211 | 212 | def find_diff_start(self, other: "Fragment", pos: int = 0) -> int | None: 213 | from .diff import find_diff_start 214 | 215 | return find_diff_start(self, other, pos) 216 | 217 | def find_diff_end( 218 | self, 219 | other: "Fragment", 220 | pos: int | None = None, 221 | other_pos: int | None = None, 222 | ) -> Optional["Diff"]: 223 | from .diff import find_diff_end 224 | 225 | if pos is None: 226 | pos = self.size 227 | if other_pos is None: 228 | other_pos = other.size 229 | return find_diff_end(self, other, pos, other_pos) 230 | 231 | def find_index(self, pos: int, round: int = -1) -> dict[str, int]: 232 | if pos == 0: 233 | return ret_index(0, pos) 234 | if pos == self.size: 235 | return ret_index(len(self.content), pos) 236 | if pos > self.size or pos < 0: 237 | msg = f"Position {pos} outside of fragment ({self})" 238 | raise ValueError(msg) 239 | i = 0 240 | cur_pos = 0 241 | while True: 242 | cur = self.child(i) 243 | end = cur_pos + cur.node_size 244 | if end >= pos: 245 | if end == pos or round > 0: 246 | return ret_index(i + 1, end) 247 | return ret_index(i, cur_pos) 248 | i += 1 249 | cur_pos = end 250 | 251 | def to_json(self) -> JSONList | None: 252 | if self.content: 253 | return [item.to_json() for item in self.content] 254 | return None 255 | 256 | @classmethod 257 | def from_json(cls, schema: "Schema[Any, Any]", value: JSON) -> "Fragment": 258 | if not value: 259 | return cls.empty 260 | 261 | if isinstance(value, str): 262 | import json 263 | 264 | value = json.loads(value) 265 | 266 | if not isinstance(value, list): 267 | msg = "Invalid input for Fragment.from_json" 268 | raise ValueError(msg) 269 | 270 | return cls([schema.node_from_json(cast(JSONDict, item)) for item in value]) 271 | 272 | @classmethod 273 | def from_array(cls, array: list["Node"]) -> "Fragment": 274 | if not array: 275 | return cls.empty 276 | joined: list[Node] | None = None 277 | size = 0 278 | for i in range(len(array)): 279 | node = array[i] 280 | size += node.node_size 281 | if i and pm_node.is_text(node) and array[i - 1].same_markup(node): 282 | if not joined: 283 | joined = array[0:i] 284 | last = joined[-1] 285 | assert isinstance(last, pm_node.TextNode) 286 | joined[-1] = node.with_text(last.text + node.text) 287 | elif joined: 288 | joined.append(node) 289 | return cls(joined or array, size) 290 | 291 | @classmethod 292 | def from_( 293 | cls, 294 | nodes: Union["Fragment", "Node", Sequence["Node"], None], 295 | ) -> "Fragment": 296 | if not nodes: 297 | return cls.empty 298 | if isinstance(nodes, Fragment): 299 | return nodes 300 | if isinstance(nodes, Iterable): 301 | return cls.from_array(list(nodes)) 302 | if hasattr(nodes, "attrs"): 303 | return cls([nodes], nodes.node_size) 304 | msg = f"cannot convert {nodes!r} to a fragment" 305 | raise ValueError(msg) 306 | 307 | def to_string_inner(self) -> str: 308 | return ", ".join([str(i) for i in self.content]) 309 | 310 | def __str__(self) -> str: 311 | return f"<{self.to_string_inner()}>" 312 | 313 | def __repr__(self) -> str: 314 | return f"<{self.__class__.__name__} {self.__str__()}>" 315 | 316 | 317 | Fragment.empty = Fragment([], 0) 318 | 319 | from . import node as pm_node # noqa: E402 320 | -------------------------------------------------------------------------------- /prosemirror/model/mark.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import TYPE_CHECKING, Any, Final, Union, cast 3 | 4 | from prosemirror.utils import Attrs, JSONDict 5 | 6 | if TYPE_CHECKING: 7 | from .schema import MarkType, Schema 8 | 9 | 10 | class Mark: 11 | none: Final[list["Mark"]] = [] 12 | 13 | def __init__(self, type: "MarkType", attrs: Attrs) -> None: 14 | self.type = type 15 | self.attrs = attrs 16 | 17 | def add_to_set(self, set: list["Mark"]) -> list["Mark"]: 18 | copy: list[Mark] | None = None 19 | placed = False 20 | for i in range(len(set)): 21 | other = set[i] 22 | if self.eq(other): 23 | return set 24 | if self.type.excludes(other.type): 25 | if copy is None: 26 | copy = set[0:i] 27 | elif other.type.excludes(self.type): 28 | return set 29 | else: 30 | if not placed and other.type.rank > self.type.rank: 31 | if copy is None: 32 | copy = set[0:i] 33 | copy.append(self) 34 | placed = True 35 | if copy: 36 | copy.append(other) 37 | if copy is None: 38 | copy = set[:] 39 | if not placed: 40 | copy.append(self) 41 | return copy 42 | 43 | def remove_from_set(self, set: list["Mark"]) -> list["Mark"]: 44 | return [item for item in set if not item.eq(self)] 45 | 46 | def is_in_set(self, set: list["Mark"]) -> bool: 47 | return any(item.eq(self) for item in set) 48 | 49 | def eq(self, other: "Mark") -> bool: 50 | if self == other: 51 | return True 52 | return self.type.name == other.type.name and self.attrs == other.attrs 53 | 54 | def to_json(self) -> JSONDict: 55 | return {"type": self.type.name, "attrs": copy.deepcopy(self.attrs)} 56 | 57 | @classmethod 58 | def from_json( 59 | cls, 60 | schema: "Schema[Any, Any]", 61 | json_data: JSONDict, 62 | ) -> "Mark": 63 | if not json_data: 64 | msg = "Invalid input for Mark.fromJSON" 65 | raise ValueError(msg) 66 | name = json_data["type"] 67 | type = schema.marks.get(name) 68 | if not type: 69 | msg = f"There is no mark type {name} in this schema" 70 | raise ValueError(msg) 71 | return type.create(cast(JSONDict | None, json_data.get("attrs"))) 72 | 73 | @classmethod 74 | def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: 75 | if a == b: 76 | return True 77 | if len(a) != len(b): 78 | return False 79 | return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b, strict=True)) 80 | 81 | @classmethod 82 | def set_from(cls, marks: Union[list["Mark"], "Mark", None]) -> list["Mark"]: 83 | if not marks: 84 | return cls.none 85 | if isinstance(marks, Mark): 86 | return [marks] 87 | copy = marks[:] 88 | return sorted(copy, key=lambda item: item.type.rank) 89 | -------------------------------------------------------------------------------- /prosemirror/model/node.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections.abc import Callable 3 | from typing import TYPE_CHECKING, Any, Optional, TypedDict, TypeGuard, Union, cast 4 | 5 | from prosemirror.utils import Attrs, JSONDict, text_length 6 | 7 | from .comparedeep import compare_deep 8 | from .fragment import Fragment 9 | from .mark import Mark 10 | from .replace import Slice, replace 11 | from .resolvedpos import ResolvedPos 12 | 13 | if TYPE_CHECKING: 14 | from .content import ContentMatch 15 | from .schema import MarkType, NodeType, Schema 16 | 17 | 18 | empty_attrs: JSONDict = {} 19 | 20 | 21 | class ChildInfo(TypedDict): 22 | node: Optional["Node"] 23 | index: int 24 | offset: int 25 | 26 | 27 | class Node: 28 | def __init__( 29 | self, 30 | type: "NodeType", 31 | attrs: "Attrs", 32 | content: Fragment | None, 33 | marks: list[Mark], 34 | ) -> None: 35 | self.type = type 36 | self.attrs = attrs 37 | self.content = content or Fragment.empty 38 | self.marks = marks or Mark.none 39 | 40 | @property 41 | def node_size(self) -> int: 42 | return 1 if self.is_leaf else 2 + self.content.size 43 | 44 | @property 45 | def child_count(self) -> int: 46 | return self.content.child_count 47 | 48 | def child(self, index: int) -> "Node": 49 | return self.content.child(index) 50 | 51 | def maybe_child(self, index: int) -> Optional["Node"]: 52 | return self.content.maybe_child(index) 53 | 54 | def for_each(self, f: Callable[["Node", int, int], None]) -> None: 55 | self.content.for_each(f) 56 | 57 | def nodes_between( 58 | self, 59 | from_: int, 60 | to: int, 61 | f: Callable[["Node", int, Optional["Node"], int], bool | None], 62 | start_pos: int = 0, 63 | ) -> None: 64 | self.content.nodes_between(from_, to, f, start_pos, self) 65 | 66 | def descendants( 67 | self, 68 | f: Callable[["Node", int, Optional["Node"], int], bool | None], 69 | ) -> None: 70 | self.nodes_between(0, self.content.size, f) 71 | 72 | @property 73 | def text_content(self) -> str: 74 | if ( 75 | self.is_leaf 76 | and (node_leaf_text := self.type.spec.get("leafText")) is not None 77 | ): 78 | return node_leaf_text(self) 79 | return self.text_between(0, self.content.size, "") 80 | 81 | def text_between( 82 | self, 83 | from_: int, 84 | to: int, 85 | block_separator: str = "", 86 | leaf_text: Callable[["Node"], str] | str = "", 87 | ) -> str: 88 | return self.content.text_between(from_, to, block_separator, leaf_text) 89 | 90 | @property 91 | def first_child(self) -> Optional["Node"]: 92 | return self.content.first_child 93 | 94 | @property 95 | def last_child(self) -> Optional["Node"]: 96 | return self.content.last_child 97 | 98 | def eq(self, other: "Node") -> bool: 99 | return self == other or ( 100 | self.same_markup(other) and self.content.eq(other.content) 101 | ) 102 | 103 | def same_markup(self, other: "Node") -> bool: 104 | return self.has_markup(other.type, other.attrs, other.marks) 105 | 106 | def has_markup( 107 | self, 108 | type: "NodeType", 109 | attrs: Optional["Attrs"] = None, 110 | marks: list[Mark] | None = None, 111 | ) -> bool: 112 | return ( 113 | self.type.name == type.name 114 | and (compare_deep(self.attrs, attrs or type.default_attrs or empty_attrs)) 115 | and (Mark.same_set(self.marks, marks or Mark.none)) 116 | ) 117 | 118 | def copy(self, content: Fragment | None = None) -> "Node": 119 | if content == self.content: 120 | return self 121 | return self.__class__(self.type, self.attrs, content, self.marks) 122 | 123 | def mark(self, marks: list[Mark]) -> "Node": 124 | if marks == self.marks: 125 | return self 126 | return self.__class__(self.type, self.attrs, self.content, marks) 127 | 128 | def cut(self, from_: int, to: int | None = None) -> "Node": 129 | if from_ == 0 and to == self.content.size: 130 | return self 131 | return self.copy(self.content.cut(from_, to)) 132 | 133 | def slice( 134 | self, 135 | from_: int, 136 | to: int | None = None, 137 | include_parents: bool = False, 138 | ) -> Slice: 139 | if to is None: 140 | to = self.content.size 141 | if from_ == to: 142 | return Slice.empty 143 | from__ = self.resolve(from_) 144 | to_ = self.resolve(to) 145 | depth = 0 if include_parents else from__.shared_depth(to) 146 | start = from__.start(depth) 147 | node = from__.node(depth) 148 | content = node.content.cut(from__.pos - start, to_.pos - start) 149 | return Slice(content, from__.depth - depth, to_.depth - depth) 150 | 151 | def replace(self, from_: int, to: int, slice: Slice) -> "Node": 152 | return replace(self.resolve(from_), self.resolve(to), slice) 153 | 154 | def node_at(self, pos: int) -> Optional["Node"]: 155 | node = self 156 | while True: 157 | index_info = node.content.find_index(pos) 158 | index, offset = index_info["index"], index_info["offset"] 159 | next_node = node.maybe_child(index) 160 | if not next_node: 161 | return None 162 | node = next_node 163 | if offset == pos or node.is_text: 164 | return node 165 | pos -= offset + 1 166 | 167 | def child_after(self, pos: int) -> ChildInfo: 168 | index_info = self.content.find_index(pos) 169 | index, offset = index_info["index"], index_info["offset"] 170 | return { 171 | "node": self.content.maybe_child(index), 172 | "index": index, 173 | "offset": offset, 174 | } 175 | 176 | def child_before(self, pos: int) -> ChildInfo: 177 | if pos == 0: 178 | return {"node": None, "index": 0, "offset": 0} 179 | index_info = self.content.find_index(pos) 180 | index, offset = index_info["index"], index_info["offset"] 181 | if offset < pos: 182 | return {"node": self.content.child(index), "index": index, "offset": offset} 183 | node = self.content.child(index - 1) 184 | return {"node": node, "index": index - 1, "offset": offset - node.node_size} 185 | 186 | def resolve(self, pos: int) -> ResolvedPos: 187 | return ResolvedPos.resolve_cached(self, pos) 188 | 189 | def resolve_no_cache(self, pos: int) -> ResolvedPos: 190 | return ResolvedPos.resolve(self, pos) 191 | 192 | def range_has_mark( 193 | self, 194 | from_: int, 195 | to: int, 196 | type: Union["Mark", "MarkType"], 197 | ) -> bool: 198 | found = False 199 | if to > from_: 200 | 201 | def iteratee( 202 | node: "Node", 203 | pos: int, 204 | parent: Optional["Node"], 205 | index: int, 206 | ) -> bool: 207 | nonlocal found 208 | if type.is_in_set(node.marks): 209 | found = True 210 | return not found 211 | 212 | self.nodes_between(from_, to, iteratee) 213 | return found 214 | 215 | @property 216 | def is_block(self) -> bool: 217 | return self.type.is_block 218 | 219 | @property 220 | def is_textblock(self) -> bool: 221 | return self.type.is_textblock 222 | 223 | @property 224 | def inline_content(self) -> bool: 225 | return self.type.inline_content 226 | 227 | @property 228 | def is_inline(self) -> bool: 229 | return self.type.is_inline 230 | 231 | @property 232 | def is_text(self) -> bool: 233 | return self.type.is_text 234 | 235 | @property 236 | def is_leaf(self) -> bool: 237 | return self.type.is_leaf 238 | 239 | @property 240 | def is_atom(self) -> bool: 241 | return self.type.is_atom 242 | 243 | def __str__(self) -> str: 244 | to_debug_string = self.type.spec.get("toDebugString", None) 245 | if to_debug_string: 246 | return to_debug_string(self) 247 | name = self.type.name 248 | if self.content.size: 249 | name += f"({self.content.to_string_inner()})" 250 | return wrap_marks(self.marks, name) 251 | 252 | def __repr__(self) -> str: 253 | return f"<{self.__class__.__name__} {self.__str__()}>" 254 | 255 | def content_match_at(self, index: int) -> "ContentMatch": 256 | match = self.type.content_match.match_fragment(self.content, 0, index) 257 | if not match: 258 | msg = "Called contentMatchAt on a node with invalid content" 259 | raise ValueError(msg) 260 | return match 261 | 262 | def can_replace( 263 | self, 264 | from_: int, 265 | to: int, 266 | replacement: Fragment = Fragment.empty, 267 | start: int = 0, 268 | end: int | None = None, 269 | ) -> bool: 270 | if end is None: 271 | end = replacement.child_count 272 | one = self.content_match_at(from_).match_fragment(replacement, start, end) 273 | two: ContentMatch | None = None 274 | if one: 275 | two = one.match_fragment(self.content, to) 276 | if not two or not two.valid_end: 277 | return False 278 | for i in range(start, end): 279 | if not self.type.allows_marks(replacement.child(i).marks): 280 | return False 281 | return True 282 | 283 | def can_replace_with( 284 | self, 285 | from_: int, 286 | to: int, 287 | type: "NodeType", 288 | marks: list[Mark] | None = None, 289 | ) -> bool: 290 | if marks and not self.type.allows_marks(marks): 291 | return False 292 | start = self.content_match_at(from_).match_type(type) 293 | end: ContentMatch | None = None 294 | if start: 295 | end = start.match_fragment(self.content, to) 296 | return end.valid_end if end else False 297 | 298 | def can_append(self, other: "Node") -> bool: 299 | if other.content.size: 300 | return self.can_replace(self.child_count, self.child_count, other.content) 301 | else: 302 | return self.type.compatible_content(other.type) 303 | 304 | def check(self) -> None: 305 | if not self.type.valid_content(self.content): 306 | msg = f"Invalid content for node {self.type.name}: {str(self.content)[:50]}" 307 | raise ValueError(msg) 308 | copy = Mark.none 309 | for mark in self.marks: 310 | copy = mark.add_to_set(copy) 311 | if not Mark.same_set(copy, self.marks): 312 | msg = ( 313 | f"Invalid collection of marks for node {self.type.name}:" 314 | f" {[m.type.name for m in self.marks]!r}" 315 | ) 316 | raise ValueError(msg) 317 | 318 | def iteratee(node: "Node", offset: int, index: int) -> None: 319 | node.check() 320 | 321 | return self.content.for_each(iteratee) 322 | 323 | def to_json(self) -> JSONDict: 324 | obj: JSONDict = {"type": self.type.name} 325 | if self.attrs: 326 | obj = { 327 | **obj, 328 | "attrs": copy.deepcopy(self.attrs), 329 | } 330 | if getattr(self.content, "size", None): 331 | obj = { 332 | **obj, 333 | "content": self.content.to_json(), 334 | } 335 | if len(self.marks): 336 | obj = { 337 | **obj, 338 | "marks": [n.to_json() for n in self.marks], 339 | } 340 | return obj 341 | 342 | @classmethod 343 | def from_json(cls, schema: "Schema[Any, Any]", json_data: JSONDict | str) -> "Node": 344 | if isinstance(json_data, str): 345 | import json 346 | 347 | json_data = cast(JSONDict, json.loads(json_data)) 348 | 349 | if not json_data: 350 | msg = "Invalid input for Node.from_json" 351 | raise ValueError(msg) 352 | marks = None 353 | if json_data.get("marks"): 354 | if not isinstance(json_data["marks"], list): 355 | msg = "Invalid mark data for Node.fromJSON" 356 | raise ValueError(msg) 357 | marks = [ 358 | schema.mark_from_json(cast(JSONDict, item)) 359 | for item in json_data["marks"] 360 | ] 361 | if json_data["type"] == "text": 362 | return schema.text(str(json_data["text"]), marks) 363 | content = Fragment.from_json(schema, json_data.get("content")) 364 | return schema.node_type(str(json_data["type"])).create( 365 | cast("Attrs", json_data.get("attrs")), 366 | content, 367 | marks, 368 | ) 369 | 370 | 371 | class TextNode(Node): 372 | def __init__( 373 | self, 374 | type: "NodeType", 375 | attrs: "Attrs", 376 | content: str, 377 | marks: list[Mark], 378 | ) -> None: 379 | super().__init__(type, attrs, None, marks) 380 | if not content: 381 | msg = "Empty text nodes are not allowed" 382 | raise ValueError(msg) 383 | self.text = content 384 | 385 | def __str__(self) -> str: 386 | import json 387 | 388 | to_debug_string = self.type.spec.get("toDebugString", None) 389 | if to_debug_string: 390 | return to_debug_string(self) 391 | return wrap_marks(self.marks, json.dumps(self.text)) 392 | 393 | @property 394 | def text_content(self) -> str: 395 | return self.text 396 | 397 | def text_between( 398 | self, 399 | from_: int, 400 | to: int, 401 | block_separator: str = "", 402 | leaf_text: Callable[["Node"], str] | str = "", 403 | ) -> str: 404 | return self.text[from_:to] 405 | 406 | @property 407 | def node_size(self) -> int: 408 | return text_length(self.text) 409 | 410 | def mark(self, marks: list[Mark]) -> "TextNode": 411 | return ( 412 | self 413 | if marks == self.marks 414 | else TextNode(self.type, self.attrs, self.text, marks) 415 | ) 416 | 417 | def with_text(self, text: str) -> "TextNode": 418 | if text == self.text: 419 | return self 420 | return TextNode(self.type, self.attrs, text, self.marks) 421 | 422 | def cut(self, from_: int = 0, to: int | None = None) -> "TextNode": 423 | if to is None: 424 | to = text_length(self.text) 425 | if from_ == 0 and to == text_length(self.text): 426 | return self 427 | substring = self.text.encode("utf-16-le")[2 * from_ : 2 * to].decode( 428 | "utf-16-le", 429 | ) 430 | return self.with_text(substring) 431 | 432 | def eq(self, other: Node) -> bool: 433 | return self.same_markup(other) and self.text == getattr(other, "text", None) 434 | 435 | def to_json( 436 | self, 437 | ) -> JSONDict: 438 | return {**super().to_json(), "text": self.text} 439 | 440 | 441 | def wrap_marks(marks: list[Mark], str: str) -> str: 442 | i = len(marks) - 1 443 | while i >= 0: 444 | str = marks[i].type.name + "(" + str + ")" 445 | i -= 1 446 | return str 447 | 448 | 449 | def is_text(node: Node) -> TypeGuard[TextNode]: 450 | """ 451 | Helper function to check if a node is a text node, but with 452 | type narrowing. (TypeGuard cannot narrow the type of `self`; see 453 | https://mypy.readthedocs.io/en/stable/type_narrowing.html#typeguards-as-methods) 454 | """ 455 | return node.is_text 456 | -------------------------------------------------------------------------------- /prosemirror/model/replace.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast 2 | 3 | from prosemirror.utils import JSONDict 4 | 5 | from .fragment import Fragment 6 | 7 | if TYPE_CHECKING: 8 | from .node import Node, TextNode 9 | from .resolvedpos import ResolvedPos 10 | from .schema import Schema 11 | 12 | 13 | class ReplaceError(ValueError): 14 | pass 15 | 16 | 17 | def remove_range(content: Fragment, from_: int, to: int) -> Fragment: 18 | from_index_info = content.find_index(from_) 19 | index, offset = from_index_info["index"], from_index_info["offset"] 20 | child = content.maybe_child(index) 21 | to_index_info = content.find_index(to) 22 | index_to, offset_to = to_index_info["index"], to_index_info["offset"] 23 | if offset == from_ or cast("Node", child).is_text: 24 | if offset_to != to and not content.child(index_to).is_text: 25 | msg = "removing non-flat range" 26 | raise ValueError(msg) 27 | return content.cut(0, from_).append(content.cut(to)) 28 | assert child 29 | if index != index_to: 30 | msg = "removing non-flat range" 31 | raise ValueError(msg) 32 | return content.replace_child( 33 | index, 34 | child.copy(remove_range(child.content, from_ - offset - 1, to - offset - 1)), 35 | ) 36 | 37 | 38 | def insert_into( 39 | content: Fragment, 40 | dist: int, 41 | insert: Fragment, 42 | parent: Optional["Node"], 43 | ) -> Fragment | None: 44 | a = content.find_index(dist) 45 | index, offset = a["index"], a["offset"] 46 | child = content.maybe_child(index) 47 | if offset == dist or cast("Node", child).is_text: 48 | if parent and not parent.can_replace(index, index, insert): 49 | return None 50 | return content.cut(0, dist).append(insert).append(content.cut(dist)) 51 | assert child 52 | inner = insert_into(child.content, dist - offset - 1, insert, None) 53 | if inner: 54 | return content.replace_child(index, child.copy(inner)) 55 | return None 56 | 57 | 58 | class Slice: 59 | empty: ClassVar["Slice"] 60 | 61 | def __init__(self, content: Fragment, open_start: int, open_end: int) -> None: 62 | self.content = content 63 | self.open_start = open_start 64 | self.open_end = open_end 65 | 66 | @property 67 | def size(self) -> int: 68 | return self.content.size - self.open_start - self.open_end 69 | 70 | def insert_at(self, pos: int, fragment: Fragment) -> Optional["Slice"]: 71 | content = insert_into(self.content, pos + self.open_start, fragment, None) 72 | if content: 73 | return Slice(content, self.open_start, self.open_end) 74 | return None 75 | 76 | def remove_between(self, from_: int, to: int) -> "Slice": 77 | return Slice( 78 | remove_range(self.content, from_ + self.open_start, to + self.open_start), 79 | self.open_start, 80 | self.open_end, 81 | ) 82 | 83 | def eq(self, other: "Slice") -> bool: 84 | return ( 85 | self.content.eq(other.content) 86 | and self.open_start == other.open_start 87 | and self.open_end == other.open_end 88 | ) 89 | 90 | def __str__(self) -> str: 91 | return f"{self.content}({self.open_start},{self.open_end})" 92 | 93 | def to_json(self) -> JSONDict | None: 94 | if not self.content.size: 95 | return None 96 | json: JSONDict = {"content": self.content.to_json()} 97 | if self.open_start > 0: 98 | json = { 99 | **json, 100 | "openStart": self.open_start, 101 | } 102 | if self.open_end > 0: 103 | json = { 104 | **json, 105 | "openEnd": self.open_end, 106 | } 107 | return json 108 | 109 | @classmethod 110 | def from_json( 111 | cls, 112 | schema: "Schema[Any, Any]", 113 | json_data: JSONDict | None, 114 | ) -> "Slice": 115 | if not json_data: 116 | return cls.empty 117 | open_start = json_data.get("openStart", 0) or 0 118 | open_end = json_data.get("openEnd", 0) or 0 119 | if not isinstance(open_start, int) or not isinstance(open_end, int): 120 | msg = "invalid input for Slice.from_json" 121 | raise ValueError(msg) 122 | return cls( 123 | Fragment.from_json(schema, json_data.get("content")), 124 | open_start, 125 | open_end, 126 | ) 127 | 128 | @classmethod 129 | def max_open(cls, fragment: Fragment, open_isolating: bool = True) -> "Slice": 130 | open_start = 0 131 | open_end = 0 132 | n = fragment.first_child 133 | while n and not n.is_leaf and (open_isolating or n.type.spec.get("isolating")): 134 | open_start += 1 135 | n = n.first_child 136 | n = fragment.last_child 137 | while n and not n.is_leaf and (open_isolating or n.type.spec.get("isolating")): 138 | open_end += 1 139 | n = n.last_child 140 | return cls(fragment, open_start, open_end) 141 | 142 | 143 | Slice.empty = Slice(Fragment.empty, 0, 0) 144 | 145 | 146 | def replace(from_: "ResolvedPos", to: "ResolvedPos", slice: Slice) -> "Node": 147 | if slice.open_start > from_.depth: 148 | msg = "Inserted content deeper than insertion position" 149 | raise ReplaceError(msg) 150 | if from_.depth - slice.open_start != to.depth - slice.open_end: 151 | msg = "Inconsistent open depths" 152 | raise ReplaceError(msg) 153 | return replace_outer(from_, to, slice, 0) 154 | 155 | 156 | def replace_outer( 157 | from_: "ResolvedPos", 158 | to: "ResolvedPos", 159 | slice: Slice, 160 | depth: int, 161 | ) -> "Node": 162 | index = from_.index(depth) 163 | node = from_.node(depth) 164 | if index == to.index(depth) and depth < from_.depth - slice.open_start: 165 | inner = replace_outer(from_, to, slice, depth + 1) 166 | return node.copy(node.content.replace_child(index, inner)) 167 | elif not slice.content.size: 168 | return close(node, replace_two_way(from_, to, depth)) 169 | elif ( 170 | not slice.open_start 171 | and not slice.open_end 172 | and from_.depth == depth 173 | and to.depth == depth 174 | ): 175 | parent = from_.parent 176 | content = parent.content 177 | return close( 178 | parent, 179 | content.cut(0, from_.parent_offset) 180 | .append(slice.content) 181 | .append(content.cut(to.parent_offset)), 182 | ) 183 | else: 184 | prepare = prepare_slice_for_replace(slice, from_) 185 | start, end = prepare["start"], prepare["end"] 186 | return close(node, replace_three_way(from_, start, end, to, depth)) 187 | 188 | 189 | def check_join(main: "Node", sub: "Node") -> None: 190 | if not sub.type.compatible_content(main.type): 191 | msg = f"Cannot join {sub.type.name} onto {main.type.name}" 192 | raise ReplaceError(msg) 193 | 194 | 195 | def joinable(before: "ResolvedPos", after: "ResolvedPos", depth: int) -> "Node": 196 | node = before.node(depth) 197 | check_join(node, after.node(depth)) 198 | return node 199 | 200 | 201 | def add_node(child: "Node", target: list["Node"]) -> None: 202 | last = len(target) - 1 203 | if last >= 0 and pm_node.is_text(child) and child.same_markup(target[last]): 204 | target[last] = child.with_text(cast("TextNode", target[last]).text + child.text) 205 | else: 206 | target.append(child) 207 | 208 | 209 | def add_range( 210 | start: Optional["ResolvedPos"], 211 | end: Optional["ResolvedPos"], 212 | depth: int, 213 | target: list["Node"], 214 | ) -> None: 215 | node = cast("ResolvedPos", end or start).node(depth) 216 | start_index = 0 217 | end_index = end.index(depth) if end else node.child_count 218 | if start: 219 | start_index = start.index(depth) 220 | if start.depth > depth: 221 | start_index += 1 222 | elif start.text_offset: 223 | add_node(cast("Node", start.node_after), target) 224 | start_index += 1 225 | i = start_index 226 | while i < end_index: 227 | add_node(node.child(i), target) 228 | i += 1 229 | if end and end.depth == depth and end.text_offset: 230 | add_node(cast("Node", end.node_before), target) 231 | 232 | 233 | def close(node: "Node", content: Fragment) -> "Node": 234 | if not node.type.valid_content(content): 235 | msg = f"Invalid content for node {node.type.name}" 236 | raise ReplaceError(msg) 237 | return node.copy(content) 238 | 239 | 240 | def replace_three_way( 241 | from_: "ResolvedPos", 242 | start: "ResolvedPos", 243 | end: "ResolvedPos", 244 | to: "ResolvedPos", 245 | depth: int, 246 | ) -> Fragment: 247 | open_start = joinable(from_, start, depth + 1) if from_.depth > depth else None 248 | open_end = joinable(end, to, depth + 1) if to.depth > depth else None 249 | content: list[Node] = [] 250 | add_range(None, from_, depth, content) 251 | if open_start and open_end and start.index(depth) == end.index(depth): 252 | check_join(open_start, open_end) 253 | add_node( 254 | close(open_start, replace_three_way(from_, start, end, to, depth + 1)), 255 | content, 256 | ) 257 | else: 258 | if open_start: 259 | add_node( 260 | close(open_start, replace_two_way(from_, start, depth + 1)), 261 | content, 262 | ) 263 | add_range(start, end, depth, content) 264 | if open_end: 265 | add_node(close(open_end, replace_two_way(end, to, depth + 1)), content) 266 | add_range(to, None, depth, content) 267 | return Fragment(content) 268 | 269 | 270 | def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Fragment: 271 | content: list[Node] = [] 272 | add_range(None, from_, depth, content) 273 | if from_.depth > depth: 274 | type = joinable(from_, to, depth + 1) 275 | add_node(close(type, replace_two_way(from_, to, depth + 1)), content) 276 | add_range(to, None, depth, content) 277 | return Fragment(content) 278 | 279 | 280 | def prepare_slice_for_replace( 281 | slice: Slice, 282 | along: "ResolvedPos", 283 | ) -> dict[str, "ResolvedPos"]: 284 | extra = along.depth - slice.open_start 285 | parent = along.node(extra) 286 | node = parent.copy(slice.content) 287 | for i in range(extra - 1, -1, -1): 288 | node = along.node(i).copy(Fragment.from_(node)) 289 | return { 290 | "start": node.resolve_no_cache(slice.open_start + extra), 291 | "end": node.resolve_no_cache(node.content.size - slice.open_end - extra), 292 | } 293 | 294 | 295 | from . import node as pm_node # noqa: E402 296 | -------------------------------------------------------------------------------- /prosemirror/model/resolvedpos.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import TYPE_CHECKING, Optional, Union, cast 3 | 4 | from .mark import Mark 5 | 6 | if TYPE_CHECKING: 7 | from .node import Node 8 | 9 | 10 | class ResolvedPos: 11 | def __init__( 12 | self, 13 | pos: int, 14 | path: list[Union["Node", int]], 15 | parent_offset: int, 16 | ) -> None: 17 | self.pos = pos 18 | self.path = path 19 | self.depth = int(len(path) / 3 - 1) 20 | self.parent_offset = parent_offset 21 | 22 | def resolve_depth(self, val: int | None = None) -> int: 23 | if val is None: 24 | return self.depth 25 | return self.depth + val if val < 0 else val 26 | 27 | @property 28 | def parent(self) -> "Node": 29 | return self.node(self.depth) 30 | 31 | @property 32 | def doc(self) -> "Node": 33 | return self.node(0) 34 | 35 | def node(self, depth: int) -> "Node": 36 | return cast("Node", self.path[self.resolve_depth(depth) * 3]) 37 | 38 | def index(self, depth: int | None = None) -> int: 39 | return cast(int, self.path[self.resolve_depth(depth) * 3 + 1]) 40 | 41 | def index_after(self, depth: int) -> int: 42 | depth = self.resolve_depth(depth) 43 | return self.index(depth) + ( 44 | 0 if depth == self.depth and not self.text_offset else 1 45 | ) 46 | 47 | def start(self, depth: int | None = None) -> int: 48 | depth = self.resolve_depth(depth) 49 | return 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 50 | 51 | def end(self, depth: int | None = None) -> int: 52 | depth = self.resolve_depth(depth) 53 | return self.start(depth) + self.node(depth).content.size 54 | 55 | def before(self, depth: int | None = None) -> int: 56 | depth = self.resolve_depth(depth) 57 | if not depth: 58 | msg = "There is no position before the top level node" 59 | raise ValueError(msg) 60 | return ( 61 | self.pos if depth == self.depth + 1 else cast(int, self.path[depth * 3 - 1]) 62 | ) 63 | 64 | def after(self, depth: int | None = None) -> int: 65 | depth = self.resolve_depth(depth) 66 | if not depth: 67 | msg = "There is no position after the top level node" 68 | raise ValueError(msg) 69 | return ( 70 | self.pos 71 | if depth == self.depth + 1 72 | else cast(int, self.path[depth * 3 - 1]) 73 | + cast("Node", self.path[depth * 3]).node_size 74 | ) 75 | 76 | @property 77 | def text_offset(self) -> int: 78 | return self.pos - cast(int, self.path[-1]) 79 | 80 | @property 81 | def node_after(self) -> Optional["Node"]: 82 | parent = self.parent 83 | index = self.index(self.depth) 84 | if index == parent.child_count: 85 | return None 86 | d_off = self.pos - cast(int, self.path[-1]) 87 | child = parent.child(index) 88 | return parent.child(index).cut(d_off) if d_off else child 89 | 90 | @property 91 | def node_before(self) -> Optional["Node"]: 92 | index = self.index(self.depth) 93 | d_off = self.pos - cast(int, self.path[-1]) 94 | if d_off: 95 | return self.parent.child(index).cut(0, d_off) 96 | return None if index == 0 else self.parent.child(index - 1) 97 | 98 | def pos_at_index(self, index: int, depth: int | None = None) -> int: 99 | depth = self.resolve_depth(depth) 100 | node = cast("Node", self.path[depth * 3]) 101 | pos = 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 102 | for i in range(index): 103 | pos += node.child(i).node_size 104 | return pos 105 | 106 | def marks(self) -> list["Mark"]: 107 | parent = self.parent 108 | index = self.index() 109 | if parent.content.size == 0: 110 | return Mark.none 111 | if self.text_offset: 112 | return parent.child(index).marks 113 | main = parent.maybe_child(index - 1) 114 | other = parent.maybe_child(index) 115 | if not main: 116 | main, other = other, main 117 | marks = cast("Node", main).marks 118 | i = 0 119 | while i < len(marks): 120 | if marks[i].type.spec.get("inclusive") is False and ( 121 | not other or not marks[i].is_in_set(other.marks) 122 | ): 123 | marks = marks[i].remove_from_set(marks) 124 | i -= 1 125 | i += 1 126 | return marks 127 | 128 | def marks_across(self, end: "ResolvedPos") -> list["Mark"] | None: 129 | after = self.parent.maybe_child(self.index()) 130 | if not after or not after.is_inline: 131 | return None 132 | marks = after.marks 133 | next = end.parent.maybe_child(end.index()) 134 | i = 0 135 | while i < len(marks): 136 | if marks[i].type.spec.get("inclusive") is False and ( 137 | not next or not marks[i].is_in_set(next.marks) 138 | ): 139 | marks = marks[i].remove_from_set(marks) 140 | i -= 1 141 | i += 1 142 | return marks 143 | 144 | def shared_depth(self, pos: int) -> int: 145 | depth = self.depth 146 | while depth > 0: 147 | if self.start(depth) <= pos and self.end(depth) >= pos: 148 | return depth 149 | depth -= 1 150 | return 0 151 | 152 | def block_range( 153 | self, 154 | other: Optional["ResolvedPos"] = None, 155 | pred: Callable[["Node"], bool] | None = None, 156 | ) -> Optional["NodeRange"]: 157 | if other is None: 158 | other = self 159 | if other.pos < self.pos: 160 | return other.block_range(self) 161 | d = self.depth - ( 162 | self.parent.inline_content or (1 if self.pos == other.pos else 0) 163 | ) 164 | while d >= 0: 165 | if other.pos <= self.end(d) and (not pred or pred(self.node(d))): 166 | return NodeRange(self, other, d) 167 | d -= 1 168 | return None 169 | 170 | def same_parent(self, other: "ResolvedPos") -> bool: 171 | return self.pos - self.parent_offset == other.pos - other.parent_offset 172 | 173 | def max(self, other: "ResolvedPos") -> "ResolvedPos": 174 | return other if other.pos > self.pos else self 175 | 176 | def min(self, other: "ResolvedPos") -> "ResolvedPos": 177 | return other if other.pos < self.pos else self 178 | 179 | def __str__(self) -> str: 180 | path = "/".join([ 181 | f"{self.node(i).type.name}_{self.index(i - 1)}" 182 | for i in range(1, self.depth + 1) 183 | ]) 184 | return f"{path}:{self.parent_offset}" 185 | 186 | @classmethod 187 | def resolve(cls, doc: "Node", pos: int) -> "ResolvedPos": 188 | if not (pos >= 0 and pos <= doc.content.size): 189 | msg = f"Position {pos} out of range" 190 | raise ValueError(msg) 191 | path: list[Node | int] = [] 192 | start = 0 193 | parent_offset = pos 194 | node = doc 195 | while True: 196 | index_info = node.content.find_index(parent_offset) 197 | index, offset = index_info["index"], index_info["offset"] 198 | rem = parent_offset - offset 199 | path.extend([node, index, start + offset]) 200 | if not rem: 201 | break 202 | node = node.child(index) 203 | if node.is_text: 204 | break 205 | parent_offset = rem - 1 206 | start += offset + 1 207 | return cls(pos, path, parent_offset) 208 | 209 | @classmethod 210 | def resolve_cached(cls, doc: "Node", pos: int) -> "ResolvedPos": 211 | # no cache for now 212 | return cls.resolve(doc, pos) 213 | 214 | 215 | class NodeRange: 216 | def __init__(self, from_: ResolvedPos, to: ResolvedPos, depth: int) -> None: 217 | self.from_ = from_ 218 | self.to = to 219 | self.depth = depth 220 | 221 | @property 222 | def start(self) -> int: 223 | return self.from_.before(self.depth + 1) 224 | 225 | @property 226 | def end(self) -> int: 227 | return self.to.after(self.depth + 1) 228 | 229 | @property 230 | def parent(self) -> "Node": 231 | return self.from_.node(self.depth) 232 | 233 | @property 234 | def start_index(self) -> int: 235 | return self.from_.index(self.depth) 236 | 237 | @property 238 | def end_index(self) -> int: 239 | return self.to.index_after(self.depth) 240 | -------------------------------------------------------------------------------- /prosemirror/model/to_dom.py: -------------------------------------------------------------------------------- 1 | import html 2 | from collections.abc import Callable, Mapping, Sequence 3 | from typing import ( 4 | Any, 5 | Union, 6 | cast, 7 | ) 8 | 9 | from .fragment import Fragment 10 | from .mark import Mark 11 | from .node import Node 12 | from .schema import MarkType, NodeType, Schema 13 | 14 | HTMLNode = Union["Element", "str"] 15 | 16 | 17 | class DocumentFragment: 18 | def __init__(self, children: list[HTMLNode]) -> None: 19 | self.children = children 20 | 21 | def __str__(self) -> str: 22 | return "".join([str(c) for c in self.children]) 23 | 24 | 25 | SELF_CLOSING_ELEMENTS = frozenset({ 26 | "area", 27 | "base", 28 | "br", 29 | "col", 30 | "embed", 31 | "hr", 32 | "img", 33 | "input", 34 | "keygen", 35 | "link", 36 | "meta", 37 | "param", 38 | "source", 39 | "track", 40 | "wbr", 41 | }) 42 | 43 | 44 | class Element(DocumentFragment): 45 | def __init__( 46 | self, 47 | name: str, 48 | attrs: dict[str, str], 49 | children: list[HTMLNode], 50 | ) -> None: 51 | self.name = name 52 | self.attrs = attrs 53 | super().__init__(children) 54 | 55 | def __str__(self) -> str: 56 | attrs_str = " ".join([f'{k}="{html.escape(v)}"' for k, v in self.attrs.items()]) 57 | open_tag_str = " ".join([s for s in [self.name, attrs_str] if s]) 58 | if self.name in SELF_CLOSING_ELEMENTS: 59 | assert not self.children, "self-closing elements should not have children" 60 | return f"<{open_tag_str}>" 61 | children_str = "".join([str(c) for c in self.children]) 62 | return f"<{open_tag_str}>{children_str}" 63 | 64 | 65 | HTMLOutputSpec = str | Sequence[Any] | Element 66 | 67 | 68 | class DOMSerializer: 69 | def __init__( 70 | self, 71 | nodes: dict[str, Callable[[Node], HTMLOutputSpec]], 72 | marks: dict[str, Callable[[Mark, bool], HTMLOutputSpec]], 73 | ) -> None: 74 | self.nodes = nodes 75 | self.marks = marks 76 | 77 | def serialize_fragment( 78 | self, 79 | fragment: Fragment, 80 | target: Element | DocumentFragment | None = None, 81 | ) -> DocumentFragment: 82 | tgt: DocumentFragment = target or DocumentFragment(children=[]) 83 | 84 | top = tgt 85 | active: list[tuple[Mark, DocumentFragment]] | None = None 86 | 87 | def each(node: Node, offset: int, index: int) -> None: 88 | nonlocal top, active 89 | 90 | if active or node.marks: 91 | if not active: 92 | active = [] 93 | keep = 0 94 | rendered = 0 95 | while keep < len(active) and rendered < len(node.marks): 96 | next = node.marks[rendered] 97 | if not self.marks.get(next.type.name): 98 | rendered += 1 99 | continue 100 | if ( 101 | not next.eq(active[keep][0]) 102 | or next.type.spec.get("spanning") is False 103 | ): 104 | break 105 | keep += 1 106 | rendered += 1 107 | while keep < len(active): 108 | top = active.pop()[1] 109 | while rendered < len(node.marks): 110 | add = node.marks[rendered] 111 | rendered += 1 112 | mark_dom = self.serialize_mark(add, node.is_inline) 113 | if mark_dom: 114 | active.append((add, top)) 115 | top.children.append(mark_dom[0]) 116 | top = cast(DocumentFragment, mark_dom[1] or mark_dom[0]) 117 | top.children.append(self.serialize_node_inner(node)) 118 | 119 | fragment.for_each(each) 120 | return tgt 121 | 122 | def serialize_node_inner(self, node: Node) -> HTMLNode: 123 | dom, content_dom = type(self).render_spec(self.nodes[node.type.name](node)) 124 | if content_dom: 125 | if node.is_leaf: 126 | msg = "Content hole not allowed in a leaf node spec" 127 | raise Exception(msg) 128 | self.serialize_fragment(node.content, content_dom) 129 | return dom 130 | 131 | def serialize_node(self, node: Node) -> HTMLNode: 132 | dom = self.serialize_node_inner(node) 133 | for mark in reversed(node.marks): 134 | wrap = self.serialize_mark(mark, node.is_inline) 135 | if wrap: 136 | inner, content_dom = wrap 137 | cast(DocumentFragment, content_dom or inner).children.append(dom) 138 | dom = inner 139 | return dom 140 | 141 | def serialize_mark( 142 | self, 143 | mark: Mark, 144 | inline: bool, 145 | ) -> tuple[HTMLNode, Element | None] | None: 146 | to_dom = self.marks.get(mark.type.name) 147 | if to_dom: 148 | return type(self).render_spec(to_dom(mark, inline)) 149 | return None 150 | 151 | @classmethod 152 | def render_spec(cls, structure: HTMLOutputSpec) -> tuple[HTMLNode, Element | None]: 153 | if isinstance(structure, str): 154 | return html.escape(structure), None 155 | if isinstance(structure, Element): 156 | return structure, None 157 | tag_name = structure[0] 158 | if " " in tag_name[1:]: 159 | msg = "XML namespaces are not supported" 160 | raise NotImplementedError(msg) 161 | content_dom: Element | None = None 162 | dom = Element(name=tag_name, attrs={}, children=[]) 163 | attrs = structure[1] if len(structure) > 1 else None 164 | start = 1 165 | if isinstance(attrs, dict): 166 | start = 2 167 | for name, value in attrs.items(): 168 | if value is None: 169 | continue 170 | if " " in name[1:]: 171 | msg = "XML namespaces are not supported" 172 | raise NotImplementedError(msg) 173 | dom.attrs[name] = value 174 | for i in range(start, len(structure)): 175 | child = structure[i] 176 | if child == 0: 177 | if i < len(structure) - 1 or i > start: 178 | msg = "Content hole must be the only child of its parent node" 179 | raise Exception(msg) 180 | return dom, dom 181 | inner, inner_content = cls.render_spec(child) 182 | dom.children.append(inner) 183 | if inner_content: 184 | if content_dom: 185 | msg = "Multiple content holes" 186 | raise Exception(msg) 187 | content_dom = inner_content 188 | return dom, content_dom 189 | 190 | @classmethod 191 | def from_schema(cls, schema: Schema[Any, Any]) -> "DOMSerializer": 192 | return cls(cls.nodes_from_schema(schema), cls.marks_from_schema(schema)) 193 | 194 | @classmethod 195 | def nodes_from_schema( 196 | cls, 197 | schema: Schema[str, Any], 198 | ) -> dict[str, Callable[["Node"], HTMLOutputSpec]]: 199 | result = gather_to_dom(schema.nodes) 200 | if "text" not in result: 201 | result["text"] = lambda node: node.text 202 | return result 203 | 204 | @classmethod 205 | def marks_from_schema( 206 | cls, 207 | schema: Schema[Any, Any], 208 | ) -> dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: 209 | return gather_to_dom(schema.marks) 210 | 211 | 212 | def gather_to_dom( 213 | obj: Mapping[str, NodeType | MarkType], 214 | ) -> dict[str, Callable[..., Any]]: 215 | result = {} 216 | for name in obj: 217 | to_dom = obj[name].spec.get("toDOM") 218 | if to_dom: 219 | result[name] = to_dom 220 | return result 221 | -------------------------------------------------------------------------------- /prosemirror/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/prosemirror/py.typed -------------------------------------------------------------------------------- /prosemirror/schema/basic/__init__.py: -------------------------------------------------------------------------------- 1 | from .schema_basic import * # noqa 2 | -------------------------------------------------------------------------------- /prosemirror/schema/basic/schema_basic.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from prosemirror.model import Schema 4 | from prosemirror.model.schema import MarkSpec, NodeSpec 5 | 6 | p_dom = ["p", 0] 7 | blockquote_dom = ["blockquote", 0] 8 | hr_dom = ["hr"] 9 | pre_dom = ["pre", ["code", 0]] 10 | br_dom = ["br"] 11 | 12 | nodes: dict[str, NodeSpec] = { 13 | "doc": {"content": "block+"}, 14 | "paragraph": { 15 | "content": "inline*", 16 | "group": "block", 17 | "parseDOM": [{"tag": "p"}], 18 | "toDOM": lambda _: p_dom, 19 | }, 20 | "blockquote": { 21 | "content": "block+", 22 | "group": "block", 23 | "defining": True, 24 | "parseDOM": [{"tag": "blockquote"}], 25 | "toDOM": lambda _: blockquote_dom, 26 | }, 27 | "horizontal_rule": { 28 | "group": "block", 29 | "parseDOM": [{"tag": "hr"}], 30 | "toDOM": lambda _: hr_dom, 31 | }, 32 | "heading": { 33 | "attrs": {"level": {"default": 1}}, 34 | "content": "inline*", 35 | "group": "block", 36 | "defining": True, 37 | "parseDOM": [ 38 | {"tag": "h1", "attrs": {"level": 1}}, 39 | {"tag": "h2", "attrs": {"level": 2}}, 40 | {"tag": "h3", "attrs": {"level": 3}}, 41 | {"tag": "h4", "attrs": {"level": 4}}, 42 | {"tag": "h5", "attrs": {"level": 5}}, 43 | {"tag": "h6", "attrs": {"level": 6}}, 44 | ], 45 | "toDOM": lambda node: [f"h{node.attrs['level']}", 0], 46 | }, 47 | "code_block": { 48 | "content": "text*", 49 | "marks": "", 50 | "group": "block", 51 | "code": True, 52 | "defining": True, 53 | "parseDOM": [{"tag": "pre", "preserveWhitespace": "full"}], 54 | "toDOM": lambda _: pre_dom, 55 | }, 56 | "text": {"group": "inline"}, 57 | "image": { 58 | "inline": True, 59 | "attrs": {"src": {}, "alt": {"default": None}, "title": {"default": None}}, 60 | "group": "inline", 61 | "draggable": True, 62 | "parseDOM": [ 63 | { 64 | "tag": "img", 65 | "getAttrs": lambda dom_: { 66 | "src": dom_.get("src"), 67 | "title": dom_.get("title"), 68 | }, 69 | }, 70 | ], 71 | "toDOM": lambda node: [ 72 | "img", 73 | { 74 | "src": node.attrs["src"], 75 | "alt": node.attrs["alt"], 76 | "title": node.attrs["title"], 77 | }, 78 | ], 79 | }, 80 | "hard_break": { 81 | "inline": True, 82 | "group": "inline", 83 | "selectable": False, 84 | "parseDOM": [{"tag": "br"}], 85 | "toDOM": lambda _: br_dom, 86 | }, 87 | } 88 | 89 | em_dom = ["em", 0] 90 | strong_dom = ["strong", 0] 91 | code_dom = ["code", 0] 92 | 93 | marks: dict[str, MarkSpec] = { 94 | "link": { 95 | "attrs": {"href": {}, "title": {"default": None}}, 96 | "inclusive": False, 97 | "parseDOM": [{"tag": "a", "getAttrs": lambda d: {"href": d.get("href")}}], 98 | "toDOM": lambda node, _: [ 99 | "a", 100 | {"href": node.attrs["href"], "title": node.attrs["title"]}, 101 | 0, 102 | ], 103 | }, 104 | "em": { 105 | "parseDOM": [{"tag": "i"}, {"tag": "em"}, {"style": "font-style=italic"}], 106 | "toDOM": lambda _, __: em_dom, 107 | }, 108 | "strong": { 109 | "parseDOM": [{"tag": "strong"}, {"tag": "b"}, {"style": "font-weight"}], 110 | "toDOM": lambda _, __: strong_dom, 111 | }, 112 | "code": {"parseDOM": [{"tag": "code"}], "toDOM": lambda _, __: code_dom}, 113 | } 114 | 115 | 116 | schema: Schema[Any, Any] = Schema({"nodes": nodes, "marks": marks}) 117 | -------------------------------------------------------------------------------- /prosemirror/schema/list/__init__.py: -------------------------------------------------------------------------------- 1 | from .schema_list import * # noqa 2 | -------------------------------------------------------------------------------- /prosemirror/schema/list/schema_list.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from prosemirror.model.schema import Nodes, NodeSpec 4 | 5 | OL_DOM = ["ol", 0] 6 | UL_DOM = ["ul", 0] 7 | LI_DOM = ["li", 0] 8 | 9 | 10 | orderd_list = NodeSpec( 11 | attrs={"order": {"default": 1}}, 12 | parseDOM=[{"tag": "ol"}], 13 | toDOM=lambda node: ( 14 | OL_DOM 15 | if node.attrs.get("order") == 1 16 | else ["ol", {"start": node.attrs["order"]}, 0] 17 | ), 18 | ) 19 | 20 | bullet_list = NodeSpec(parseDOM=[{"tag": "ul"}], toDOM=lambda _: UL_DOM) 21 | 22 | list_item = NodeSpec(parseDOM=[{"tag": "li"}], defining=True, toDOM=lambda _: LI_DOM) 23 | 24 | 25 | def add(obj: "NodeSpec", props: "NodeSpec") -> "NodeSpec": 26 | return {**obj, **props} 27 | 28 | 29 | def add_list_nodes( 30 | nodes: dict["Nodes", "NodeSpec"], 31 | item_content: str, 32 | list_group: str, 33 | ) -> dict["Nodes", "NodeSpec"]: 34 | copy = nodes.copy() 35 | copy.update({ 36 | cast(Nodes, "ordered_list"): add( 37 | orderd_list, 38 | NodeSpec(content="list_item+", group=list_group), 39 | ), 40 | cast(Nodes, "bullet_list"): add( 41 | bullet_list, 42 | NodeSpec(content="list_item+", group=list_group), 43 | ), 44 | cast(Nodes, "list_item"): add(list_item, NodeSpec(content=item_content)), 45 | }) 46 | return copy 47 | -------------------------------------------------------------------------------- /prosemirror/test_builder/__init__.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | from typing import Any 4 | 5 | from prosemirror.model import Node, Schema 6 | from prosemirror.schema.basic import schema as _schema 7 | from prosemirror.schema.list import add_list_nodes 8 | 9 | from .build import builders 10 | 11 | nodes = add_list_nodes(_schema.spec["nodes"], "paragraph block*", "block") 12 | 13 | nodes.update({ 14 | "doc": { 15 | "content": "block+", 16 | "attrs": {"meta": {"default": None}}, 17 | }, 18 | }) 19 | 20 | test_schema: Schema[Any, Any] = Schema({ 21 | "nodes": nodes, 22 | "marks": _schema.spec["marks"], 23 | }) 24 | 25 | out = builders( 26 | test_schema, 27 | { 28 | "doc": {"nodeType": "doc"}, 29 | "docMetaOne": {"nodeType": "doc", "meta": 1}, 30 | "docMetaTwo": {"nodeType": "doc", "meta": 2}, 31 | "p": {"nodeType": "paragraph"}, 32 | "pre": {"nodeType": "code_block"}, 33 | "h1": {"nodeType": "heading", "level": 1}, 34 | "h2": {"nodeType": "heading", "level": 2}, 35 | "h3": {"nodeType": "heading", "level": 3}, 36 | "li": {"nodeType": "list_item"}, 37 | "ul": {"nodeType": "bullet_list"}, 38 | "ol": {"nodeType": "ordered_list"}, 39 | "br": {"nodeType": "hard_break"}, 40 | "img": {"nodeType": "image", "src": "img.png"}, 41 | "hr": {"nodeType": "horizontal_rule"}, 42 | "a": {"markType": "link", "href": "foo"}, 43 | }, 44 | ) 45 | 46 | 47 | def eq(a: Node, b: Node) -> bool: 48 | return a.eq(b) 49 | -------------------------------------------------------------------------------- /prosemirror/test_builder/build.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | import contextlib 4 | import re 5 | from collections.abc import Callable 6 | from typing import Any 7 | 8 | from prosemirror.model import Node, NodeType, Schema 9 | from prosemirror.utils import Attrs, JSONDict 10 | 11 | NO_TAG = Node.tag = {} 12 | 13 | 14 | def flatten( 15 | schema: Schema[Any, Any], 16 | children: list[Node | JSONDict | str], 17 | f: Callable[[Node], Node], 18 | ) -> tuple[list[Node], dict[str, int]]: 19 | result, pos, tag = [], 0, NO_TAG 20 | 21 | for child in children: 22 | if hasattr(child, "tag") and child.tag != NO_TAG: 23 | if tag == NO_TAG: 24 | tag = {} 25 | for id in child.tag: 26 | tag[id] = child.tag[id] + (0 if child.is_text else 1) + pos 27 | if isinstance(child, dict) and "tag" in child and child["tag"] != Node.tag: 28 | if tag == NO_TAG: 29 | tag = {} 30 | for id in child["tag"]: 31 | tag[id] = child["tag"][id] + (0 if "flat" in child else 1) + pos 32 | if isinstance(child, str): 33 | at = 0 34 | out = "" 35 | for m in re.finditer(r"<(\w+)>", child): 36 | out += child[at : m.start()] 37 | pos += m.start() - at 38 | at = m.start() + len(m[0]) 39 | if tag == NO_TAG: 40 | tag = {} 41 | tag[m[1]] = pos 42 | out += child[at:] 43 | pos += len(child) - at 44 | if out: 45 | result.append(f(schema.text(out))) 46 | elif isinstance(child, dict) and "flat" in child: 47 | for item in child["flat"]: 48 | node = f(item) 49 | pos += node.node_size 50 | result.append(node) 51 | elif getattr(child, "flat", 0): 52 | for item in child.flat: 53 | node = f(item) 54 | pos += node.node_size 55 | result.append(node) 56 | else: 57 | node = f(child) 58 | pos += node.node_size 59 | result.append(node) 60 | return result, tag 61 | 62 | 63 | def block(type: NodeType, attrs: Attrs | None = None): 64 | def result(*args): 65 | my_attrs = attrs 66 | if ( 67 | args 68 | and args[0] 69 | and not isinstance(args[0], str | Node) 70 | and not getattr(args[0], "flat", None) 71 | and "flat" not in args[0] 72 | ): 73 | my_attrs.update(args[0]) 74 | args = args[1:] 75 | nodes, tag = flatten(type.schema, args, lambda x: x) 76 | node = type.create(my_attrs, nodes) 77 | if tag != NO_TAG: 78 | node.tag = tag 79 | return node 80 | 81 | if type.is_leaf: 82 | with contextlib.suppress(ValueError): 83 | result.flat = [type.create(attrs)] 84 | 85 | return result 86 | 87 | 88 | def mark(type: NodeType, attrs: Attrs): 89 | def result(*args): 90 | my_attrs = attrs.copy() 91 | if ( 92 | args 93 | and args[0] 94 | and not isinstance(args[0], str | Node) 95 | and not getattr(args[0], "flat", None) 96 | and "flat" not in args[0] 97 | ): 98 | my_attrs.update(args[0]) 99 | args = args[1:] 100 | mark = type.create(my_attrs) 101 | 102 | def f(n): 103 | return ( 104 | n if mark.type.is_in_set(n.marks) else n.mark(mark.add_to_set(n.marks)) 105 | ) 106 | 107 | nodes, tag = flatten(type.schema, args, f) 108 | return {"flat": nodes, "tag": tag} 109 | 110 | return result 111 | 112 | 113 | def builders(schema: Schema[Any, Any], names): 114 | result = {"schema": schema} 115 | for name in schema.nodes: 116 | result[name] = block(schema.nodes[name], {}) 117 | for name in schema.marks: 118 | result[name] = mark(schema.marks[name], {}) 119 | 120 | if names: 121 | for name in names: 122 | value = names[name] 123 | type_name = value.get("nodeType") or value.get("markType") or name 124 | type = schema.nodes.get(type_name) 125 | if type: 126 | result[name] = block(type, value) 127 | else: 128 | type = schema.marks.get(type_name) 129 | if type: 130 | result[name] = mark(type, value) 131 | return result 132 | -------------------------------------------------------------------------------- /prosemirror/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .attr_step import AttrStep 2 | from .map import Mapping, MapResult, StepMap 3 | from .mark_step import AddMarkStep, AddNodeMarkStep, RemoveMarkStep, RemoveNodeMarkStep 4 | from .replace import ( 5 | close_fragment, 6 | covered_depths, 7 | fits_trivially, 8 | replace_step, 9 | ) 10 | from .replace_step import ReplaceAroundStep, ReplaceStep 11 | from .step import Step, StepResult 12 | from .structure import ( 13 | can_join, 14 | can_split, 15 | drop_point, 16 | find_wrapping, 17 | insert_point, 18 | join_point, 19 | lift_target, 20 | ) 21 | from .transform import Transform, TransformError 22 | 23 | __all__ = [ 24 | "AddMarkStep", 25 | "AddNodeMarkStep", 26 | "AttrStep", 27 | "MapResult", 28 | "Mapping", 29 | "RemoveMarkStep", 30 | "RemoveNodeMarkStep", 31 | "ReplaceAroundStep", 32 | "ReplaceStep", 33 | "Step", 34 | "StepMap", 35 | "StepResult", 36 | "Transform", 37 | "TransformError", 38 | "can_join", 39 | "can_split", 40 | "close_fragment", 41 | "covered_depths", 42 | "drop_point", 43 | "find_wrapping", 44 | "fits_trivially", 45 | "insert_point", 46 | "join_point", 47 | "lift_target", 48 | "replace_step", 49 | ] 50 | -------------------------------------------------------------------------------- /prosemirror/transform/attr_step.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast 2 | 3 | from prosemirror.model import Fragment, Node, Schema, Slice 4 | from prosemirror.transform.map import Mappable, StepMap 5 | from prosemirror.transform.step import Step, StepResult, step_json_id 6 | from prosemirror.utils import JSON, JSONDict 7 | 8 | 9 | class AttrStep(Step): 10 | def __init__(self, pos: int, attr: str, value: JSON) -> None: 11 | super().__init__() 12 | self.pos = pos 13 | self.attr = attr 14 | self.value = value 15 | 16 | def apply(self, doc: Node) -> StepResult: 17 | node = doc.node_at(self.pos) 18 | if not node: 19 | return StepResult.fail("No node at attribute step's position") 20 | attrs = {} 21 | for name in node.attrs: 22 | attrs[name] = node.attrs[name] 23 | attrs[self.attr] = self.value 24 | updated = node.type.create(attrs, None, node.marks) 25 | return StepResult.from_replace( 26 | doc, 27 | self.pos, 28 | self.pos + 1, 29 | Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), 30 | ) 31 | 32 | def get_map(self) -> StepMap: 33 | return StepMap.empty 34 | 35 | def invert(self, doc: Node) -> Step: 36 | node_at_pos = doc.node_at(self.pos) 37 | assert node_at_pos is not None 38 | return AttrStep(self.pos, self.attr, node_at_pos.attrs[self.attr]) 39 | 40 | def map(self, mapping: Mappable) -> Step | None: 41 | pos = mapping.map_result(self.pos, 1) 42 | return None if pos.deleted_after else AttrStep(pos.pos, self.attr, self.value) 43 | 44 | def to_json(self) -> JSONDict: 45 | return { 46 | "stepType": "attr", 47 | "pos": self.pos, 48 | "attr": self.attr, 49 | "value": self.value, 50 | } 51 | 52 | @staticmethod 53 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AttrStep": 54 | if isinstance(json_data, str): 55 | import json 56 | 57 | json_data = cast(JSONDict, json.loads(json_data)) 58 | 59 | if not isinstance(json_data["pos"], int) or not isinstance( 60 | json_data["attr"], 61 | str, 62 | ): 63 | msg = "Invalid input for AttrStep.from_json" 64 | raise ValueError(msg) 65 | return AttrStep(json_data["pos"], json_data["attr"], json_data["value"]) 66 | 67 | 68 | step_json_id("attr", AttrStep) 69 | -------------------------------------------------------------------------------- /prosemirror/transform/doc_attr_step.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast 2 | 3 | from prosemirror.model import Node, Schema 4 | from prosemirror.transform.map import Mappable, StepMap 5 | from prosemirror.transform.step import Step, StepResult, step_json_id 6 | from prosemirror.utils import JSON, JSONDict 7 | 8 | 9 | class DocAttrStep(Step): 10 | def __init__(self, attr: str, value: JSON) -> None: 11 | super().__init__() 12 | self.attr = attr 13 | self.value = value 14 | 15 | def apply(self, doc: Node) -> StepResult: 16 | attrs = {} 17 | for name in doc.attrs: 18 | attrs[name] = doc.attrs[name] 19 | attrs[self.attr] = self.value 20 | updated = doc.type.create(attrs, doc.content, doc.marks) 21 | return StepResult.ok(updated) 22 | 23 | def get_map(self) -> StepMap: 24 | return StepMap.empty 25 | 26 | def invert(self, doc: Node) -> Step: 27 | return DocAttrStep(self.attr, doc.attrs[self.attr]) 28 | 29 | def map(self, mapping: Mappable) -> Step | None: 30 | return self 31 | 32 | def to_json(self) -> JSONDict: 33 | json_data = { 34 | "stepType": "docAttr", 35 | "attr": self.attr, 36 | "value": self.value, 37 | } 38 | 39 | return json_data 40 | 41 | @staticmethod 42 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "DocAttrStep": 43 | if isinstance(json_data, str): 44 | import json 45 | 46 | json_data = cast(JSONDict, json.loads(json_data)) 47 | 48 | if not isinstance(json_data["attr"], str): 49 | msg = "Invalid input for DocAttrStep.from_json" 50 | raise ValueError(msg) 51 | return DocAttrStep(json_data["attr"], json_data["value"]) 52 | 53 | 54 | step_json_id("docAttr", DocAttrStep) 55 | -------------------------------------------------------------------------------- /prosemirror/transform/map.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections.abc import Callable 3 | from typing import ClassVar, Literal, overload 4 | 5 | lower16 = 0xFFFF 6 | factor16 = 2**16 7 | 8 | 9 | def make_recover(index: float, offset: int) -> int: 10 | return int(index + offset * factor16) 11 | 12 | 13 | def recover_index(value: int) -> int: 14 | return int(value & lower16) 15 | 16 | 17 | def recover_offset(value: int) -> int: 18 | return int((value - (value & lower16)) / factor16) 19 | 20 | 21 | DEL_BEFORE = 1 22 | DEL_AFTER = 2 23 | DEL_ACROSS = 4 24 | DEL_SIDE = 8 25 | 26 | 27 | class MapResult: 28 | def __init__(self, pos: int, del_info: int = 0, recover: int | None = None) -> None: 29 | self.pos = pos 30 | self.del_info = del_info 31 | self.recover = recover 32 | 33 | # get deleted() { return (this.delInfo & DEL_SIDE) > 0 } 34 | 35 | # get deletedBefore() { return (this.delInfo & (DEL_BEFORE | DEL_ACROSS)) > 0 } 36 | 37 | # get deletedAfter() { return (this.delInfo & (DEL_AFTER | DEL_ACROSS)) > 0 } 38 | 39 | # get deletedAcross() { return (this.delInfo & DEL_ACROSS) > 0 } 40 | 41 | @property 42 | def deleted(self) -> bool: 43 | return (self.del_info & DEL_SIDE) > 0 44 | 45 | @property 46 | def deleted_before(self) -> bool: 47 | return (self.del_info & (DEL_BEFORE | DEL_ACROSS)) > 0 48 | 49 | @property 50 | def deleted_after(self) -> bool: 51 | return (self.del_info & (DEL_AFTER | DEL_ACROSS)) > 0 52 | 53 | @property 54 | def deleted_across(self) -> bool: 55 | return (self.del_info & DEL_ACROSS) > 0 56 | 57 | 58 | class Mappable(metaclass=abc.ABCMeta): 59 | @abc.abstractmethod 60 | def map(self, pos: int, assoc: int = 1) -> int: ... 61 | 62 | @abc.abstractmethod 63 | def map_result(self, pos: int, assoc: int = 1) -> MapResult: ... 64 | 65 | 66 | class StepMap(Mappable): 67 | empty: ClassVar["StepMap"] 68 | 69 | def __init__(self, ranges: list[int], inverted: bool = False) -> None: 70 | # prosemirror-transform overrides the constructor to return the 71 | # StepMap.empty singleton when ranges are empty. 72 | # It is not easy to do in Python, and the intent of that is to make sure 73 | # empty stepmaps can eq to each other, which is already the case in Python. 74 | self.ranges = ranges 75 | self.inverted = inverted 76 | 77 | def recover(self, value: int) -> int: 78 | diff = 0 79 | index = recover_index(value) 80 | if not self.inverted: 81 | for i in range(index): 82 | diff += self.ranges[i * 3 + 2] - self.ranges[i * 3 + 1] 83 | return self.ranges[index * 3] + diff + recover_offset(value) 84 | 85 | def map(self, pos: int, assoc: int = 1) -> int: 86 | return self._map(pos, assoc, True) 87 | 88 | def map_result(self, pos: int, assoc: int = 1) -> MapResult: 89 | return self._map(pos, assoc, False) 90 | 91 | @overload 92 | def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: ... 93 | 94 | @overload 95 | def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: ... 96 | 97 | def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: 98 | diff = 0 99 | old_index = 2 if self.inverted else 1 100 | new_index = 1 if self.inverted else 2 101 | for i in range(0, len(self.ranges), 3): 102 | start = self.ranges[i] - (diff if self.inverted else 0) 103 | if start > pos: 104 | break 105 | old_size = self.ranges[i + old_index] 106 | new_size = self.ranges[i + new_index] 107 | end = start + old_size 108 | if pos <= end: 109 | if not old_size: 110 | side = assoc 111 | elif pos == start: 112 | side = -1 113 | elif pos == end: 114 | side = 1 115 | else: 116 | side = assoc 117 | result = start + diff + (0 if side < 0 else new_size) 118 | if simple: 119 | return result 120 | recover = ( 121 | None 122 | if pos == (start if assoc < 0 else end) 123 | else make_recover(i / 3, pos - start) 124 | ) 125 | del_info = ( 126 | DEL_AFTER 127 | if pos == start 128 | else (DEL_BEFORE if pos == end else DEL_ACROSS) 129 | ) 130 | if pos != start if assoc < 0 else pos != end: 131 | del_info |= DEL_SIDE 132 | return MapResult(result, del_info, recover) 133 | diff += new_size - old_size 134 | return pos + diff if simple else MapResult(pos + diff, 0, None) 135 | 136 | def touches(self, pos: int, recover: int) -> bool: 137 | diff = 0 138 | index = recover_index(recover) 139 | old_index = 2 if self.inverted else 1 140 | new_index = 1 if self.inverted else 2 141 | for i in range(len(self.ranges), 3): 142 | start = self.ranges[i] - (diff if self.inverted else 0) 143 | if start > pos: 144 | break 145 | old_size = self.ranges[i + old_index] 146 | end = start + old_size 147 | if pos <= end and i == index * 3: 148 | return True 149 | diff += self.ranges[i + new_index] - old_size 150 | return False 151 | 152 | def for_each(self, f: Callable[[int, int, int, int], None]) -> None: 153 | old_index = 2 if self.inverted else 1 154 | new_index = 1 if self.inverted else 2 155 | i = 0 156 | diff = 0 157 | while i < len(self.ranges): 158 | start = self.ranges[i] 159 | old_start = start - (diff if self.inverted else 0) 160 | new_start = start + (0 if self.inverted else diff) 161 | old_size = self.ranges[i + old_index] 162 | new_size = self.ranges[i + new_index] 163 | f(old_start, old_start + old_size, new_start, new_start + new_size) 164 | i += 3 165 | 166 | def invert(self) -> "StepMap": 167 | return StepMap(self.ranges, not self.inverted) 168 | 169 | def __str__(self) -> str: 170 | return ("-" if self.inverted else "") + str(self.ranges) 171 | 172 | 173 | StepMap.empty = StepMap([]) 174 | 175 | 176 | class Mapping(Mappable): 177 | def __init__( 178 | self, 179 | maps: list[StepMap] | None = None, 180 | mirror: list[int] | None = None, 181 | from_: int | None = None, 182 | to: int | None = None, 183 | ) -> None: 184 | self.maps = maps or [] 185 | self.from_ = from_ or 0 186 | self.to = len(self.maps) if to is None else to 187 | self.mirror = mirror 188 | 189 | def slice(self, from_: int = 0, to: int | None = None) -> "Mapping": 190 | if to is None: 191 | to = len(self.maps) 192 | return Mapping(self.maps, self.mirror, from_, to) 193 | 194 | def copy(self) -> "Mapping": 195 | return Mapping( 196 | self.maps[:], 197 | (self.mirror[:] if self.mirror else None), 198 | self.from_, 199 | self.to, 200 | ) 201 | 202 | def append_map(self, map: StepMap, mirrors: int | None = None) -> None: 203 | self.maps.append(map) 204 | self.to = len(self.maps) 205 | if mirrors is not None: 206 | self.set_mirror(len(self.maps) - 1, mirrors) 207 | 208 | def append_mapping(self, mapping: "Mapping") -> None: 209 | i = 0 210 | start_size = len(self.maps) 211 | while i < len(mapping.maps): 212 | mirr = mapping.get_mirror(i) 213 | i += 1 214 | self.append_map( 215 | mapping.maps[i], 216 | (start_size + mirr) if (mirr is not None and mirr < i) else None, 217 | ) 218 | 219 | def get_mirror(self, n: int) -> int | None: 220 | if self.mirror: 221 | for i in range(len(self.mirror)): 222 | if (self.mirror[i]) == n: 223 | return self.mirror[i + (-1 if i % 2 else 1)] 224 | return None 225 | 226 | def set_mirror(self, n: int, m: int) -> None: 227 | if not self.mirror: 228 | self.mirror = [] 229 | self.mirror.extend([n, m]) 230 | 231 | def append_mapping_inverted(self, mapping: "Mapping") -> None: 232 | i = len(mapping.maps) - 1 233 | total_size = len(self.maps) + len(mapping.maps) 234 | while i >= 0: 235 | mirr = mapping.get_mirror(i) 236 | self.append_map( 237 | mapping.maps[i].invert(), 238 | (total_size - mirr - 1) if (mirr is not None and mirr > i) else None, 239 | ) 240 | i -= 1 241 | 242 | def invert(self) -> "Mapping": 243 | inverse = Mapping() 244 | inverse.append_mapping_inverted(self) 245 | return inverse 246 | 247 | def map(self, pos: int, assoc: int = 1) -> int: 248 | if self.mirror: 249 | return self._map(pos, assoc, True) 250 | for i in range(self.from_, self.to): 251 | pos = self.maps[i].map(pos, assoc) 252 | return pos 253 | 254 | def map_result(self, pos: int, assoc: int = 1) -> MapResult: 255 | return self._map(pos, assoc, False) 256 | 257 | @overload 258 | def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: ... 259 | 260 | @overload 261 | def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: ... 262 | 263 | def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: 264 | del_info = 0 265 | 266 | i = self.from_ 267 | while i < self.to: 268 | map = self.maps[i] 269 | result = map.map_result(pos, assoc) 270 | if result.recover is not None: 271 | corr = self.get_mirror(i) 272 | if corr is not None and corr > i and corr < self.to: 273 | i = corr 274 | pos = self.maps[corr].recover(result.recover) 275 | i += 1 276 | continue 277 | del_info |= result.del_info 278 | pos = result.pos 279 | i += 1 280 | return pos if simple else MapResult(pos, del_info, None) 281 | -------------------------------------------------------------------------------- /prosemirror/transform/mark.py: -------------------------------------------------------------------------------- 1 | # Upstream adds methods to the Transform class prototype in this file, instead 2 | # see transform.py for add_mark, remove_mark, and clear_incompatible. 3 | -------------------------------------------------------------------------------- /prosemirror/transform/mark_step.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any, cast 3 | 4 | from prosemirror.model import Fragment, Mark, Node, Schema, Slice 5 | from prosemirror.transform.map import Mappable 6 | from prosemirror.transform.step import Step, StepResult, step_json_id 7 | from prosemirror.utils import JSONDict 8 | 9 | 10 | def map_fragment( 11 | fragment: Fragment, 12 | f: Callable[[Node, Node, int], Node], 13 | parent: Node, 14 | ) -> Fragment: 15 | mapped = [] 16 | for i in range(fragment.child_count): 17 | child = fragment.child(i) 18 | if getattr(child.content, "size", None): 19 | child = child.copy(map_fragment(child.content, f, child)) 20 | if child.is_inline: 21 | child = f(child, parent, i) 22 | mapped.append(child) 23 | return fragment.from_array(mapped) 24 | 25 | 26 | class AddMarkStep(Step): 27 | def __init__(self, from_: int, to: int, mark: Mark) -> None: 28 | super().__init__() 29 | self.from_ = from_ 30 | self.to = to 31 | self.mark = mark 32 | 33 | def apply(self, doc: Node) -> StepResult: 34 | old_slice = doc.slice(self.from_, self.to) 35 | from__ = doc.resolve(self.from_) 36 | parent = from__.node(from__.shared_depth(self.to)) 37 | 38 | def iteratee(node: Node, parent: Node | None, i: int) -> Node: 39 | if parent and ( 40 | not node.is_atom or not parent.type.allows_mark_type(self.mark.type) 41 | ): 42 | return node 43 | return node.mark(self.mark.add_to_set(node.marks)) 44 | 45 | slice = Slice( 46 | map_fragment(old_slice.content, iteratee, parent), 47 | old_slice.open_start, 48 | old_slice.open_end, 49 | ) 50 | return StepResult.from_replace(doc, self.from_, self.to, slice) 51 | 52 | def invert(self, doc: Node | None = None) -> Step: 53 | return RemoveMarkStep(self.from_, self.to, self.mark) 54 | 55 | def map(self, mapping: Mappable) -> Step | None: 56 | from_ = mapping.map_result(self.from_, 1) 57 | to = mapping.map_result(self.to, -1) 58 | if (from_.deleted and to.deleted) or from_.pos > to.pos: 59 | return None 60 | return AddMarkStep(from_.pos, to.pos, self.mark) 61 | 62 | def merge(self, other: Step) -> Step | None: 63 | if ( 64 | isinstance(other, AddMarkStep) 65 | and other.mark.eq(self.mark) 66 | and self.from_ <= other.to 67 | and self.to >= other.from_ 68 | ): 69 | return AddMarkStep( 70 | min(self.from_, other.from_), 71 | max(self.to, other.to), 72 | self.mark, 73 | ) 74 | return None 75 | 76 | def to_json(self) -> JSONDict: 77 | return { 78 | "stepType": "addMark", 79 | "mark": self.mark.to_json(), 80 | "from": self.from_, 81 | "to": self.to, 82 | } 83 | 84 | @staticmethod 85 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AddMarkStep": 86 | if isinstance(json_data, str): 87 | import json 88 | 89 | json_data = cast(JSONDict, json.loads(json_data)) 90 | 91 | if not isinstance(json_data["from"], int) or not isinstance( 92 | json_data["to"], 93 | int, 94 | ): 95 | msg = "Invalid input for AddMarkStep.from_json" 96 | raise ValueError(msg) 97 | return AddMarkStep( 98 | json_data["from"], 99 | json_data["to"], 100 | schema.mark_from_json(cast(JSONDict, json_data["mark"])), 101 | ) 102 | 103 | 104 | step_json_id("addMark", AddMarkStep) 105 | 106 | 107 | class RemoveMarkStep(Step): 108 | def __init__(self, from_: int, to: int, mark: Mark) -> None: 109 | super().__init__() 110 | self.from_ = from_ 111 | self.to = to 112 | self.mark = mark 113 | 114 | def apply(self, doc: Node) -> StepResult: 115 | old_slice = doc.slice(self.from_, self.to) 116 | 117 | def iteratee(node: Node, parent: Node | None, i: int) -> Node: 118 | return node.mark(self.mark.remove_from_set(node.marks)) 119 | 120 | slice = Slice( 121 | map_fragment(old_slice.content, iteratee, doc), 122 | old_slice.open_start, 123 | old_slice.open_end, 124 | ) 125 | return StepResult.from_replace(doc, self.from_, self.to, slice) 126 | 127 | def invert(self, doc: Node | None = None) -> Step: 128 | return AddMarkStep(self.from_, self.to, self.mark) 129 | 130 | def map(self, mapping: Mappable) -> Step | None: 131 | from_ = mapping.map_result(self.from_, 1) 132 | to = mapping.map_result(self.to, -1) 133 | if (from_.deleted and to.deleted) or (from_.pos > to.pos): 134 | return None 135 | return RemoveMarkStep(from_.pos, to.pos, self.mark) 136 | 137 | def merge(self, other: Step) -> Step | None: 138 | if ( 139 | isinstance(other, RemoveMarkStep) 140 | and (other.mark.eq(self.mark)) 141 | and (self.from_ <= other.to) 142 | and self.to >= other.from_ 143 | ): 144 | return RemoveMarkStep( 145 | min(self.from_, other.from_), 146 | max(self.to, other.to), 147 | self.mark, 148 | ) 149 | return None 150 | 151 | def to_json(self) -> JSONDict: 152 | return { 153 | "stepType": "removeMark", 154 | "mark": self.mark.to_json(), 155 | "from": self.from_, 156 | "to": self.to, 157 | } 158 | 159 | @staticmethod 160 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> Step: 161 | if isinstance(json_data, str): 162 | import json 163 | 164 | json_data = cast(JSONDict, json.loads(json_data)) 165 | 166 | if not isinstance(json_data["from"], int) or not isinstance( 167 | json_data["to"], 168 | int, 169 | ): 170 | msg = "Invalid input for RemoveMarkStep.from_json" 171 | raise ValueError(msg) 172 | return RemoveMarkStep( 173 | json_data["from"], 174 | json_data["to"], 175 | schema.mark_from_json(cast(JSONDict, json_data["mark"])), 176 | ) 177 | 178 | 179 | step_json_id("removeMark", RemoveMarkStep) 180 | 181 | 182 | class AddNodeMarkStep(Step): 183 | def __init__(self, pos: int, mark: Mark) -> None: 184 | super().__init__() 185 | self.pos = pos 186 | self.mark = mark 187 | 188 | def apply(self, doc: Node) -> StepResult: 189 | node = doc.node_at(self.pos) 190 | if not node: 191 | return StepResult.fail("No node at mark step's position") 192 | updated = node.type.create(node.attrs, None, self.mark.add_to_set(node.marks)) 193 | return StepResult.from_replace( 194 | doc, 195 | self.pos, 196 | self.pos + 1, 197 | Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), 198 | ) 199 | 200 | def invert(self, doc: Node) -> Step: 201 | node = doc.node_at(self.pos) 202 | if node: 203 | new_set = self.mark.add_to_set(node.marks) 204 | if len(new_set) == len(node.marks): 205 | for i in range(len(node.marks)): 206 | if not node.marks[i].is_in_set(new_set): 207 | return AddNodeMarkStep(self.pos, node.marks[i]) 208 | return AddNodeMarkStep(self.pos, self.mark) 209 | return RemoveNodeMarkStep(self.pos, self.mark) 210 | 211 | def map(self, mapping: Mappable) -> Step | None: 212 | pos = mapping.map_result(self.pos, 1) 213 | return None if pos.deleted_after else AddNodeMarkStep(pos.pos, self.mark) 214 | 215 | def to_json(self) -> JSONDict: 216 | return { 217 | "stepType": "addNodeMark", 218 | "pos": self.pos, 219 | "mark": self.mark.to_json(), 220 | } 221 | 222 | @staticmethod 223 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> Step: 224 | if isinstance(json_data, str): 225 | import json 226 | 227 | json_data = cast(JSONDict, json.loads(json_data)) 228 | 229 | if not isinstance(json_data["pos"], int): 230 | msg = "Invalid input for AddNodeMarkStep.from_json" 231 | raise ValueError(msg) 232 | return AddNodeMarkStep( 233 | json_data["pos"], 234 | schema.mark_from_json(cast(JSONDict, json_data["mark"])), 235 | ) 236 | 237 | 238 | step_json_id("addNodeMark", AddNodeMarkStep) 239 | 240 | 241 | class RemoveNodeMarkStep(Step): 242 | def __init__(self, pos: int, mark: Mark) -> None: 243 | super().__init__() 244 | self.pos = pos 245 | self.mark = mark 246 | 247 | def apply(self, doc: Node) -> StepResult: 248 | node = doc.node_at(self.pos) 249 | if not node: 250 | return StepResult.fail("No node at mark step's position") 251 | updated = node.type.create( 252 | node.attrs, 253 | None, 254 | self.mark.remove_from_set(node.marks), 255 | ) 256 | return StepResult.from_replace( 257 | doc, 258 | self.pos, 259 | self.pos + 1, 260 | Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), 261 | ) 262 | 263 | def invert(self, doc: Node) -> Step: 264 | node = doc.node_at(self.pos) 265 | if not node or not self.mark.is_in_set(node.marks): 266 | return self 267 | return AddNodeMarkStep(self.pos, self.mark) 268 | 269 | def map(self, mapping: Mappable) -> Step | None: 270 | pos = mapping.map_result(self.pos, 1) 271 | return None if pos.deleted_after else RemoveNodeMarkStep(pos.pos, self.mark) 272 | 273 | def to_json(self) -> JSONDict: 274 | return { 275 | "stepType": "removeNodeMark", 276 | "pos": self.pos, 277 | "mark": self.mark.to_json(), 278 | } 279 | 280 | @staticmethod 281 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> Step: 282 | if isinstance(json_data, str): 283 | import json 284 | 285 | json_data = cast(JSONDict, json.loads(json_data)) 286 | 287 | if not isinstance(json_data["pos"], int): 288 | msg = "Invalid input for RemoveNodeMarkStep.from_json" 289 | raise ValueError(msg) 290 | return RemoveNodeMarkStep( 291 | json_data["pos"], 292 | schema.mark_from_json(cast(JSONDict, json_data["mark"])), 293 | ) 294 | 295 | 296 | step_json_id("removeNodeMark", RemoveNodeMarkStep) 297 | -------------------------------------------------------------------------------- /prosemirror/transform/replace_step.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, cast 2 | 3 | from prosemirror.model import Node, Schema, Slice 4 | from prosemirror.transform.map import Mappable, StepMap 5 | from prosemirror.transform.step import Step, StepResult, step_json_id 6 | from prosemirror.utils import JSONDict 7 | 8 | 9 | class ReplaceStep(Step): 10 | def __init__( 11 | self, 12 | from_: int, 13 | to: int, 14 | slice: Slice, 15 | structure: bool | None = None, 16 | ) -> None: 17 | super().__init__() 18 | self.from_ = from_ 19 | self.to = to 20 | self.slice = slice 21 | self.structure = bool(structure) 22 | 23 | def apply(self, doc: Node) -> StepResult: 24 | if self.structure and content_between(doc, self.from_, self.to): 25 | return StepResult.fail("Structure replace would overrite content") 26 | return StepResult.from_replace(doc, self.from_, self.to, self.slice) 27 | 28 | def get_map(self) -> StepMap: 29 | return StepMap([self.from_, self.to - self.from_, self.slice.size]) 30 | 31 | def invert(self, doc: Node) -> "ReplaceStep": 32 | return ReplaceStep( 33 | self.from_, 34 | self.from_ + self.slice.size, 35 | doc.slice(self.from_, self.to), 36 | ) 37 | 38 | def map(self, mapping: Mappable) -> Optional["ReplaceStep"]: 39 | from_ = mapping.map_result(self.from_, 1) 40 | to = mapping.map_result(self.to, -1) 41 | if from_.deleted and to.deleted: 42 | return None 43 | return ReplaceStep(from_.pos, max(from_.pos, to.pos), self.slice) 44 | 45 | def merge(self, other: "Step") -> Optional["ReplaceStep"]: 46 | if not isinstance(other, ReplaceStep) or other.structure or self.structure: 47 | return None 48 | if ( 49 | self.from_ + self.slice.size == other.from_ 50 | and not self.slice.open_end 51 | and not other.slice.open_start 52 | ): 53 | if self.slice.size + other.slice.size == 0: 54 | slice = Slice.empty 55 | else: 56 | slice = Slice( 57 | self.slice.content.append(other.slice.content), 58 | self.slice.open_start, 59 | other.slice.open_end, 60 | ) 61 | return ReplaceStep( 62 | self.from_, 63 | self.to + (other.to - other.from_), 64 | slice, 65 | self.structure, 66 | ) 67 | elif ( 68 | other.to == self.from_ 69 | and not self.slice.open_start 70 | and not other.slice.open_end 71 | ): 72 | if self.slice.size + other.slice.size == 0: 73 | slice = Slice.empty 74 | else: 75 | slice = Slice( 76 | other.slice.content.append(self.slice.content), 77 | other.slice.open_start, 78 | self.slice.open_end, 79 | ) 80 | return ReplaceStep(other.from_, self.to, slice, self.structure) 81 | return None 82 | 83 | def to_json(self) -> JSONDict: 84 | json_data: JSONDict = {"stepType": "replace", "from": self.from_, "to": self.to} 85 | if self.slice.size: 86 | json_data = { 87 | **json_data, 88 | "slice": self.slice.to_json(), 89 | } 90 | if self.structure: 91 | json_data = { 92 | **json_data, 93 | "structure": True, 94 | } 95 | return json_data 96 | 97 | @staticmethod 98 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "ReplaceStep": 99 | if isinstance(json_data, str): 100 | import json 101 | 102 | json_data = cast(JSONDict, json.loads(json_data)) 103 | 104 | if not isinstance(json_data["from"], int) or not isinstance( 105 | json_data["to"], 106 | int, 107 | ): 108 | msg = "Invlid input for ReplaceStep.from_json" 109 | raise ValueError(msg) 110 | return ReplaceStep( 111 | json_data["from"], 112 | json_data["to"], 113 | Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), 114 | bool(json_data.get("structure")), 115 | ) 116 | 117 | 118 | step_json_id("replace", ReplaceStep) 119 | 120 | 121 | class ReplaceAroundStep(Step): 122 | def __init__( 123 | self, 124 | from_: int, 125 | to: int, 126 | gap_from: int, 127 | gap_to: int, 128 | slice: Slice, 129 | insert: int, 130 | structure: bool | None = None, 131 | ) -> None: 132 | super().__init__() 133 | self.from_ = from_ 134 | self.to = to 135 | self.gap_from = gap_from 136 | self.gap_to = gap_to 137 | self.slice = slice 138 | self.insert = insert 139 | self.structure = bool(structure) 140 | 141 | def apply(self, doc: Node) -> StepResult: 142 | if self.structure and ( 143 | content_between(doc, self.from_, self.gap_from) 144 | or content_between(doc, self.gap_to, self.to) 145 | ): 146 | return StepResult.fail("Structure gap-replace would overwrite content") 147 | gap = doc.slice(self.gap_from, self.gap_to) 148 | if gap.open_start or gap.open_end: 149 | return StepResult.fail("Gap is not a flat range") 150 | inserted = self.slice.insert_at(self.insert, gap.content) 151 | if not inserted: 152 | return StepResult.fail("Content does not fit in gap") 153 | return StepResult.from_replace(doc, self.from_, self.to, inserted) 154 | 155 | def get_map(self) -> StepMap: 156 | return StepMap([ 157 | self.from_, 158 | self.gap_from - self.from_, 159 | self.insert, 160 | self.gap_to, 161 | self.to - self.gap_to, 162 | self.slice.size - self.insert, 163 | ]) 164 | 165 | def invert(self, doc: Node) -> "ReplaceAroundStep": 166 | gap = self.gap_to - self.gap_from 167 | return ReplaceAroundStep( 168 | self.from_, 169 | self.from_ + self.slice.size + gap, 170 | self.from_ + self.insert, 171 | self.from_ + self.insert + gap, 172 | doc.slice(self.from_, self.to).remove_between( 173 | self.gap_from - self.from_, 174 | self.gap_to - self.from_, 175 | ), 176 | self.gap_from - self.from_, 177 | self.structure, 178 | ) 179 | 180 | def map(self, mapping: Mappable) -> Optional["ReplaceAroundStep"]: 181 | from_ = mapping.map_result(self.from_, 1) 182 | to = mapping.map_result(self.to, -1) 183 | gap_from = mapping.map(self.gap_from, -1) 184 | gap_to = mapping.map(self.gap_to, 1) 185 | if (from_.deleted and to.deleted) or gap_from < from_.pos or gap_to > to.pos: 186 | return None 187 | return ReplaceAroundStep( 188 | from_.pos, 189 | to.pos, 190 | gap_from, 191 | gap_to, 192 | self.slice, 193 | self.insert, 194 | self.structure, 195 | ) 196 | 197 | def to_json(self) -> JSONDict: 198 | json_data: JSONDict = { 199 | "stepType": "replaceAround", 200 | "from": self.from_, 201 | "to": self.to, 202 | "gapFrom": self.gap_from, 203 | "gapTo": self.gap_to, 204 | "insert": self.insert, 205 | } 206 | if self.slice.size: 207 | json_data = { 208 | **json_data, 209 | "slice": self.slice.to_json(), 210 | } 211 | if self.structure: 212 | json_data = { 213 | **json_data, 214 | "structure": True, 215 | } 216 | return json_data 217 | 218 | @staticmethod 219 | def from_json( 220 | schema: Schema[Any, Any], 221 | json_data: JSONDict | str, 222 | ) -> "ReplaceAroundStep": 223 | if isinstance(json_data, str): 224 | import json 225 | 226 | json_data = cast(JSONDict, json.loads(json_data)) 227 | 228 | if ( 229 | not isinstance(json_data["from"], int) 230 | or not isinstance(json_data["to"], int) 231 | or not isinstance(json_data["gapFrom"], int) 232 | or not isinstance(json_data["gapTo"], int) 233 | or not isinstance(json_data["insert"], int) 234 | ): 235 | msg = "Invlid input for ReplaceAroundStep.from_json" 236 | raise ValueError(msg) 237 | return ReplaceAroundStep( 238 | json_data["from"], 239 | json_data["to"], 240 | json_data["gapFrom"], 241 | json_data["gapTo"], 242 | Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), 243 | json_data["insert"], 244 | bool(json_data.get("structure")), 245 | ) 246 | 247 | 248 | step_json_id("replaceAround", ReplaceAroundStep) 249 | 250 | 251 | def content_between(doc: Node, from_: int, to: int) -> bool: 252 | from__ = doc.resolve(from_) 253 | dist = to - from_ 254 | depth = from__.depth 255 | while ( 256 | dist > 0 257 | and depth > 0 258 | and from__.index_after(depth) == from__.node(depth).child_count 259 | ): 260 | depth -= 1 261 | dist -= 1 262 | if dist > 0: 263 | next = from__.node(depth).maybe_child(from__.index_after(depth)) 264 | while dist > 0: 265 | if not next or next.is_leaf: 266 | return True 267 | next = next.first_child 268 | dist -= 1 269 | return False 270 | -------------------------------------------------------------------------------- /prosemirror/transform/step.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Literal, Optional, TypeVar, cast, overload 3 | 4 | from prosemirror.model import Node, ReplaceError, Schema, Slice 5 | from prosemirror.transform.map import Mappable, StepMap 6 | from prosemirror.utils import JSONDict 7 | 8 | # like a registry 9 | STEPS_BY_ID: dict[str, type["Step"]] = {} 10 | StepSubclass = TypeVar("StepSubclass", bound="Step") 11 | 12 | 13 | class Step(metaclass=abc.ABCMeta): 14 | json_id: str 15 | 16 | @abc.abstractmethod 17 | def apply(self, _doc: Node) -> "StepResult": ... 18 | 19 | def get_map(self) -> StepMap: 20 | return StepMap.empty 21 | 22 | @abc.abstractmethod 23 | def invert(self, _doc: Node) -> "Step": ... 24 | 25 | @abc.abstractmethod 26 | def map(self, _mapping: Mappable) -> Optional["Step"]: ... 27 | 28 | def merge(self, _other: "Step") -> Optional["Step"]: 29 | return None 30 | 31 | @abc.abstractmethod 32 | def to_json(self) -> JSONDict: ... 33 | 34 | @staticmethod 35 | def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": 36 | if isinstance(json_data, str): 37 | import json 38 | 39 | json_data = cast(JSONDict, json.loads(json_data)) 40 | 41 | if not json_data or not json_data.get("stepType"): 42 | msg = "Invalid inpit for Step.from_json" 43 | raise ValueError(msg) 44 | type = STEPS_BY_ID.get(cast(str, json_data["stepType"])) 45 | if not type: 46 | msg = f"no step type {json_data['stepType']} defined" 47 | raise ValueError(msg) 48 | return type.from_json(schema, json_data) 49 | 50 | 51 | def step_json_id(id: str, step_class: type[StepSubclass]) -> type[StepSubclass]: 52 | if id in STEPS_BY_ID: 53 | msg = f"Duplicated JSON ID for step type: {id}" 54 | raise ValueError(msg) 55 | 56 | STEPS_BY_ID[id] = step_class 57 | step_class.json_id = id 58 | 59 | return step_class 60 | 61 | 62 | class StepResult: 63 | @overload 64 | def __init__(self, doc: Node, failed: Literal[None]) -> None: ... 65 | 66 | @overload 67 | def __init__(self, doc: None, failed: str) -> None: ... 68 | 69 | def __init__(self, doc: Node | None, failed: str | None) -> None: 70 | self.doc = doc 71 | self.failed = failed 72 | 73 | @classmethod 74 | def ok(cls, doc: Node) -> "StepResult": 75 | return cls(doc, None) 76 | 77 | @classmethod 78 | def fail(cls, message: str) -> "StepResult": 79 | return cls(None, message) 80 | 81 | @classmethod 82 | def from_replace(cls, doc: Node, from_: int, to: int, slice: Slice) -> "StepResult": 83 | try: 84 | return cls.ok(doc.replace(from_, to, slice)) 85 | except ReplaceError as e: 86 | return cls.fail(e.args[0]) 87 | -------------------------------------------------------------------------------- /prosemirror/transform/structure.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import cast 3 | 4 | from prosemirror.model import ContentMatch, Node, NodeRange, NodeType, Slice 5 | from prosemirror.utils import Attrs 6 | 7 | 8 | def can_cut(node: Node, start: int, end: int) -> bool: 9 | if start == 0 or node.can_replace(start, node.child_count): 10 | return (end == node.child_count) or node.can_replace(0, end) 11 | return False 12 | 13 | 14 | def lift_target(range_: NodeRange) -> int | None: 15 | parent = range_.parent 16 | content = parent.content.cut_by_index(range_.start_index, range_.end_index) 17 | depth = range_.depth 18 | while True: 19 | node = range_.from_.node(depth) 20 | index = range_.from_.index(depth) 21 | end_index = range_.to.index_after(depth) 22 | if depth < range_.depth and node.can_replace(index, end_index, content): 23 | return depth 24 | if ( 25 | depth == 0 26 | or node.type.spec.get("isolating") 27 | or not can_cut(node, index, end_index) 28 | ): 29 | break 30 | depth -= 1 31 | 32 | return None 33 | 34 | 35 | @dataclass 36 | class NodeTypeWithAttrs: 37 | type: NodeType 38 | attrs: Attrs | None = None 39 | 40 | 41 | def find_wrapping( 42 | range_: NodeRange, 43 | node_type: NodeType, 44 | attrs: Attrs | None = None, 45 | inner_range: NodeRange | None = None, 46 | ) -> list[NodeTypeWithAttrs] | None: 47 | if inner_range is None: 48 | inner_range = range_ 49 | 50 | around = find_wrapping_outside(range_, node_type) 51 | inner = None 52 | 53 | if around is not None: 54 | inner = find_wrapping_inside(inner_range, node_type) 55 | else: 56 | return None 57 | 58 | if inner is None: 59 | return None 60 | 61 | return ( 62 | [with_attrs(item) for item in around] 63 | + [NodeTypeWithAttrs(type=node_type, attrs=attrs)] 64 | + [with_attrs(item) for item in inner] 65 | ) 66 | 67 | 68 | def with_attrs(type: NodeType) -> NodeTypeWithAttrs: 69 | return NodeTypeWithAttrs(type=type, attrs=None) 70 | 71 | 72 | def find_wrapping_outside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: 73 | parent = range_.parent 74 | start_index = range_.start_index 75 | end_index = range_.end_index 76 | around = parent.content_match_at(start_index).find_wrapping(type) 77 | if around is None: 78 | return None 79 | outer = around[0] if len(around) and around[0] else type 80 | return around if parent.can_replace_with(start_index, end_index, outer) else None 81 | 82 | 83 | def find_wrapping_inside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: 84 | parent = range_.parent 85 | start_index = range_.start_index 86 | end_index = range_.end_index 87 | inner = parent.child(start_index) 88 | inside = type.content_match.find_wrapping(inner.type) 89 | 90 | if inside is None: 91 | return None 92 | 93 | last_type = inside[-1] if len(inside) else type 94 | inner_match: ContentMatch | None = last_type.content_match 95 | i = start_index 96 | 97 | while inner_match and i < end_index: 98 | inner_match = inner_match.match_type(parent.child(i).type) 99 | i += 1 100 | 101 | if not inner_match or not inner_match.valid_end: 102 | return None 103 | 104 | return inside 105 | 106 | 107 | def can_change_type(doc: Node, pos: int, type: NodeType) -> bool: 108 | pos_ = doc.resolve(pos) 109 | index = pos_.index() 110 | return pos_.parent.can_replace_with(index, index + 1, type) 111 | 112 | 113 | def can_split( 114 | doc: Node, 115 | pos: int, 116 | depth: int | None = None, 117 | types_after: list[NodeTypeWithAttrs] | None = None, 118 | ) -> bool: 119 | if depth is None: 120 | depth = 1 121 | pos_ = doc.resolve(pos) 122 | base = pos_.depth - depth 123 | inner_type: NodeTypeWithAttrs = cast( 124 | NodeTypeWithAttrs, 125 | (types_after and types_after[-1]) or pos_.parent, 126 | ) 127 | 128 | if ( 129 | base < 0 130 | or pos_.parent.type.spec.get("isolating") 131 | or not pos_.parent.can_replace(pos_.index(), pos_.parent.child_count) 132 | or not inner_type.type.valid_content( 133 | pos_.parent.content.cut_by_index(pos_.index(), pos_.parent.child_count), 134 | ) 135 | ): 136 | return False 137 | 138 | d = pos_.depth - 1 139 | i = depth - 2 140 | 141 | while d > base: 142 | node = pos_.node(d) 143 | index = pos_.index(d) 144 | if node.type.spec.get("isolating"): 145 | return False 146 | rest = node.content.cut_by_index(index, node.child_count) 147 | 148 | if types_after and len(types_after) > i + 1: 149 | override_child = types_after[i + 1] 150 | rest = rest.replace_child( 151 | 0, 152 | override_child.type.create(override_child.attrs), 153 | ) 154 | after: NodeTypeWithAttrs = cast( 155 | NodeTypeWithAttrs, 156 | (types_after and len(types_after) > i and types_after[i]) or node, 157 | ) 158 | if not node.can_replace( 159 | index + 1, 160 | node.child_count, 161 | ) or not after.type.valid_content(rest): 162 | return False 163 | d -= 1 164 | i -= 1 165 | index = pos_.index_after(base) 166 | base_type = types_after[0] if types_after else None 167 | return pos_.node(base).can_replace_with( 168 | index, 169 | index, 170 | base_type.type if base_type else pos_.node(base + 1).type, 171 | ) 172 | 173 | 174 | def can_join(doc: Node, pos: int) -> bool | None: 175 | pos_ = doc.resolve(pos) 176 | index = pos_.index() 177 | return ( 178 | pos_.parent.can_replace(index, index + 1) 179 | if joinable(pos_.node_before, pos_.node_after) 180 | else None 181 | ) 182 | 183 | 184 | def joinable(a: Node | None, b: Node | None) -> bool: 185 | if a and b and not a.is_leaf: 186 | return a.can_append(b) 187 | return False 188 | 189 | 190 | def join_point(doc: Node, pos: int, dir: int = -1) -> int | None: 191 | pos_ = doc.resolve(pos) 192 | for d in range(pos_.depth, -1, -1): 193 | before = None 194 | after = None 195 | index = pos_.index(d) 196 | if d == pos_.depth: 197 | before = pos_.node_before 198 | after = pos_.node_after 199 | elif dir > 0: 200 | before = pos_.node(d + 1) 201 | index += 1 202 | after = pos_.node(d).maybe_child(index) 203 | else: 204 | before = pos_.node(d).maybe_child(index - 1) 205 | after = pos_.node(d + 1) 206 | if ( 207 | before 208 | and not before.is_textblock 209 | and joinable(before, after) 210 | and pos_.node(d).can_replace(index, index + 1) 211 | ): 212 | return pos 213 | if d == 0: 214 | break 215 | pos = pos_.before(d) if dir < 0 else pos_.after(d) 216 | 217 | return None 218 | 219 | 220 | def insert_point(doc: Node, pos: int, node_type: NodeType) -> int | None: 221 | pos_ = doc.resolve(pos) 222 | if pos_.parent.can_replace_with(pos_.index(), pos_.index(), node_type): 223 | return pos 224 | if pos_.parent_offset == 0: 225 | for d in range(pos_.depth - 1, -1, -1): 226 | index = pos_.index(d) 227 | if pos_.node(d).can_replace_with(index, index, node_type): 228 | return pos_.before(d + 1) 229 | if index > 0: 230 | return None 231 | if pos_.parent_offset == pos_.parent.content.size: 232 | for d in range(pos_.depth - 1, -1, -1): 233 | index = pos_.index_after(d) 234 | if pos_.node(d).can_replace_with(index, index, node_type): 235 | return pos_.after(d + 1) 236 | if index < pos_.node(d).child_count: 237 | return None 238 | 239 | return None 240 | 241 | 242 | def drop_point(doc: Node, pos: int, slice: Slice) -> int | None: 243 | pos_ = doc.resolve(pos) 244 | if not slice.content.size: 245 | return pos 246 | content = slice.content 247 | for _i in range(slice.open_start): 248 | assert content.first_child is not None 249 | content = content.first_child.content 250 | pass_ = 1 251 | while pass_ <= (2 if slice.open_start == 0 and slice.size else 1): 252 | for d in range(pos_.depth, 0, -1): 253 | if d == pos_.depth: 254 | bias = 0 255 | elif pos_.pos <= (pos_.start(d + 1) + pos_.end(d + 1)) / 2: 256 | bias = -1 257 | else: 258 | bias = 1 259 | insert_pos = pos_.index(d) + (1 if bias > 0 else 0) 260 | parent = pos_.node(d) 261 | fits = False 262 | if pass_ == 1: 263 | fits = parent.can_replace(insert_pos, insert_pos, content) 264 | else: 265 | assert content.first_child is not None 266 | wrapping = parent.content_match_at(insert_pos).find_wrapping( 267 | content.first_child.type, 268 | ) 269 | fits = wrapping is not None and parent.can_replace_with( 270 | insert_pos, 271 | insert_pos, 272 | wrapping[0], 273 | ) 274 | if fits: 275 | if bias == 0: 276 | return pos_.pos 277 | elif bias < 0: 278 | return pos_.before(d + 1) 279 | else: 280 | return pos_.after(d + 1) 281 | pass_ += 1 282 | return None 283 | -------------------------------------------------------------------------------- /prosemirror/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | from typing import TypeAlias 3 | 4 | JSONDict: TypeAlias = Mapping[str, "JSON"] 5 | JSONList: TypeAlias = Sequence["JSON"] 6 | 7 | JSON: TypeAlias = JSONDict | JSONList | str | int | float | bool | None 8 | 9 | Attrs: TypeAlias = JSONDict 10 | 11 | 12 | def text_length(text: str) -> int: 13 | return len(text.encode("utf-16-le")) // 2 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.version] 6 | source = "vcs" 7 | 8 | [project] 9 | name = "prosemirror" 10 | dynamic = ["version"] 11 | description = "Python implementation of core ProseMirror modules for collaborative editing" 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | authors = [ 15 | { name = "Samuel Cormier-Iijima", email = "sam@fellow.co" }, 16 | { name = "Shen Li", email = "dustet@gmail.com" }, 17 | ] 18 | license = { text = "BSD-3-Clause" } 19 | keywords = ["prosemirror", "collaborative", "editing"] 20 | dependencies = ["typing-extensions>=4.1", "lxml>=4.9", "cssselect>=1.2"] 21 | 22 | classifiers = [ 23 | "Development Status :: 5 - Production/Stable", 24 | "Intended Audience :: Developers", 25 | "License :: OSI Approved :: BSD License", 26 | "Operating System :: OS Independent", 27 | "Programming Language :: Python", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | "Programming Language :: Python :: 3.13", 32 | "Programming Language :: Python :: 3 :: Only", 33 | "Topic :: Software Development :: Libraries :: Python Modules", 34 | "Typing :: Typed", 35 | ] 36 | 37 | [project.urls] 38 | Homepage = "https://github.com/fellowapp/prosemirror-py" 39 | Repository = "https://github.com/fellowapp/prosemirror-py" 40 | Changelog = "https://github.com/fellowapp/prosemirror-py/releases" 41 | 42 | [dependency-groups] 43 | dev = [ 44 | "codecov~=2.1", 45 | "coverage~=7.6", 46 | "mypy~=1.15", 47 | "pyright>=1.1.396", 48 | "pytest~=8.3", 49 | "pytest-cov~=6.0", 50 | "ruff~=0.9", 51 | "types-lxml>=2025.2.24", 52 | ] 53 | 54 | [tool.ruff.lint] 55 | select = [ 56 | "ANN", 57 | "B", 58 | "COM", 59 | "E", 60 | "EM", 61 | "F", 62 | "I", 63 | "N", 64 | "PT", 65 | "RSE", 66 | "RUF", 67 | "SIM", 68 | "UP", 69 | "W", 70 | ] 71 | ignore = ["COM812"] 72 | preview = true 73 | 74 | [tool.ruff.lint.per-file-ignores] 75 | "prosemirror/test_builder/**" = ["ANN"] 76 | "tests/**" = ["ANN"] 77 | 78 | [tool.ruff.format] 79 | preview = true 80 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture 5 | def ist(): 6 | def ist(a, b=None, key=None): 7 | if key is None: 8 | if b is not None: 9 | assert a == b 10 | else: 11 | assert a 12 | else: 13 | if b is not None: 14 | assert key(a, b) 15 | else: 16 | assert key(a) 17 | 18 | return ist 19 | -------------------------------------------------------------------------------- /tests/prosemirror_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/tests/prosemirror_model/__init__.py -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/tests/prosemirror_model/tests/__init__.py -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_content.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.model import ContentMatch, Node 4 | from prosemirror.test_builder import out 5 | from prosemirror.test_builder import test_schema as schema 6 | 7 | doc = out["doc"] 8 | h1 = out["h1"] 9 | p = out["p"] 10 | pre = out["pre"] 11 | img = out["img"] 12 | br = out["br"] 13 | hr = out["hr"] 14 | 15 | 16 | def get(expr): 17 | return ContentMatch.parse(expr, schema.nodes) 18 | 19 | 20 | def match(expr, types): 21 | m = get(expr) 22 | ts = [schema.nodes[t] for t in types.split(" ")] if types else [] 23 | i = 0 24 | while m and i < len(ts): 25 | m = m.match_type(ts[i]) 26 | i += 1 27 | if m: 28 | return m.valid_end 29 | return False 30 | 31 | 32 | @pytest.mark.parametrize( 33 | ("expr", "types", "valid"), 34 | [ 35 | ("", "", True), 36 | ("", "image", False), 37 | ("image*", "", True), 38 | ("image*", "image", True), 39 | ("image*", "image image image image", True), 40 | ("image*", "image text", False), 41 | ("inline*", "image text", True), 42 | ("inline*", "paragraph", False), 43 | ("(paragraph | heading)", "paragraph", True), 44 | ("(paragraph | heading)", "image", False), 45 | ( 46 | "paragraph horizontal_rule paragraph", 47 | "paragraph horizontal_rule paragraph", 48 | True, 49 | ), 50 | ("paragraph horizontal_rule", "paragraph horizontal_rule paragraph", False), 51 | ("paragraph horizontal_rule paragraph", "paragraph horizontal_rule", False), 52 | ( 53 | "paragraph horizontal_rule", 54 | "horizontal_rule paragraph horizontal_rule", 55 | False, 56 | ), 57 | ("heading paragraph*", "heading", True), 58 | ("heading paragraph*", "heading paragraph paragraph", True), 59 | ("heading paragraph+", "heading paragraph", True), 60 | ("heading paragraph+", "heading paragraph paragraph", True), 61 | ("heading paragraph+", "heading", False), 62 | ("heading paragraph+", "paragraph paragraph", False), 63 | ("image?", "image", True), 64 | ("image?", "", True), 65 | ("image?", "image image", False), 66 | ( 67 | "(heading paragraph+)+", 68 | "heading paragraph heading paragraph paragraph", 69 | True, 70 | ), 71 | ( 72 | "(heading paragraph+)+", 73 | "heading paragraph heading paragraph paragraph horizontal_rule", 74 | False, 75 | ), 76 | ("hard_break{2}", "hard_break hard_break", True), 77 | ("hard_break{2}", "hard_break", False), 78 | ("hard_break{2}", "hard_break hard_break hard_break", False), 79 | ("hard_break{2, 4}", "hard_break hard_break", True), 80 | ("hard_break{2, 4}", "hard_break hard_break hard_break hard_break", True), 81 | ("hard_break{2, 4}", "hard_break hard_break hard_break", True), 82 | ("hard_break{2, 4}", "hard_break", False), 83 | ( 84 | "hard_break{2, 4}", 85 | "hard_break hard_break hard_break hard_break hard_break", 86 | False, 87 | ), 88 | ("hard_break{2, 4} text*", "hard_break hard_break image", False), 89 | ("hard_break{2, 4} image?", "hard_break hard_break image", True), 90 | ("hard_break{2,}", "hard_break hard_break", True), 91 | ("hard_break{2,}", "hard_break hard_break hard_break hard_break", True), 92 | ("hard_break{2,}", "hard_break", False), 93 | ], 94 | ) 95 | def test_match_type(expr, types, valid): 96 | if valid: 97 | assert match(expr, types) 98 | else: 99 | assert not match(expr, types) 100 | 101 | 102 | @pytest.mark.parametrize( 103 | ("expr", "before", "after", "result"), 104 | [ 105 | ( 106 | "paragraph horizontal_rule paragraph", 107 | '{"type":"doc","content":[{"type":"paragraph"},{"type":"horizontal_rule"}]}', 108 | '{"type":"doc","content":[{"type":"paragraph"}]}', 109 | '{"type":"doc"}', 110 | ), 111 | ( 112 | "paragraph horizontal_rule paragraph", 113 | '{"type":"doc","content":[{"type":"paragraph"}]}', 114 | '{"type":"doc","content":[{"type":"paragraph"}]}', 115 | '{"type":"doc","content":[{"type":"horizontal_rule"}]}', 116 | ), 117 | ( 118 | "hard_break*", 119 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 120 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 121 | '{"type":"paragraph"}', 122 | ), 123 | ( 124 | "hard_break*", 125 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 126 | '{"type":"paragraph"}', 127 | '{"type":"paragraph"}', 128 | ), 129 | ( 130 | "hard_break*", 131 | '{"type":"paragraph"}', 132 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 133 | '{"type":"paragraph"}', 134 | ), 135 | ( 136 | "hard_break*", 137 | '{"type":"paragraph"}', 138 | '{"type":"paragraph"}', 139 | '{"type":"paragraph"}', 140 | ), 141 | ( 142 | "hard_break+", 143 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 144 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 145 | '{"type":"paragraph"}', 146 | ), 147 | ( 148 | "hard_break+", 149 | '{"type":"paragraph"}', 150 | '{"type":"paragraph"}', 151 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 152 | ), 153 | ( 154 | "hard_break+", 155 | '{"type":"paragraph"}', 156 | '{"type":"paragraph","content":[{"type":"image","attrs":{"src":"img.png","alt":null,"title":null}}]}', 157 | None, 158 | ), 159 | ( 160 | "heading* paragraph*", 161 | '{"type":"doc","content":[{"type":"heading","attrs":{"level":1}}]}', 162 | '{"type":"doc","content":[{"type":"paragraph"}]}', 163 | '{"type":"doc"}', 164 | ), 165 | ( 166 | "heading* paragraph*", 167 | '{"type":"doc","content":[{"type":"heading","attrs":{"level":1}}]}', 168 | '{"type":"doc"}', 169 | '{"type":"doc"}', 170 | ), 171 | ( 172 | "heading+ paragraph+", 173 | '{"type":"doc","content":[{"type":"heading","attrs":{"level":1}}]}', 174 | '{"type":"doc","content":[{"type":"paragraph"}]}', 175 | '{"type":"doc"}', 176 | ), 177 | ( 178 | "heading+ paragraph+", 179 | '{"type":"doc","content":[{"type":"heading","attrs":{"level":1}}]}', 180 | '{"type":"doc"}', 181 | '{"type":"doc","content":[{"type":"paragraph"}]}', 182 | ), 183 | ( 184 | "hard_break{3}", 185 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 186 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 187 | '{"type":"paragraph","content":[{"type":"hard_break"}]}', 188 | ), 189 | ( 190 | "hard_break{3}", 191 | '{"type":"paragraph","content":[{"type":"hard_break"},{"type":"hard_break"}]}', 192 | '{"type":"paragraph","content":[{"type":"hard_break"},{"type":"hard_break"}]}', 193 | None, 194 | ), 195 | ( 196 | "code_block{2} paragraph{2}", 197 | '{"type":"doc","content":[{"type":"code_block"}]}', 198 | '{"type":"doc","content":[{"type":"paragraph"}]}', 199 | '{"type":"doc","content":[{"type":"code_block"},{"type":"paragraph"}]}', 200 | ), 201 | ( 202 | "heading paragraph? horizontal_rule", 203 | '{"type":"doc","content":[{"type":"heading"}]}', 204 | '{"type":"doc"}', 205 | '{"type":"doc","content":[{"type":"horizontal_rule"}]}', 206 | ), 207 | ], 208 | ) 209 | def test_fill_before(expr, before, after, result): 210 | before = Node.from_json(schema, before) 211 | after = Node.from_json(schema, after) 212 | filled = get(expr).match_fragment(before.content).fill_before(after.content, True) 213 | if result: 214 | result = Node.from_json(schema, result) 215 | assert filled.eq(result.content) 216 | else: 217 | assert not filled 218 | 219 | 220 | @pytest.mark.parametrize( 221 | ("expr", "before", "mid", "after", "left", "right"), 222 | [ 223 | ( 224 | "paragraph horizontal_rule paragraph horizontal_rule paragraph", 225 | '{"type":"doc","content":[{"type":"paragraph"}]}', 226 | '{"type":"doc","content":[{"type":"paragraph"}]}', 227 | '{"type":"doc","content":[{"type":"paragraph"}]}', 228 | '{"type":"doc","content":[{"type":"horizontal_rule"}]}', 229 | '{"type":"doc","content":[{"type":"horizontal_rule"}]}', 230 | ), 231 | ( 232 | "code_block+ paragraph+", 233 | '{"type":"doc","content":[{"type":"code_block"}]}', 234 | '{"type":"doc","content":[{"type":"code_block"}]}', 235 | '{"type":"doc","content":[{"type":"paragraph"}]}', 236 | '{"type":"doc"}', 237 | '{"type":"doc"}', 238 | ), 239 | ( 240 | "code_block+ paragraph+", 241 | '{"type":"doc"}', 242 | '{"type":"doc"}', 243 | '{"type":"doc"}', 244 | '{"type":"doc"}', 245 | '{"type":"doc","content":[{"type":"code_block"},{"type":"paragraph"}]}', 246 | ), 247 | ( 248 | "code_block{3} paragraph{3}", 249 | '{"type":"doc","content":[{"type":"code_block"}]}', 250 | '{"type":"doc","content":[{"type":"paragraph"}]}', 251 | '{"type":"doc"}', 252 | '{"type":"doc","content":[{"type":"code_block"},{"type":"code_block"}]}', 253 | '{"type":"doc","content":[{"type":"paragraph"},{"type":"paragraph"}]}', 254 | ), 255 | ( 256 | "paragraph*", 257 | '{"type":"doc","content":[{"type":"paragraph"}]}', 258 | '{"type":"doc","content":[{"type":"code_block"}]}', 259 | '{"type":"doc","content":[{"type":"paragraph"}]}', 260 | None, 261 | None, 262 | ), 263 | ( 264 | "paragraph{4}", 265 | '{"type":"doc","content":[{"type":"paragraph"}]}', 266 | '{"type":"doc","content":[{"type":"paragraph"}]}', 267 | '{"type":"doc","content":[{"type":"paragraph"}]}', 268 | '{"type":"doc"}', 269 | '{"type":"doc","content":[{"type":"paragraph"}]}', 270 | ), 271 | ( 272 | "paragraph{2}", 273 | '{"type":"doc","content":[{"type":"paragraph"}]}', 274 | '{"type":"doc","content":[{"type":"paragraph"}]}', 275 | '{"type":"doc","content":[{"type":"paragraph"}]}', 276 | None, 277 | None, 278 | ), 279 | ], 280 | ) 281 | def test_fill3_before(expr, before, mid, after, left, right): 282 | before = Node.from_json(schema, before) 283 | mid = Node.from_json(schema, mid) 284 | after = Node.from_json(schema, after) 285 | content = get(expr) 286 | a = content.match_fragment(before.content).fill_before(mid.content) 287 | b = False 288 | if a: 289 | b = content.match_fragment( 290 | before.content.append(a).append(mid.content), 291 | ).fill_before(after.content, True) 292 | if left: 293 | left = Node.from_json(schema, left) 294 | right = Node.from_json(schema, right) 295 | assert a.eq(left.content) 296 | assert b.eq(right.content) 297 | else: 298 | assert not b 299 | -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_diff.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.test_builder import out 4 | 5 | doc = out["doc"] 6 | blockquote = out["blockquote"] 7 | h1 = out["h1"] 8 | h2 = out["h2"] 9 | p = out["p"] 10 | em = out["em"] 11 | strong = out["strong"] 12 | 13 | 14 | @pytest.mark.parametrize( 15 | ("a", "b"), 16 | [ 17 | ( 18 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), 19 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), 20 | ), 21 | ( 22 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye")), ""), 23 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye")), p("oops")), 24 | ), 25 | ( 26 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye")), "", p("oops")), 27 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), 28 | ), 29 | (doc(p("a", em("b"))), doc(p("a", strong("b")))), 30 | (doc(p("foobar", em("b"))), doc(p("foo", em("b")))), 31 | (doc(p("foobar")), doc(p("foocar"))), 32 | (doc(p("a"), "", p("b")), doc(p("a"), h1("b"))), 33 | (doc("", p("b")), doc(h1("b"))), 34 | (doc(p("a"), "", h1("foo")), doc(p("a"), h2("foo"))), 35 | ], 36 | ) 37 | def test_find_diff_start(a, b): 38 | assert a.content.find_diff_start(b.content) == a.tag.get("a") 39 | 40 | 41 | @pytest.mark.parametrize( 42 | ("a", "b"), 43 | [ 44 | ( 45 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), 46 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), 47 | ), 48 | ( 49 | doc("", p("a", em("b")), p("hello"), blockquote(h1("bye"))), 50 | doc(p("oops"), p("a", em("b")), p("hello"), blockquote(h1("bye"))), 51 | ), 52 | ( 53 | doc(p("oops"), "", p("a", em("b")), p("hello"), blockquote(h1("bye"))), 54 | doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), 55 | ), 56 | (doc(p("a", em("b"), "c")), doc(p("a", strong("b"), "c"))), 57 | (doc(p("barfoo", em("b"))), doc(p("foo", em("b")))), 58 | (doc(p("foobar")), doc(p("foocar"))), 59 | (doc(p("a"), "", p("b")), doc(h1("a"), p("b"))), 60 | (doc(p("b"), ""), doc(h1("b"))), 61 | (doc("", p("hello")), doc(p("hey"), p("hello"))), 62 | ], 63 | ) 64 | def test_find_diff_end(a, b): 65 | found = a.content.find_diff_end(b.content) 66 | if a == b: 67 | assert not found 68 | if found: 69 | assert found.get("a") == a.tag.get("a") 70 | -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_dom.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.model import DOMSerializer 4 | from prosemirror.model.from_dom import from_html 5 | from prosemirror.schema.basic import schema 6 | from prosemirror.test_builder import out 7 | 8 | doc = out["doc"] 9 | p = out["p"] 10 | li = out["li"] 11 | ul = out["ul"] 12 | em = out["em"] 13 | a = out["a"] 14 | blockquote = out["blockquote"] 15 | strong = out["strong"] 16 | code = out["code"] 17 | img = out["img"] 18 | br = out["br"] 19 | ul = out["ul"] 20 | ol = out["ol"] 21 | h1 = out["h1"] 22 | h2 = out["h2"] 23 | pre = out["pre"] 24 | 25 | serializer = DOMSerializer.from_schema(schema) 26 | _marks_copy = serializer.marks.copy() 27 | del _marks_copy["em"] 28 | no_em = DOMSerializer(serializer.nodes, _marks_copy) 29 | 30 | 31 | @pytest.mark.parametrize( 32 | ("desc", "doc", "html"), 33 | [ 34 | ( 35 | "it can represent simple node", 36 | doc(p("hello")), 37 | "

hello

", 38 | ), 39 | ( 40 | "it can represent a line break", 41 | doc(p("hi", br, "there")), 42 | "

hi
there

", 43 | ), 44 | ( 45 | "it can represent an image", 46 | doc(p("hi", img({"alt": "x"}), "there")), 47 | '

hixthere

', 48 | ), 49 | ( 50 | "it joins styles", 51 | doc(p("one", strong("two", em("three")), em("four"), "five")), 52 | "

onetwothreefourfive

", 53 | ), 54 | ( 55 | "it can represent links", 56 | doc( 57 | p( 58 | "a ", 59 | a({"href": "foo"}, "big ", a({"href": "bar"}, "nested"), " link"), 60 | ), 61 | ), 62 | '

a big nested' 63 | ' link

', 64 | ), 65 | ( 66 | "it can represent an unordered list", 67 | doc( 68 | ul(li(p("one")), li(p("two")), li(p("three", strong("!")))), 69 | p("after"), 70 | ), 71 | "

after

", 73 | ), 74 | ( 75 | "it can represent an ordered list", 76 | doc( 77 | ol(li(p("one")), li(p("two")), li(p("three", strong("!")))), 78 | p("after"), 79 | ), 80 | "
  1. one

  2. two

  3. three" 81 | "!

after

", 82 | ), 83 | ( 84 | "it can represent a blockquote", 85 | doc(blockquote(p("hello"), p("bye"))), 86 | "

hello

bye

", 87 | ), 88 | ( 89 | "it can represent headings", 90 | doc(h1("one"), h2("two"), p("text")), 91 | "

one

two

text

", 92 | ), 93 | ( 94 | "it can represent inline code", 95 | doc(p("text and ", code("code that is ", em("emphasized"), "..."))), 96 | "

text and code that is emphasized" 97 | "...

", 98 | ), 99 | ( 100 | "it can represent a code block", 101 | doc(blockquote(pre("some code")), p("and")), 102 | "
some code

and

", 103 | ), 104 | ( 105 | "it supports leaf nodes in marks", 106 | doc(p(em("hi", br, "x"))), 107 | "

hi
x

", 108 | ), 109 | ( 110 | "it doesn't collapse non-breaking spaces", 111 | doc(p("\u00a0 \u00a0hello\u00a0")), 112 | "

\u00a0 \u00a0hello\u00a0

", 113 | ), 114 | ], 115 | ) 116 | def test_serializer_first(doc, html, desc): 117 | """Parser is not implemented, this is just testing serializer right now""" 118 | schema = doc.type.schema 119 | dom = DOMSerializer.from_schema(schema).serialize_fragment(doc.content) 120 | assert str(dom) == html, desc 121 | 122 | 123 | @pytest.mark.parametrize( 124 | ("desc", "serializer", "doc", "expect"), 125 | [ 126 | ( 127 | "it can omit a mark", 128 | no_em, 129 | p("foo", em("bar"), strong("baz")), 130 | "foobarbaz", 131 | ), 132 | ( 133 | "it doesn't split other marks for omitted marks", 134 | no_em, 135 | p("foo", code("bar"), em(code("baz"), "quux"), "xyz"), 136 | "foobarbazquuxxyz", 137 | ), 138 | ( 139 | "it can render marks with complex structure", 140 | DOMSerializer( 141 | serializer.nodes, 142 | { 143 | **serializer.marks, 144 | "em": lambda *_: ["em", ["i", {"data-emphasis": "true"}, 0]], 145 | }, 146 | ), 147 | p(strong("foo", code("bar"), em(code("baz"))), em("quux"), "xyz"), 148 | "foobar" 149 | 'baz' 150 | "quuxxyz", 151 | ), 152 | ], 153 | ) 154 | def test_serializer(serializer, doc, expect, desc): 155 | assert str(serializer.serialize_fragment(doc.content)) == expect, desc 156 | 157 | 158 | def test_html_is_escaped(): 159 | assert ( 160 | str(serializer.serialize_node(schema.text("bold &"))) 161 | == "<b>bold &</b>" 162 | ) 163 | 164 | 165 | @pytest.mark.parametrize( 166 | ("desc", "doc", "expect"), 167 | [ 168 | ( 169 | "Basic text node", 170 | """

test

""", 171 | { 172 | "type": "doc", 173 | "content": [ 174 | { 175 | "type": "paragraph", 176 | "content": [{"type": "text", "text": "test"}], 177 | }, 178 | ], 179 | }, 180 | ), 181 | ( 182 | "Indented HTML", 183 | """ 184 |
185 |

186 | test 187 |

188 |
189 | """, 190 | { 191 | "type": "doc", 192 | "content": [ 193 | { 194 | "type": "paragraph", 195 | "content": [{"type": "text", "text": "test"}], 196 | }, 197 | ], 198 | }, 199 | ), 200 | ( 201 | "Styled(marks) nodes pt1", 202 | """

test some bolded text

""", 203 | { 204 | "type": "doc", 205 | "content": [ 206 | { 207 | "type": "paragraph", 208 | "content": [ 209 | {"type": "text", "text": "test "}, 210 | { 211 | "type": "text", 212 | "marks": [{"type": "strong", "attrs": {}}], 213 | "text": "some bolded text", 214 | }, 215 | ], 216 | }, 217 | ], 218 | }, 219 | ), 220 | ( 221 | "Styled nodes pt2", 222 | """

test some bolded text

another test """ 223 | """em

""", 224 | { 225 | "type": "doc", 226 | "content": [ 227 | { 228 | "type": "paragraph", 229 | "content": [ 230 | {"type": "text", "text": "test "}, 231 | { 232 | "type": "text", 233 | "marks": [{"type": "strong", "attrs": {}}], 234 | "text": "some bolded text", 235 | }, 236 | ], 237 | }, 238 | { 239 | "type": "paragraph", 240 | "content": [ 241 | {"type": "text", "text": "another test "}, 242 | { 243 | "type": "text", 244 | "marks": [{"type": "em", "attrs": {}}], 245 | "text": "em", 246 | }, 247 | ], 248 | }, 249 | ], 250 | }, 251 | ), 252 | ( 253 | "Slightly more complex test, testing pre and tail text around elements", 254 | """

test google\nsome more text here""" 255 | """

Hello

Test """ 256 | """heading

Test break
Another bit """ 257 | """of testing data.

""", 258 | { 259 | "type": "doc", 260 | "content": [ 261 | { 262 | "type": "paragraph", 263 | "content": [ 264 | {"type": "text", "text": "test "}, 265 | { 266 | "type": "text", 267 | "marks": [ 268 | { 269 | "type": "link", 270 | "attrs": { 271 | "href": "www.google.ca", 272 | "title": None, 273 | }, 274 | }, 275 | ], 276 | "text": "google", 277 | }, 278 | {"type": "text", "text": " some more text here"}, 279 | ], 280 | }, 281 | { 282 | "type": "paragraph", 283 | "content": [ 284 | { 285 | "type": "image", 286 | "attrs": { 287 | "src": "google.ca", 288 | "alt": None, 289 | "title": None, 290 | }, 291 | }, 292 | ], 293 | }, 294 | { 295 | "type": "paragraph", 296 | "content": [ 297 | { 298 | "type": "text", 299 | "marks": [{"type": "strong", "attrs": {}}], 300 | "text": "Hello", 301 | }, 302 | ], 303 | }, 304 | { 305 | "type": "heading", 306 | "attrs": {"level": 1}, 307 | "content": [{"type": "text", "text": "Test heading"}], 308 | }, 309 | { 310 | "type": "paragraph", 311 | "content": [ 312 | { 313 | "type": "text", 314 | "marks": [{"type": "em", "attrs": {}}], 315 | "text": "Test ", 316 | }, 317 | { 318 | "type": "text", 319 | "marks": [ 320 | {"type": "em", "attrs": {}}, 321 | {"type": "strong", "attrs": {}}, 322 | ], 323 | "text": "break", 324 | }, 325 | {"type": "hard_break"}, 326 | {"type": "text", "text": "Another bit of testing data."}, 327 | ], 328 | }, 329 | ], 330 | }, 331 | ), 332 | ( 333 | "Unstructured", 334 | """Testing the result of this""", 335 | { 336 | "type": "doc", 337 | "content": [ 338 | { 339 | "type": "paragraph", 340 | "content": [ 341 | {"type": "text", "text": "Testing the result of this"}, 342 | ], 343 | }, 344 | ], 345 | }, 346 | ), 347 | ( 348 | "Unstructured with tail", 349 | """Testing the

result of

this""", 350 | { 351 | "type": "doc", 352 | "content": [ 353 | { 354 | "type": "paragraph", 355 | "content": [{"type": "text", "text": "Testing the"}], 356 | }, 357 | { 358 | "type": "paragraph", 359 | "content": [ 360 | {"type": "text", "text": "result "}, 361 | { 362 | "type": "text", 363 | "marks": [{"type": "strong", "attrs": {}}], 364 | "text": "o", 365 | }, 366 | { 367 | "type": "text", 368 | "marks": [ 369 | {"type": "em", "attrs": {}}, 370 | {"type": "strong", "attrs": {}}, 371 | ], 372 | "text": "f", 373 | }, 374 | ], 375 | }, 376 | { 377 | "type": "paragraph", 378 | "content": [{"type": "text", "text": " this"}], 379 | }, 380 | ], 381 | }, 382 | ), 383 | ], 384 | ) 385 | def test_parser(doc, expect, desc): 386 | """ 387 | The `expect` dicts are straight copies from the output of the JS lib run in Node, 388 | with 1 exception of 'attrs' key in marks dicts, in JS if blank attrs isn't written, 389 | this library does write out 'attrs' even if it is blank, I didn't want to modify 390 | behavior of existing files with the addition of this 391 | """ 392 | assert from_html(schema, doc) == expect, desc 393 | -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_mark.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.model import Mark, Node, Schema 4 | from prosemirror.test_builder import out 5 | from prosemirror.test_builder import test_schema as schema 6 | 7 | doc = out["doc"] 8 | p = out["p"] 9 | em = out["em"] 10 | a = out["a"] 11 | 12 | em_ = schema.mark("em") 13 | strong = schema.mark("strong") 14 | 15 | 16 | def link(href, title=None): 17 | return schema.mark("link", {"href": href, "title": title}) 18 | 19 | 20 | code = schema.mark("code") 21 | 22 | 23 | custom_schema = Schema({ 24 | "nodes": { 25 | "doc": {"content": "paragraph+"}, 26 | "paragraph": {"content": "text*"}, 27 | "text": {}, 28 | }, 29 | "marks": { 30 | "remark": {"attrs": {"id": {}}, "excludes": "", "inclusive": False}, 31 | "user": {"attrs": {"id": {}}, "excludes": "_"}, 32 | "strong": {"excludes": "em-group"}, 33 | "em": {"group": "em-group"}, 34 | }, 35 | }) 36 | 37 | custom = custom_schema.marks 38 | remark1 = custom["remark"].create({"id": 1}) 39 | remark2 = custom["remark"].create({"id": 2}) 40 | user1 = custom["user"].create({"id": 1}) 41 | user2 = custom["user"].create({"id": 2}) 42 | custom_em = custom["em"].create() 43 | custom_strong = custom["strong"].create() 44 | 45 | 46 | @pytest.mark.parametrize( 47 | ("a", "b", "res"), 48 | [ 49 | ([em_, strong], [em_, strong], True), 50 | ([em_, strong], [em_, code], False), 51 | ([em_, strong], [em_, strong, code], False), 52 | ([link("http://foo"), code], [link("http://foo"), code], True), 53 | ([link("http://foo"), code], [link("http://bar"), code], False), 54 | ], 55 | ) 56 | def test_same_set(a, b, res): 57 | assert Mark.same_set(a, b) is res 58 | 59 | 60 | @pytest.mark.parametrize( 61 | ("a", "b", "res"), 62 | [ 63 | (link("http://foo"), (link("http://foo")), True), 64 | (link("http://foo"), link("http://bar"), False), 65 | (link("http://foo", "A"), link("http://foo", "B"), False), 66 | ], 67 | ) 68 | def test_eq(a, b, res): 69 | assert a.eq(b) is res 70 | 71 | 72 | def test_add_to_set(ist): 73 | ist(em_.add_to_set([]), [em_], Mark.same_set) 74 | ist(em_.add_to_set([em_]), [em_], Mark.same_set) 75 | ist(em_.add_to_set([strong]), [em_, strong], Mark.same_set) 76 | ist(strong.add_to_set([em_]), [em_, strong], Mark.same_set) 77 | ist( 78 | link("http://bar").add_to_set([link("http://foo"), em_]), 79 | [link("http://bar"), em_], 80 | Mark.same_set, 81 | ) 82 | ist( 83 | link("http://foo").add_to_set([em_, link("http://foo")]), 84 | [em_, link("http://foo")], 85 | Mark.same_set, 86 | ) 87 | ist( 88 | code.add_to_set([em_, strong, link("http://foo")]), 89 | [em_, strong, link("http://foo"), code], 90 | Mark.same_set, 91 | ) 92 | ist(strong.add_to_set([em_, code]), [em_, strong, code], Mark.same_set) 93 | ist(remark2.add_to_set([remark1]), [remark1, remark2], Mark.same_set) 94 | ist(remark1.add_to_set([remark1]), [remark1], Mark.same_set) 95 | ist(user1.add_to_set([remark1, custom_em]), [user1], Mark.same_set) 96 | ist(custom_em.add_to_set([user1]), [user1], Mark.same_set) 97 | ist(user2.add_to_set([user1]), [user2], Mark.same_set) 98 | ist( 99 | custom_em.add_to_set([remark1, custom_strong]), 100 | [remark1, custom_strong], 101 | Mark.same_set, 102 | ) 103 | ist( 104 | custom_strong.add_to_set([remark1, custom_em]), 105 | [remark1, custom_strong], 106 | Mark.same_set, 107 | ) 108 | 109 | 110 | def test_remove_form_set(ist): 111 | ist(Mark.same_set(em_.remove_from_set([]), [])) 112 | ist(Mark.same_set(em_.remove_from_set([em_]), [])) 113 | ist(Mark.same_set(strong.remove_from_set([em_]), [em_])) 114 | ist(Mark.same_set(link("http://foo").remove_from_set([link("http://foo")]), [])) 115 | ist( 116 | Mark.same_set( 117 | link("http://foo", "title").remove_from_set([link("http://foo")]), 118 | [link("http://foo")], 119 | ), 120 | ) 121 | 122 | 123 | class TestResolvedPosMarks: 124 | custom_doc = Node.from_json( 125 | custom_schema, 126 | { 127 | "type": "doc", 128 | "content": [ 129 | { 130 | "type": "paragraph", 131 | "content": [ 132 | { 133 | "type": "text", 134 | "marks": [ 135 | {"type": "remark", "attrs": {"id": 1}}, 136 | {"type": "strong"}, 137 | ], 138 | "text": "one", 139 | }, 140 | {"type": "text", "text": "two"}, 141 | ], 142 | }, 143 | { 144 | "type": "paragraph", 145 | "content": [ 146 | {"type": "text", "text": "one"}, 147 | { 148 | "type": "text", 149 | "marks": [{"type": "remark", "attrs": {"id": 1}}], 150 | "text": "twothree", 151 | }, 152 | ], 153 | }, 154 | { 155 | "type": "paragraph", 156 | "content": [ 157 | { 158 | "type": "text", 159 | "marks": [{"type": "remark", "attrs": {"id": 2}}], 160 | "text": "one", 161 | }, 162 | { 163 | "type": "text", 164 | "marks": [{"type": "remark", "attrs": {"id": 1}}], 165 | "text": "two", 166 | }, 167 | ], 168 | }, 169 | ], 170 | }, 171 | ) 172 | 173 | @pytest.mark.parametrize( 174 | ("doc", "mark", "result"), 175 | [ 176 | (doc(p(em("foo"))), em_, True), 177 | (doc(p(em("foo"))), strong, False), 178 | (doc(p(em("hi"), " there")), em_, True), 179 | (doc(p("one ", em("two"))), em_, False), 180 | (doc(p(em("one"))), em_, True), 181 | (doc(p(a("link"))), link("http://baz"), False), 182 | ], 183 | ) 184 | def test_is_at(self, doc, mark, result): 185 | assert mark.is_in_set(doc.resolve(doc.tag["a"]).marks()) is result 186 | 187 | @pytest.mark.parametrize( 188 | ("a", "b"), 189 | [ 190 | (custom_doc.resolve(4).marks(), [custom_strong]), 191 | (custom_doc.resolve(3).marks(), [remark1, custom_strong]), 192 | (custom_doc.resolve(20).marks(), []), 193 | (custom_doc.resolve(15).marks(), [remark1]), 194 | (custom_doc.resolve(25).marks(), []), 195 | ], 196 | ) 197 | def test_with_custom_doc(self, a, b): 198 | assert Mark.same_set(a, b) 199 | -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_node.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from prosemirror.model import Fragment, Schema 4 | from prosemirror.test_builder import eq, out 5 | from prosemirror.test_builder import test_schema as schema 6 | 7 | doc = out["doc"] 8 | blockquote = out["blockquote"] 9 | p = out["p"] 10 | li = out["li"] 11 | ul = out["ul"] 12 | em = out["em"] 13 | strong = out["strong"] 14 | code = out["code"] 15 | a = out["a"] 16 | br = out["br"] 17 | hr = out["hr"] 18 | img = out["img"] 19 | 20 | custom_schema: Schema[ 21 | Literal["doc", "paragraph", "text", "contact", "hard_break"], 22 | str, 23 | ] = Schema({ 24 | "nodes": { 25 | "doc": {"content": "paragraph+"}, 26 | "paragraph": {"content": "(text|contact)*"}, 27 | "text": { 28 | "toDebugString": lambda _: "custom_text", 29 | }, 30 | "contact": { 31 | "inline": True, 32 | "attrs": {"name": {}, "email": {}}, 33 | "leafText": (lambda node: f"{node.attrs['name']} <{node.attrs['email']}>"), 34 | }, 35 | "hard_break": { 36 | "toDebugString": lambda _: "custom_hard_break", 37 | }, 38 | }, 39 | }) 40 | 41 | 42 | class TestToString: 43 | def test_nesting(self): 44 | node = doc(ul(li(p("hey"), p()), li(p("foo")))) 45 | expected = 'doc(bullet_list(list_item(paragraph("hey"), paragraph), list_item(paragraph("foo"))))' # noqa 46 | assert str(node) == expected 47 | 48 | def test_shows_inline_children(self): 49 | node = doc(p("foo", img, br, "bar")) 50 | assert str(node) == 'doc(paragraph("foo", image, hard_break, "bar"))' 51 | 52 | def test_shows_marks(self): 53 | node = doc(p("foo", em("bar", strong("quux")), code("baz"))) 54 | expected = 'doc(paragraph("foo", em("bar"), em(strong("quux")), code("baz")))' 55 | assert str(node) == expected 56 | 57 | def test_has_default_tostring_method_text(self): 58 | assert str(schema.text("hello")) == '"hello"' 59 | 60 | def test_has_default_tostring_method_br(self): 61 | assert str(br()) == "hard_break" 62 | 63 | def test_nodespec_to_debug_string(self): 64 | assert str(custom_schema.text("hello")) == "custom_text" 65 | 66 | def test_respected_by_fragment(self): 67 | f = Fragment.from_array( 68 | [ 69 | custom_schema.text("hello"), 70 | custom_schema.nodes["hard_break"].create_checked(), 71 | custom_schema.text("world"), 72 | ], 73 | ) 74 | assert str(f) == "" 75 | 76 | def test_should_respect_custom_leaf_text_spec(self): 77 | contact = custom_schema.nodes["contact"].create_checked({ 78 | "name": "Bob", 79 | "email": "bob@example.com", 80 | }) 81 | paragraph = custom_schema.nodes["paragraph"].create_checked( 82 | {}, 83 | [custom_schema.text("Hello "), contact], 84 | ) 85 | 86 | assert contact.text_content, "Bob " 87 | assert paragraph.text_content, "Hello Bob " 88 | 89 | 90 | class TestCut: 91 | @staticmethod 92 | def cut(doc, cut): 93 | assert eq(doc.cut(doc.tag.get("a", 0), doc.tag.get("b")), cut) 94 | 95 | def test_extracts_full_block(self): 96 | self.cut(doc(p("foo"), "", p("bar"), "", p("baz")), doc(p("bar"))) 97 | 98 | def test_cuts_text(self): 99 | self.cut(doc(p("0"), p("foobarbaz"), p("2")), doc(p("bar"))) 100 | 101 | def test_cuts_deeply(self): 102 | self.cut( 103 | doc( 104 | blockquote( 105 | ul(li(p("a"), p("bc")), li(p("d")), "", li(p("e"))), 106 | p("3"), 107 | ), 108 | ), 109 | doc(blockquote(ul(li(p("c")), li(p("d"))))), 110 | ) 111 | 112 | def test_works_from_the_left(self): 113 | self.cut(doc(blockquote(p("foobar"))), doc(blockquote(p("foo")))) 114 | 115 | def test_works_to_the_right(self): 116 | self.cut(doc(blockquote(p("foobar"))), doc(blockquote(p("bar")))) 117 | 118 | def test_preserves_marks(self): 119 | self.cut( 120 | doc(p("foo", em("bar", img, strong("baz"), br), "quux", code("xyz"))), 121 | doc(p(em("r", img, strong("baz"), br), "qu")), 122 | ) 123 | 124 | 125 | class TestBetween: 126 | @staticmethod 127 | def between(doc, *nodes): 128 | i = 0 129 | 130 | def iteratee(node, pos, *args): 131 | nonlocal i 132 | 133 | if i == len(nodes): 134 | msg = f"More nodes iterated than list ({node.type.name})" 135 | raise Exception(msg) 136 | compare = node.text if node.is_text else node.type.name 137 | if compare != nodes[i]: 138 | msg = f"Expected {nodes[i]!r}, got {compare!r}" 139 | raise Exception(msg) 140 | i += 1 141 | if not node.is_text and doc.node_at(pos) != node: 142 | msg = f"Pos {pos} does not point at node {node!r} {doc.nodeAt(pos)!r}" 143 | raise Exception(msg) 144 | 145 | doc.nodes_between(doc.tag["a"], doc.tag["b"], iteratee) 146 | 147 | def test_iterates_over_text(self): 148 | self.between(doc(p("foobarbaz")), "paragraph", "foobarbaz") 149 | 150 | def test_descends_multiple_levels(self): 151 | self.between( 152 | doc(blockquote(ul(li(p("foo")), p("b"), ""), p("c"))), 153 | "blockquote", 154 | "bullet_list", 155 | "list_item", 156 | "paragraph", 157 | "foo", 158 | "paragraph", 159 | "b", 160 | ) 161 | 162 | def test_iterates_over_inline_nodes(self): 163 | self.between( 164 | doc( 165 | p( 166 | em("x"), 167 | "foo", 168 | em("bar", img, strong("baz"), br), 169 | "quux", 170 | code("xyz"), 171 | ), 172 | ), 173 | "paragraph", 174 | "foo", 175 | "bar", 176 | "image", 177 | "baz", 178 | "hard_break", 179 | "quux", 180 | "xyz", 181 | ) 182 | 183 | 184 | class TestTextBetween: 185 | def test_passing_custom_function_as_leaf_text(self): 186 | d = doc(p("foo", img, br)) 187 | 188 | def leaf_text(node): 189 | if node.type.name == "image": 190 | return "" 191 | elif node.type.name == "hard_break": 192 | return "" 193 | 194 | text = d.text_between(0, d.content.size, "", leaf_text) 195 | assert text == "foo" 196 | 197 | def test_works_with_leaf_text(self): 198 | d = custom_schema.nodes["doc"].create_checked( 199 | {}, 200 | [ 201 | custom_schema.nodes["paragraph"].create_checked( 202 | {}, 203 | [ 204 | custom_schema.text("Hello "), 205 | custom_schema.nodes["contact"].create_checked({ 206 | "name": "Alice", 207 | "email": "alice@example.com", 208 | }), 209 | ], 210 | ), 211 | ], 212 | ) 213 | assert d.text_between(0, d.content.size) == "Hello Alice " 214 | 215 | def test_should_ignore_leaf_text_spec_when_passing_a_custom_leaf_text(self): 216 | d = custom_schema.nodes["doc"].create_checked( 217 | {}, 218 | [ 219 | custom_schema.nodes["paragraph"].create_checked( 220 | {}, 221 | [ 222 | custom_schema.text("Hello "), 223 | custom_schema.nodes["contact"].create_checked({ 224 | "name": "Alice", 225 | "email": "alice@example.com", 226 | }), 227 | ], 228 | ), 229 | ], 230 | ) 231 | assert ( 232 | d.text_between(0, d.content.size, "", "") == "Hello " 233 | ) 234 | 235 | 236 | class TestTextContent: 237 | def test_whole_doc(self): 238 | assert doc(p("foo")).text_content == "foo" 239 | 240 | def test_text_node(self): 241 | assert schema.text("foo").text_content == "foo" 242 | 243 | def test_nested_element(self): 244 | node = doc(ul(li(p("hi")), li(p(em("a"), "b")))) 245 | assert node.text_content == "hiab" 246 | 247 | 248 | class TestFrom: 249 | @staticmethod 250 | def from_(arg, expect): 251 | assert expect.copy(Fragment.from_(arg)).eq(expect) 252 | 253 | def test_wraps_single_node(self): 254 | self.from_(schema.node("paragraph"), doc(p())) 255 | 256 | def test_wraps_array(self): 257 | self.from_([schema.node("hard_break"), schema.text("foo")], p(br, "foo")) 258 | 259 | def test_preserves_a_fragment(self): 260 | self.from_(doc(p("foo")).content, doc(p("foo"))) 261 | 262 | def test_accepts_null(self): 263 | self.from_(None, p()) 264 | 265 | def test_joins_adjacent_text(self): 266 | self.from_([schema.text("a"), schema.text("b")], p("ab")) 267 | 268 | 269 | class TestToJSON: 270 | @staticmethod 271 | def round_trip(doc): 272 | assert schema.node_from_json(doc.to_json()).eq(doc) 273 | 274 | def test_serialize_simple_node(self): 275 | self.round_trip(doc(p("foo"))) 276 | 277 | def test_serialize_marks(self): 278 | self.round_trip(doc(p("foo", em("bar", strong("baz")), " ", a("x")))) 279 | 280 | def test_serialize_inline_leaf_nodes(self): 281 | self.round_trip(doc(p("foo", em(img, "bar")))) 282 | 283 | def test_serialize_block_leaf_nodes(self): 284 | self.round_trip(doc(p("a"), hr, p("b"), p())) 285 | 286 | def test_serialize_nested_nodes(self): 287 | self.round_trip( 288 | doc(blockquote(ul(li(p("a"), p("b")), li(p(img))), p("c")), p("d")), 289 | ) 290 | -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_resolve.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.test_builder import out 4 | 5 | doc = out["doc"] 6 | p = out["p"] 7 | em = out["em"] 8 | blockquote = out["blockquote"] 9 | 10 | test_doc = doc(p("ab"), blockquote(p(em("cd"), "ef"))) 11 | _doc = {"node": test_doc, "start": 0, "end": 12} 12 | _p1 = {"node": test_doc.child(0), "start": 1, "end": 3} 13 | _blk = {"node": test_doc.child(1), "start": 5, "end": 11} 14 | _p2 = {"node": _blk["node"].child(0), "start": 6, "end": 10} 15 | 16 | 17 | @pytest.mark.parametrize( 18 | ("pos", "exp"), 19 | list( 20 | enumerate([ 21 | [_doc, 0, None, _p1["node"]], 22 | [_doc, _p1, 0, None, "ab"], 23 | [_doc, _p1, 1, "a", "b"], 24 | [_doc, _p1, 2, "ab", None], 25 | [_doc, 4, _p1["node"], _blk["node"]], 26 | [_doc, _blk, 0, None, _p2["node"]], 27 | [_doc, _blk, _p2, 0, None, "cd"], 28 | [_doc, _blk, _p2, 1, "c", "d"], 29 | [_doc, _blk, _p2, 2, "cd", "ef"], 30 | [_doc, _blk, _p2, 3, "e", "f"], 31 | [_doc, _blk, _p2, 4, "ef", None], 32 | [_doc, _blk, 6, _p2["node"], None], 33 | [_doc, 12, _blk["node"], None], 34 | ]), 35 | ), 36 | ) 37 | def test_node_resolve(pos, exp): 38 | pos = test_doc.resolve(pos) 39 | assert pos.depth == len(exp) - 4 40 | for i in range(len(exp) - 3): 41 | assert pos.node(i).eq(exp[i]["node"]) 42 | assert pos.start(i) == exp[i]["start"] 43 | assert pos.end(i) == exp[i]["end"] 44 | if i: 45 | assert pos.before(i) == exp[i]["start"] - 1 46 | assert pos.after(i) == exp[i]["end"] + 1 47 | assert pos.parent_offset == exp[len(exp) - 3] 48 | before = pos.node_before 49 | e_before = exp[len(exp) - 2] 50 | if isinstance(e_before, str): 51 | assert before.text_content == e_before 52 | else: 53 | assert before == e_before 54 | after = pos.node_after 55 | e_after = exp[len(exp) - 1] 56 | if isinstance(e_after, str): 57 | assert after.text_content == e_after 58 | else: 59 | assert after == e_after 60 | 61 | 62 | @pytest.mark.parametrize( 63 | ("pos", "result"), 64 | [ 65 | (0, ":0"), 66 | (1, "paragraph_0:0"), 67 | (7, "blockquote_1/paragraph_0:1"), 68 | ], 69 | ) 70 | def test_resolvedpos_str(pos, result): 71 | assert str(test_doc.resolve(pos)) == result 72 | 73 | 74 | @pytest.fixture 75 | def doc_for_pos_at_index(): 76 | return doc(blockquote(p("one"), blockquote(p("two ", em("three")), p("four")))) 77 | 78 | 79 | @pytest.mark.parametrize( 80 | ("index", "depth", "pos"), 81 | [ 82 | (0, None, 8), 83 | (1, None, 12), 84 | (2, None, 17), 85 | (0, 2, 7), 86 | (1, 2, 18), 87 | (2, 2, 24), 88 | (0, 1, 1), 89 | (1, 1, 6), 90 | (2, 1, 25), 91 | (0, 0, 0), 92 | (1, 0, 26), 93 | ], 94 | ) 95 | def test_pos_at_index(index, depth, pos, doc_for_pos_at_index): 96 | d = doc_for_pos_at_index 97 | 98 | p_three = d.resolve(12) 99 | assert p_three.pos_at_index(index, depth) == pos 100 | -------------------------------------------------------------------------------- /tests/prosemirror_model/tests/test_slice.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.test_builder import out 4 | 5 | doc = out["doc"] 6 | p = out["p"] 7 | li = out["li"] 8 | ul = out["ul"] 9 | em = out["em"] 10 | a = out["a"] 11 | blockquote = out["blockquote"] 12 | 13 | 14 | @pytest.mark.parametrize( 15 | ("doc", "expect", "open_start", "open_end"), 16 | [ 17 | (doc(p("hello world")), doc(p("hello")), 0, 1), 18 | (doc(p("hello")), doc(p("hello")), 0, 1), 19 | (doc(p("hello world"), p("rest")), doc(p("hello")), 0, 1), 20 | (doc(p("hello ", em("WORLD"))), doc(p("hello ", em("WOR"))), 0, 1), 21 | (doc(p("a"), p("b")), doc(p("a"), p("b")), 0, 1), 22 | (doc(p("a"), "", p("b")), doc(p("a")), 0, 0), 23 | ( 24 | doc(blockquote(ul(li(p("a")), li(p("b"))))), 25 | doc(blockquote(ul(li(p("a")), li(p("b"))))), 26 | 0, 27 | 4, 28 | ), 29 | (doc(p("hello world")), doc(p(" world")), 1, 0), 30 | (doc(p("hello")), doc(p("hello")), 1, 0), 31 | (doc(p("foo"), p("barbaz")), doc(p("baz")), 1, 0), 32 | ( 33 | doc(p("a sentence with an ", em("emphasized ", a("link")), " in it")), 34 | doc(p(em(a("nk")), " in it")), 35 | 1, 36 | 0, 37 | ), 38 | ( 39 | doc(p("a ", em("sentence"), " with ", em("text"), " in it")), 40 | doc(p("th ", em("text"), " in it")), 41 | 1, 42 | 0, 43 | ), 44 | (doc(p("a"), "", p("b")), doc(p("b")), 0, 0), 45 | ( 46 | doc(blockquote(ul(li(p("a")), li(p("b"))))), 47 | doc(blockquote(ul(li(p("b"))))), 48 | 4, 49 | 0, 50 | ), 51 | (doc(p("hello world")), p("o wo"), 0, 0), 52 | (doc(p("one"), p("two")), doc(p("e"), p("t")), 1, 1), 53 | ( 54 | doc(p("here's nothing and ", em("here's em"))), 55 | p("ing and ", em("here's e")), 56 | 0, 57 | 0, 58 | ), 59 | ( 60 | doc(ul(li(p("hello")), li(p("world")), li(p("x"))), p(em("boo"))), 61 | doc(ul(li(p("rld")), li(p("x"))), p(em("bo"))), 62 | 3, 63 | 1, 64 | ), 65 | ( 66 | doc( 67 | blockquote( 68 | p("foobar"), 69 | ul(li(p("a")), li(p("b"), "", p("c"))), 70 | p("d"), 71 | ), 72 | ), 73 | blockquote(p("bar"), ul(li(p("a")), li(p("b")))), 74 | 1, 75 | 2, 76 | ), 77 | ], 78 | ) 79 | def test_slice_cut(doc, expect, open_start, open_end): 80 | slice = doc.slice(doc.tag.get("a", 0), doc.tag.get("b")) 81 | assert slice.content.eq(expect.content) 82 | assert slice.open_start == open_start 83 | assert slice.open_end == open_end 84 | 85 | 86 | def test_slice_can_include_parents(): 87 | d = doc(blockquote(p("foo"), p("bar"))) 88 | slice = d.slice(d.tag["a"], d.tag["b"], True) 89 | assert str(slice) == '(2,2)' 90 | -------------------------------------------------------------------------------- /tests/prosemirror_transform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/tests/prosemirror_transform/__init__.py -------------------------------------------------------------------------------- /tests/prosemirror_transform/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fellowapp/prosemirror-py/c996d5e23a8d6ef7360db26bf91f815a86a1587a/tests/prosemirror_transform/tests/__init__.py -------------------------------------------------------------------------------- /tests/prosemirror_transform/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.model import Fragment, Slice 4 | from prosemirror.test_builder import out 5 | from prosemirror.test_builder import test_schema as schema 6 | from prosemirror.transform import ( 7 | AddMarkStep, 8 | Mapping, 9 | RemoveMarkStep, 10 | ReplaceStep, 11 | Step, 12 | StepMap, 13 | Transform, 14 | ) 15 | 16 | doc = out["doc"] 17 | p = out["p"] 18 | 19 | 20 | @pytest.fixture 21 | def test_mapping(): 22 | def t_mapping(mapping, *cases): 23 | inverted = mapping.invert() 24 | for case in cases: 25 | from_, to, bias, lossy = ( 26 | lambda from_, to, bias=1, lossy=False: (from_, to, bias, lossy) 27 | )(*case) 28 | assert mapping.map(from_, bias) == to 29 | if not lossy: 30 | assert inverted.map(to, bias) == from_ 31 | 32 | return t_mapping 33 | 34 | 35 | @pytest.fixture 36 | def make_mapping(): 37 | def mk(*args): 38 | mapping = Mapping() 39 | for arg in args: 40 | if isinstance(arg, list): 41 | mapping.append_map(StepMap(arg)) 42 | else: 43 | for from_ in arg: 44 | mapping.set_mirror(from_, arg[from_]) 45 | return mapping 46 | 47 | return mk 48 | 49 | 50 | @pytest.fixture 51 | def test_del(): 52 | def t_del(mapping: Mapping, pos: int, side: int, flags: str): 53 | r = mapping.map_result(pos, side) 54 | found = "" 55 | if r.deleted: 56 | found += "d" 57 | if r.deleted_before: 58 | found += "b" 59 | if r.deleted_after: 60 | found += "a" 61 | if r.deleted_across: 62 | found += "x" 63 | assert found == flags 64 | 65 | return t_del 66 | 67 | 68 | @pytest.fixture 69 | def make_step(): 70 | return _make_step 71 | 72 | 73 | def _make_step(from_: int, to: int, val: str | None) -> Step: 74 | if val == "+em": 75 | return AddMarkStep(from_, to, schema.marks["em"].create()) 76 | elif val == "-em": 77 | return RemoveMarkStep(from_, to, schema.marks["em"].create()) 78 | return ReplaceStep( 79 | from_, 80 | to, 81 | Slice.empty if val is None else Slice(Fragment.from_(schema.text(val)), 0, 0), 82 | ) 83 | 84 | 85 | @pytest.fixture 86 | def test_doc(): 87 | return doc(p("foobar")) 88 | 89 | 90 | _test_doc = doc(p("foobar")) 91 | 92 | 93 | @pytest.fixture 94 | def test_transform(): 95 | def invert(transform): 96 | out = Transform(transform.doc) 97 | for i, step in reversed(list(enumerate(transform.steps))): 98 | out.step(step.invert(transform.docs[i])) 99 | return out 100 | 101 | def test_step_json(tr): 102 | new_tr = Transform(tr.before) 103 | for step in tr.steps: 104 | new_tr.step(Step.from_json(tr.doc.type.schema, step.to_json())) 105 | 106 | def test_mapping(mapping, pos, new_pos): 107 | mapped = mapping.map(pos, 1) 108 | assert mapped == new_pos 109 | remap = Mapping([m.invert() for m in mapping.maps]) 110 | for i, map in enumerate(reversed(mapping.maps)): 111 | remap.append_map(map, len(mapping.maps) - 1 - i) 112 | assert remap.map(pos, 1) == pos 113 | 114 | def test_transform(tr, expect): 115 | assert tr.doc.eq(expect) 116 | assert invert(tr).doc.eq(tr.before) 117 | test_step_json(tr) 118 | 119 | for tag in expect.tag: 120 | test_mapping(tr.mapping, tr.before.tag.get(tag), expect.tag.get(tag)) 121 | 122 | return test_transform 123 | -------------------------------------------------------------------------------- /tests/prosemirror_transform/tests/test_mapping.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.parametrize( 5 | ("mapping_info", "cases"), 6 | [ 7 | ([[2, 0, 4]], [[0, 0], [2, 6], [2, 2, -1], [3, 7]]), 8 | ( 9 | [[2, 4, 0]], 10 | [[0, 0], [2, 2, -1], [3, 2, 1, True], [6, 2, 1], [6, 2, -1, True], [7, 3]], 11 | ), 12 | ( 13 | [[2, 4, 4]], 14 | [[0, 0], [2, 2, 1], [4, 6, 1, True], [4, 2, -1, True], [6, 6, -1], [8, 8]], 15 | ), 16 | ([[2, 4, 0], [2, 0, 4], {0: 1}], [[0, 0], [2, 2], [4, 4], [6, 6], [7, 7]]), 17 | ([[2, 0, 4], [2, 4, 0], {0: 1}], [[0, 0], [2, 2], [3, 3]]), 18 | ( 19 | [[2, 4, 0], [1, 0, 1], [3, 0, 4], {0: 2}], 20 | [[0, 0], [1, 2], [4, 5], [6, 7], [7, 8]], 21 | ), 22 | ], 23 | ) 24 | def test_all_mapping_cases(mapping_info, cases, test_mapping, make_mapping): 25 | test_mapping(make_mapping(*mapping_info), *cases) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | ("mapping_info", "pos", "side", "flags"), 30 | [ 31 | (([0, 2, 0],), 2, -1, "db"), 32 | (([0, 2, 0],), 2, 1, "b"), 33 | (([0, 2, 2],), 2, -1, "db"), 34 | ( 35 | ( 36 | [0, 1, 0], 37 | [0, 1, 0], 38 | ), 39 | 2, 40 | -1, 41 | "db", 42 | ), 43 | (([0, 1, 0],), 2, -1, ""), 44 | (([2, 2, 0],), 2, -1, "a"), 45 | (([2, 2, 0],), 2, 1, "da"), 46 | (([2, 2, 2],), 2, 1, "da"), 47 | ( 48 | ( 49 | [2, 1, 0], 50 | [2, 1, 0], 51 | ), 52 | 2, 53 | 1, 54 | "da", 55 | ), 56 | (([3, 2, 0],), 2, -1, ""), 57 | (([0, 4, 0],), 2, -1, "dbax"), 58 | (([0, 4, 0],), 2, 1, "dbax"), 59 | ( 60 | ( 61 | [0, 1, 0], 62 | [4, 1, 0], 63 | [0, 3, 0], 64 | ), 65 | 2, 66 | 1, 67 | "dbax", 68 | ), 69 | ( 70 | ( 71 | [4, 1, 0], 72 | [0, 1, 0], 73 | ), 74 | 2, 75 | -1, 76 | "", 77 | ), 78 | ( 79 | ( 80 | [2, 1, 0], 81 | [0, 2, 0], 82 | ), 83 | 2, 84 | -1, 85 | "dba", 86 | ), 87 | ( 88 | ( 89 | [2, 1, 0], 90 | [0, 1, 0], 91 | ), 92 | 2, 93 | -1, 94 | "a", 95 | ), 96 | ( 97 | ( 98 | [3, 1, 0], 99 | [0, 2, 0], 100 | ), 101 | 2, 102 | -1, 103 | "db", 104 | ), 105 | ], 106 | ) 107 | def test_all_del_cases(mapping_info, pos, side, flags, test_del, make_mapping): 108 | test_del(make_mapping(*mapping_info), pos, side, flags) 109 | -------------------------------------------------------------------------------- /tests/prosemirror_transform/tests/test_step.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .conftest import _make_step, _test_doc 4 | 5 | 6 | def yes(from1, to1, val1, from2, to2, val2): 7 | def inner(): 8 | step1 = _make_step(from1, to1, val1) 9 | step2 = _make_step(from2, to2, val2) 10 | merged = step1.merge(step2) 11 | assert merged 12 | assert merged.apply(_test_doc).doc.eq( 13 | step2.apply(step1.apply(_test_doc).doc).doc, 14 | ) 15 | 16 | return inner 17 | 18 | 19 | def no(from1, to1, val1, from2, to2, val2): 20 | def inner(): 21 | step1 = _make_step(from1, to1, val1) 22 | step2 = _make_step(from2, to2, val2) 23 | merged = step1.merge(step2) 24 | assert merged is None 25 | 26 | return inner 27 | 28 | 29 | @pytest.mark.parametrize( 30 | ("pass_", "from1", "to1", "val1", "from2", "to2", "val2"), 31 | [ 32 | (yes, 2, 2, "a", 3, 3, "b"), 33 | (yes, 2, 2, "a", 2, 2, "b"), 34 | (no, 2, 2, "a", 4, 4, "b"), 35 | (no, 3, 3, "a", 2, 2, "b"), 36 | (yes, 3, 4, None, 2, 3, None), 37 | (yes, 2, 3, None, 2, 3, None), 38 | (no, 1, 2, None, 2, 3, None), 39 | (yes, 2, 3, None, 2, 2, "x"), 40 | (yes, 2, 2, "quux", 6, 6, "baz"), 41 | (yes, 2, 2, "quux", 2, 2, "baz"), 42 | (yes, 2, 5, None, 2, 4, None), 43 | (yes, 4, 6, None, 2, 4, None), 44 | (yes, 3, 4, "x", 4, 5, "y"), 45 | (yes, 1, 2, "+em", 2, 4, "+em"), 46 | (yes, 1, 3, "+em", 2, 4, "+em"), 47 | (no, 1, 2, "+em", 3, 4, "+em"), 48 | (yes, 1, 2, "-em", 2, 4, "-em"), 49 | (yes, 1, 3, "-em", 2, 4, "-em"), 50 | (no, 1, 2, "-em", 3, 4, "-em"), 51 | ], 52 | ) 53 | def test_all_cases(pass_, from1, to1, val1, from2, to2, val2): 54 | pass_(from1, to1, val1, from2, to2, val2)() 55 | -------------------------------------------------------------------------------- /tests/prosemirror_transform/tests/test_structure.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from prosemirror.model import Schema, Slice 4 | from prosemirror.transform import Transform, can_split, find_wrapping, lift_target 5 | from prosemirror.transform.structure import NodeTypeWithAttrs 6 | 7 | schema = Schema({ 8 | "nodes": { 9 | "doc": {"content": "head? block* sect* closing?"}, 10 | "para": {"content": "text*", "group": "block"}, 11 | "head": {"content": "text*", "marks": ""}, 12 | "figure": {"content": "caption figureimage", "group": "block"}, 13 | "quote": {"content": "block+", "group": "block"}, 14 | "figureimage": {}, 15 | "caption": {"content": "text*", "marks": ""}, 16 | "sect": {"content": "head block* sect*"}, 17 | "closing": {"content": "text*"}, 18 | "text": {"group": "inline"}, 19 | "fixed": {"content": "head para closing", "group": "block"}, 20 | }, 21 | "marks": {"em": {}}, 22 | }) 23 | 24 | 25 | def n(name, *content): 26 | return schema.nodes[name].create(None, list(content)) 27 | 28 | 29 | def t(str, em=None): 30 | return schema.text(str, [schema.mark["em"]] if em else None) 31 | 32 | 33 | doc = n( 34 | "doc", # 0 35 | n("head", t("Head")), # 6 36 | n("para", t("Intro")), # 13 37 | n( 38 | "sect", # 14 39 | n("head", t("Section head")), # 28 40 | n( 41 | "sect", # 29 42 | n("head", t("Subsection head")), # 46 43 | n("para", t("Subtext")), # 55 44 | n( 45 | "figure", 46 | n("caption", t("Figure caption")), 47 | n("figureimage"), 48 | ), # 56 # 72 # 74 49 | n("quote", n("para", t("!"))), 50 | ), 51 | ), # 81 52 | n("sect", n("head", t("S2")), n("para", t("Yes"))), # 82 # 86 # 92 53 | n("closing", t("fin")), 54 | ) # 97 55 | 56 | 57 | def range_(pos, end=None): 58 | return doc.resolve(pos).block_range(None if end is None else doc.resolve(end)) 59 | 60 | 61 | def fill(params, length): 62 | new_params = [] 63 | for item in params: 64 | item = list(item) 65 | diff = length - len(item) 66 | if diff > 0: 67 | item += [None] * diff 68 | new_params.append(item) 69 | return new_params 70 | 71 | 72 | class TestCanSplit: 73 | @pytest.mark.parametrize( 74 | ("pass_", "pos", "depth", "after"), 75 | fill( 76 | [ 77 | (False, 0), 78 | (False, 3), 79 | (True, 3, 1, "para"), 80 | (False, 6), 81 | (True, 8), 82 | (False, 14), 83 | (False, 17), 84 | (True, 17, 2), 85 | (True, 18, 1, "para"), 86 | (False, 46), 87 | (True, 48), 88 | (False, 60), 89 | (False, 62, 2), 90 | (False, 72), 91 | (True, 76), 92 | (True, 77, 2), 93 | (False, 97), 94 | ], 95 | 4, 96 | ), 97 | ) 98 | def test_can_split(self, pass_, pos, depth, after): 99 | res = can_split( 100 | doc, 101 | pos, 102 | depth, 103 | [NodeTypeWithAttrs(type=schema.nodes[after])] if after else None, 104 | ) 105 | if pass_: 106 | assert res 107 | else: 108 | assert not res 109 | 110 | def test_doesnt_return_true_when_split_content_doesnt_fit_in_given_node_type( 111 | self, 112 | ): 113 | s = Schema({ 114 | "nodes": { 115 | "doc": {"content": "chapter+"}, 116 | "para": {"content": "text*", "group": "block"}, 117 | "head": {"content": "text*", "marks": ""}, 118 | "figure": {"content": "caption figureimage", "group": "block"}, 119 | "quote": {"content": "block+", "group": "block"}, 120 | "figureimage": {}, 121 | "caption": {"content": "text*", "marks": ""}, 122 | "sect": {"content": "head block* sect*"}, 123 | "closing": {"content": "text*"}, 124 | "text": {"group": "inline"}, 125 | "fixed": {"content": "head para closing", "group": "block"}, 126 | "title": {"content": "text*"}, 127 | "chapter": {"content": "title scene+"}, 128 | "scene": {"content": "para+"}, 129 | }, 130 | }) 131 | assert not can_split( 132 | s.node( 133 | "doc", 134 | None, 135 | s.node( 136 | "chapter", 137 | None, 138 | [ 139 | s.node("title", None, s.text("title")), 140 | s.node("scene", None, s.node("para", None, s.text("scene"))), 141 | ], 142 | ), 143 | ), 144 | 4, 145 | 1, 146 | [NodeTypeWithAttrs(s.nodes["scene"])], 147 | ) 148 | 149 | 150 | class TestLiftTarget: 151 | @pytest.mark.parametrize( 152 | ("pass_", "pos"), 153 | [(False, 0), (False, 3), (False, 52), (False, 70), (True, 76), (False, 86)], 154 | ) 155 | def test_lift_target(self, pass_, pos): 156 | r = range_(pos) 157 | if pass_: 158 | assert bool(r and lift_target(r)) 159 | else: 160 | assert not bool(r and lift_target(r)) 161 | 162 | 163 | class TestFindWrapping: 164 | @pytest.mark.parametrize( 165 | ("pass_", "pos", "end", "type"), 166 | [ 167 | (True, 0, 92, "sect"), 168 | (False, 4, 4, "sect"), 169 | (True, 8, 8, "quote"), 170 | (False, 18, 18, "quote"), 171 | (True, 55, 74, "quote"), 172 | (False, 90, 90, "figure"), 173 | ], 174 | ) 175 | def test_find_wrapping(self, pass_, pos, end, type): 176 | r = range_(pos, end) 177 | if pass_: 178 | assert find_wrapping(r, schema.nodes[type]) 179 | else: 180 | assert not bool(find_wrapping(r, schema.nodes[type])) 181 | 182 | 183 | @pytest.mark.parametrize( 184 | ("doc", "from_", "to", "content", "open_start", "open_end", "result"), 185 | [ 186 | ( 187 | n("doc", n("sect", n("head", t("foo")), n("para", t("bar")))), 188 | 6, 189 | 6, 190 | n("doc", n("sect"), n("sect")), 191 | 1, 192 | 1, 193 | n( 194 | "doc", 195 | n("sect", n("head", t("foo"))), 196 | n("sect", n("head"), n("para", t("bar"))), 197 | ), 198 | ), 199 | ( 200 | n("doc", n("para", t("a")), n("para", t("b"))), 201 | 3, 202 | 3, 203 | n("doc", n("closing", t("."))), 204 | 0, 205 | 0, 206 | n("doc", n("para", t("a")), n("para", t("b"))), 207 | ), 208 | ( 209 | n("doc", n("sect", n("head", t("foo")), n("para", t("bar")))), 210 | 1, 211 | 3, 212 | n("doc", n("sect"), n("sect", n("head", t("hi")))), 213 | 1, 214 | 2, 215 | n( 216 | "doc", 217 | n("sect", n("head")), 218 | n("sect", n("head", t("hioo")), n("para", t("bar"))), 219 | ), 220 | ), 221 | ( 222 | n("doc"), 223 | 0, 224 | 0, 225 | n("doc", n("figure", n("figureimage"))), 226 | 1, 227 | 0, 228 | n("doc", n("figure", n("caption"), n("figureimage"))), 229 | ), 230 | ( 231 | n("doc"), 232 | 0, 233 | 0, 234 | n("doc", n("figure", n("caption"))), 235 | 0, 236 | 1, 237 | n("doc", n("figure", n("caption"), n("figureimage"))), 238 | ), 239 | ( 240 | n( 241 | "doc", 242 | n("figure", n("caption"), n("figureimage")), 243 | n("figure", n("caption"), n("figureimage")), 244 | ), 245 | 3, 246 | 8, 247 | None, 248 | 0, 249 | 0, 250 | n("doc", n("figure", n("caption"), n("figureimage"))), 251 | ), 252 | ( 253 | n("doc", n("sect", n("head"), n("figure", n("caption"), n("figureimage")))), 254 | 7, 255 | 9, 256 | n("doc", n("para", t("hi"))), 257 | 0, 258 | 0, 259 | n( 260 | "doc", 261 | n( 262 | "sect", 263 | n("head"), 264 | n("figure", n("caption"), n("figureimage")), 265 | n("para", t("hi")), 266 | ), 267 | ), 268 | ), 269 | ], 270 | ) 271 | def test_replace(doc, from_, to, content, open_start, open_end, result): 272 | slice = Slice(content.content, open_start, open_end) if content else Slice.empty 273 | tr = Transform(doc).replace(from_, to, slice) 274 | assert tr.doc.eq(result) 275 | --------------------------------------------------------------------------------