├── tests ├── __init__.py ├── conftest.py ├── test_utils.py ├── test_serializer.py ├── test_entry.py ├── test_node.py ├── test_memory.py └── test_tree.py ├── MANIFEST.in ├── .gitignore ├── bplustree ├── __init__.py ├── utils.py ├── const.py ├── serializer.py ├── entry.py ├── node.py ├── tree.py └── memory.py ├── Makefile ├── .travis.yml ├── LICENSE ├── setup.py └── README.rst /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.rst -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pyc 2 | __pycache__ 3 | .cache 4 | build 5 | dist 6 | bplustree.egg-info 7 | .coverage 8 | .pytest_cache 9 | -------------------------------------------------------------------------------- /bplustree/__init__.py: -------------------------------------------------------------------------------- 1 | from .tree import BPlusTree 2 | from .serializer import ( 3 | IntSerializer, StrSerializer, UUIDSerializer, DatetimeUTCSerializer 4 | ) 5 | from .const import VERSION 6 | 7 | __version__ = VERSION 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | publish: 2 | pip install -U pip setuptools wheel twine 3 | python setup.py sdist 4 | python setup.py bdist_wheel 5 | twine upload dist/* 6 | rm -fr build dist bplustree.egg-info 7 | 8 | clean: 9 | rm -fr build dist bplustree.egg-info 10 | 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: xenial 3 | sudo: true 4 | 5 | python: 6 | - "3.5" 7 | - "3.6" 8 | - "3.7" 9 | 10 | install: 11 | - pip install -e .[tests,datetime] 12 | 13 | script: 14 | - pytest -v --cov=bplustree tests/ 15 | - pycodestyle --ignore=E252,E226,W504 bplustree tests 16 | 17 | after_success: 18 | - coveralls 19 | 20 | notifications: 21 | email: false 22 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | filename = '/tmp/bplustree-testfile.index' 7 | 8 | 9 | @pytest.fixture(autouse=True) 10 | def clean_file(): 11 | if os.path.isfile(filename): 12 | os.unlink(filename) 13 | if os.path.isfile(filename + '-wal'): 14 | os.unlink(filename + '-wal') 15 | yield 16 | if os.path.isfile(filename): 17 | os.unlink(filename) 18 | if os.path.isfile(filename + '-wal'): 19 | os.unlink(filename + '-wal') 20 | 21 | 22 | @pytest.fixture(autouse=True) 23 | def patch_fsync(): 24 | mock_fsync = mock.patch('os.fsync') 25 | mock_fsync.start() 26 | yield 27 | mock_fsync.stop() 28 | -------------------------------------------------------------------------------- /bplustree/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Iterable 3 | 4 | 5 | def pairwise(iterable: Iterable): 6 | """Iterate over elements two by two. 7 | 8 | s -> (s0,s1), (s1,s2), (s2, s3), ... 9 | """ 10 | a, b = itertools.tee(iterable) 11 | next(b, None) 12 | return zip(a, b) 13 | 14 | 15 | def iter_slice(iterable: bytes, n: int): 16 | """Yield slices of size n and says if each slice is the last one. 17 | 18 | s -> (b'123', False), (b'45', True) 19 | """ 20 | start = 0 21 | stop = start + n 22 | final_offset = len(iterable) 23 | 24 | while True: 25 | if start >= final_offset: 26 | break 27 | 28 | rv = iterable[start:stop] 29 | start = stop 30 | stop = start + n 31 | yield rv, start >= final_offset 32 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from bplustree.utils import pairwise, iter_slice 4 | 5 | 6 | def test_pairwise(): 7 | i = pairwise([0, 1, 2, 3, 4]) 8 | assert next(i) == (0, 1) 9 | assert next(i) == (1, 2) 10 | assert next(i) == (2, 3) 11 | assert next(i) == (3, 4) 12 | with pytest.raises(StopIteration): 13 | next(i) 14 | 15 | 16 | def test_iter_slice(): 17 | i = iter_slice(b'12345678', 3) 18 | assert next(i) == (b'123', False) 19 | assert next(i) == (b'456', False) 20 | assert next(i) == (b'78', True) 21 | with pytest.raises(StopIteration): 22 | next(i) 23 | 24 | i = iter_slice(b'123456', 3) 25 | assert next(i) == (b'123', False) 26 | assert next(i) == (b'456', True) 27 | with pytest.raises(StopIteration): 28 | next(i) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Nicolas Le Manchet 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /bplustree/const.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | VERSION = '0.0.4.dev1' 4 | 5 | # Endianess for storing numbers 6 | ENDIAN = 'little' 7 | 8 | # Bytes used for storing references to pages 9 | # Can address 16 TB of memory with 4 KB pages 10 | PAGE_REFERENCE_BYTES = 4 11 | 12 | # Bytes used for storing the type of the node in page header 13 | NODE_TYPE_BYTES = 1 14 | 15 | # Bytes used for storing the length of the page payload in page header 16 | USED_PAGE_LENGTH_BYTES = 3 17 | 18 | # Bytes used for storing the length of the key or value payload in record 19 | # header. Limits the maximum length of a key or value to 64 KB. 20 | USED_KEY_LENGTH_BYTES = 2 21 | USED_VALUE_LENGTH_BYTES = 2 22 | 23 | # Max 256 types of frames 24 | FRAME_TYPE_BYTES = 1 25 | 26 | # Bytes used for storing general purpose integers like file metadata 27 | OTHERS_BYTES = 4 28 | 29 | 30 | TreeConf = namedtuple('TreeConf', [ 31 | 'page_size', # Size of a page within the tree in bytes 32 | 'order', # Branching factor of the tree 33 | 'key_size', # Maximum size of a key in bytes 34 | 'value_size', # Maximum size of a value in bytes 35 | 'serializer', # Instance of a Serializer 36 | ]) 37 | -------------------------------------------------------------------------------- /tests/test_serializer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | from unittest import mock 3 | import uuid 4 | 5 | import pytest 6 | 7 | from bplustree.serializer import ( 8 | IntSerializer, StrSerializer, UUIDSerializer, DatetimeUTCSerializer 9 | ) 10 | 11 | 12 | def test_int_serializer(): 13 | s = IntSerializer() 14 | assert s.serialize(42, 2) == b'*\x00' 15 | assert s.deserialize(b'*\x00') == 42 16 | assert repr(s) == 'IntSerializer()' 17 | 18 | 19 | def test_serializer_slots(): 20 | s = IntSerializer() 21 | with pytest.raises(AttributeError): 22 | s.foo = True 23 | 24 | 25 | def test_str_serializer(): 26 | s = StrSerializer() 27 | assert s.serialize('foo', 3) == b'foo' 28 | assert s.deserialize(b'foo') == 'foo' 29 | assert repr(s) == 'StrSerializer()' 30 | 31 | 32 | def test_uuid_serializer(): 33 | s = UUIDSerializer() 34 | id_ = uuid.uuid4() 35 | assert s.serialize(id_, 16) == id_.bytes 36 | assert s.deserialize(id_.bytes) == id_ 37 | assert repr(s) == 'UUIDSerializer()' 38 | 39 | 40 | def test_datetime_utc_serializer(): 41 | s = DatetimeUTCSerializer() 42 | dt = datetime(2018, 1, 6, 21, 42, 2, 424739, tzinfo=timezone.utc) 43 | serialized = s.serialize(dt, 8) 44 | assert serialized == b'W\xe2\x02\xd6\xa0\x99\xec\x8c' 45 | assert s.deserialize(serialized) == dt 46 | assert repr(s) == 'DatetimeUTCSerializer()' 47 | 48 | 49 | @mock.patch.dict('bplustree.serializer.__dict__', {'temporenc': None}) 50 | def test_datetime_utc_serializer_no_temporenc(): 51 | with pytest.raises(RuntimeError): 52 | DatetimeUTCSerializer() 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from codecs import open 3 | from os import path 4 | 5 | here = path.abspath(path.dirname(__file__)) 6 | 7 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | with open(path.join(here, 'LICENSE'), encoding='utf-8') as f: 11 | long_description += f.read() 12 | 13 | with open(path.join(here, 'bplustree', 'const.py'), encoding='utf-8') as fp: 14 | version = dict() 15 | exec(fp.read(), version) 16 | version = version['VERSION'] 17 | 18 | setup( 19 | name='bplustree', 20 | version=version, 21 | description='On-disk B+tree for Python 3', 22 | long_description=long_description, 23 | url='https://github.com/NicolasLM/bplustree', 24 | author='Nicolas Le Manchet', 25 | author_email='nicolas@lemanchet.fr', 26 | license='MIT', 27 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 28 | classifiers=[ 29 | 'Development Status :: 4 - Beta', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Software Development :: Libraries', 32 | 'Topic :: Database', 33 | 'License :: OSI Approved :: MIT License', 34 | 'Natural Language :: English', 35 | 'Programming Language :: Python :: 3', 36 | 'Programming Language :: Python :: 3 :: Only', 37 | 'Programming Language :: Python :: 3.5', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Programming Language :: Python :: 3.7', 40 | ], 41 | keywords='bplustree B+tree Btree database index', 42 | 43 | packages=find_packages(include=('bplustree', 'bplustree.*')), 44 | install_requires=[ 45 | 'rwlock', 46 | 'cachetools' 47 | ], 48 | 49 | extras_require={ 50 | 'tests': [ 51 | 'pytest', 52 | 'pytest-cov', 53 | 'python-coveralls', 54 | 'pycodestyle' 55 | ], 56 | 'datetime': [ 57 | 'temporenc', 58 | ], 59 | }, 60 | ) 61 | -------------------------------------------------------------------------------- /bplustree/serializer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from datetime import datetime, timezone 3 | from uuid import UUID 4 | 5 | try: 6 | import temporenc 7 | except ImportError: 8 | temporenc = None 9 | 10 | from .const import ENDIAN 11 | 12 | 13 | class Serializer(metaclass=abc.ABCMeta): 14 | 15 | __slots__ = [] 16 | 17 | @abc.abstractmethod 18 | def serialize(self, obj: object, key_size: int) -> bytes: 19 | """Serialize a key to bytes.""" 20 | 21 | @abc.abstractmethod 22 | def deserialize(self, data: bytes) -> object: 23 | """Create a key object from bytes.""" 24 | 25 | def __repr__(self): 26 | return '{}()'.format(self.__class__.__name__) 27 | 28 | 29 | class IntSerializer(Serializer): 30 | 31 | __slots__ = [] 32 | 33 | def serialize(self, obj: int, key_size: int) -> bytes: 34 | return obj.to_bytes(key_size, ENDIAN) 35 | 36 | def deserialize(self, data: bytes) -> int: 37 | return int.from_bytes(data, ENDIAN) 38 | 39 | 40 | class StrSerializer(Serializer): 41 | 42 | __slots__ = [] 43 | 44 | def serialize(self, obj: str, key_size: int) -> bytes: 45 | rv = obj.encode(encoding='utf-8') 46 | assert len(rv) <= key_size 47 | return rv 48 | 49 | def deserialize(self, data: bytes) -> str: 50 | return data.decode(encoding='utf-8') 51 | 52 | 53 | class UUIDSerializer(Serializer): 54 | 55 | __slots__ = [] 56 | 57 | def serialize(self, obj: UUID, key_size: int) -> bytes: 58 | return obj.bytes 59 | 60 | def deserialize(self, data: bytes) -> UUID: 61 | return UUID(bytes=data) 62 | 63 | 64 | class DatetimeUTCSerializer(Serializer): 65 | 66 | __slots__ = [] 67 | 68 | def __init__(self): 69 | if temporenc is None: 70 | raise RuntimeError('Serialization to/from datetime needs the ' 71 | 'third-party library "temporenc"') 72 | 73 | def serialize(self, obj: datetime, key_size: int) -> bytes: 74 | if obj.tzinfo is None: 75 | raise ValueError('DatetimeUTCSerializer needs a timezone aware ' 76 | 'datetime') 77 | return temporenc.packb(obj, type='DTS') 78 | 79 | def deserialize(self, data: bytes) -> datetime: 80 | rv = temporenc.unpackb(data).datetime() 81 | rv = rv.replace(tzinfo=timezone.utc) 82 | return rv 83 | -------------------------------------------------------------------------------- /tests/test_entry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from bplustree.entry import Record, Reference, OpaqueData, NOT_LOADED 4 | from bplustree.const import TreeConf 5 | from bplustree.serializer import IntSerializer, StrSerializer 6 | 7 | tree_conf = TreeConf(4096, 4, 16, 16, IntSerializer()) 8 | 9 | 10 | def test_record_int_serialization(): 11 | r1 = Record(tree_conf, 42, b'foo') 12 | data = r1.dump() 13 | 14 | r2 = Record(tree_conf, data=data) 15 | assert r1 == r2 16 | assert r1.value == r2.value 17 | assert r1.overflow_page == r2.overflow_page 18 | 19 | 20 | def test_record_str_serialization(): 21 | tree_conf = TreeConf(4096, 4, 40, 40, StrSerializer()) 22 | r1 = Record(tree_conf, '0', b'0') 23 | data = r1.dump() 24 | 25 | r2 = Record(tree_conf, data=data) 26 | assert r1 == r2 27 | assert r1.value == r2.value 28 | assert r1.overflow_page == r2.overflow_page 29 | 30 | 31 | def test_record_int_serialization_overflow_value(): 32 | r1 = Record(tree_conf, 42, overflow_page=5) 33 | data = r1.dump() 34 | 35 | r2 = Record(tree_conf, data=data) 36 | assert r1 == r2 37 | assert r1.value == r2.value 38 | assert r1.overflow_page == r2.overflow_page 39 | 40 | 41 | def test_record_repr(): 42 | r1 = Record(tree_conf, 42, b'foo') 43 | assert repr(r1) == "" 44 | 45 | r1.value = None 46 | assert repr(r1) == "" 47 | 48 | r1.overflow_page = 5 49 | assert repr(r1) == "" 50 | 51 | 52 | def test_record_slots(): 53 | r1 = Record(tree_conf, 42, b'foo') 54 | with pytest.raises(AttributeError): 55 | r1.foo = True 56 | 57 | 58 | def test_record_lazy_load(): 59 | data = Record(tree_conf, 42, b'foo').dump() 60 | r = Record(tree_conf, data=data) 61 | 62 | assert r._data == data 63 | assert r._key == NOT_LOADED 64 | assert r._value == NOT_LOADED 65 | assert r._overflow_page == NOT_LOADED 66 | 67 | _ = r.key 68 | assert r._key == 42 69 | assert r._value == b'foo' 70 | assert r._overflow_page is None 71 | assert r._data == data 72 | 73 | r.key = 27 74 | assert r._key == 27 75 | assert r._data is None 76 | 77 | 78 | def test_reference_int_serialization(): 79 | r1 = Reference(tree_conf, 42, 1, 2) 80 | data = r1.dump() 81 | 82 | r2 = Reference(tree_conf, data=data) 83 | assert r1 == r2 84 | assert r1.before == r2.before 85 | assert r1.after == r2.after 86 | 87 | 88 | def test_reference_str_serialization(): 89 | tree_conf = TreeConf(4096, 4, 40, 40, StrSerializer()) 90 | r1 = Reference(tree_conf, 'foo', 1, 2) 91 | data = r1.dump() 92 | 93 | r2 = Reference(tree_conf, data=data) 94 | assert r1 == r2 95 | assert r1.before == r2.before 96 | assert r1.after == r2.after 97 | 98 | 99 | def test_reference_repr(): 100 | r1 = Reference(tree_conf, 42, 1, 2) 101 | assert repr(r1) == '' 102 | 103 | 104 | def test_reference_lazy_load(): 105 | data = Reference(tree_conf, 42, 1, 2).dump() 106 | r = Reference(tree_conf, data=data) 107 | 108 | assert r._data == data 109 | assert r._key == NOT_LOADED 110 | assert r._before == NOT_LOADED 111 | assert r._after == NOT_LOADED 112 | 113 | _ = r.key 114 | assert r._key == 42 115 | assert r._before == 1 116 | assert r._after == 2 117 | assert r._data == data 118 | 119 | r.key = 27 120 | assert r._key == 27 121 | assert r._data is None 122 | 123 | 124 | def test_opaque_data(): 125 | data = b'foo' 126 | o = OpaqueData(data=data) 127 | assert o.data == data 128 | assert o.dump() == data 129 | 130 | o = OpaqueData() 131 | o.load(data) 132 | assert o.data == data 133 | assert o.dump() == data 134 | 135 | assert repr(o) == "" 136 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Bplustree 2 | ========= 3 | 4 | .. image:: https://travis-ci.org/NicolasLM/bplustree.svg?branch=master 5 | :target: https://travis-ci.org/NicolasLM/bplustree 6 | .. image:: https://coveralls.io/repos/github/NicolasLM/bplustree/badge.svg?branch=master 7 | :target: https://coveralls.io/github/NicolasLM/bplustree?branch=master 8 | 9 | An on-disk B+tree for Python 3. 10 | 11 | It feels like a dict, but stored on disk. When to use it? 12 | 13 | - When the data to store does not fit in memory 14 | - When the data needs to be persisted 15 | - When keeping the keys in order is important 16 | 17 | This project is under development: the format of the file may change between 18 | versions. Do not use as your primary source of data. 19 | 20 | Quickstart 21 | ---------- 22 | 23 | Install Bplustree with pip:: 24 | 25 | pip install bplustree 26 | 27 | Create a B+tree index stored on a file and use it with: 28 | 29 | .. code:: python 30 | 31 | >>> from bplustree import BPlusTree 32 | >>> tree = BPlusTree('/tmp/bplustree.db', order=50) 33 | >>> tree[1] = b'foo' 34 | >>> tree[2] = b'bar' 35 | >>> tree[1] 36 | b'foo' 37 | >>> tree.get(3) 38 | >>> tree.close() 39 | 40 | Keys and values 41 | --------------- 42 | 43 | Keys must have a natural order and must be serializable to bytes. Some default 44 | serializers for the most common types are provided. For example to index UUIDs: 45 | 46 | .. code:: python 47 | 48 | >>> import uuid 49 | >>> from bplustree import BPlusTree, UUIDSerializer 50 | >>> tree = BPlusTree('/tmp/bplustree.db', serializer=UUIDSerializer(), key_size=16) 51 | >>> tree.insert(uuid.uuid1(), b'foo') 52 | >>> list(tree.keys()) 53 | [UUID('48f2553c-de23-4d20-95bf-6972a89f3bc0')] 54 | 55 | Values on the other hand are always bytes. They can be of arbitrary length, 56 | the parameter ``value_size=128`` defines the upper bound of value sizes that 57 | can be stored in the tree itself. Values exceeding this limit are stored in 58 | overflow pages. Each overflowing value occupies at least a full page. 59 | 60 | Iterating 61 | --------- 62 | 63 | Since keys are kept in order, it is very efficient to retrieve elements in 64 | order: 65 | 66 | .. code:: python 67 | 68 | >>> for i in tree: 69 | ... print(i) 70 | ... 71 | 1 72 | 2 73 | >>> for key, value in tree.items(): 74 | ... print(key, value) 75 | ... 76 | 1 b'foo' 77 | 2 b'bar' 78 | 79 | It is also possible to iterate over a subset of the tree by giving a Python 80 | slice: 81 | 82 | .. code:: python 83 | 84 | >>> for key, value in tree.items(slice(start=0, stop=10)): 85 | ... print(key, value) 86 | ... 87 | 1 b'foo' 88 | 2 b'bar' 89 | 90 | Both methods use a generator so they don't require loading the whole content 91 | in memory, but copying a slice of the tree into a dict is also possible: 92 | 93 | .. code:: python 94 | 95 | >>> tree[0:10] 96 | {1: b'foo', 2: b'bar'} 97 | 98 | 99 | Concurrency 100 | ----------- 101 | 102 | The tree is thread-safe, it follows the multiple readers/single writer pattern. 103 | 104 | It is safe to: 105 | 106 | - Share an instance of a ``BPlusTree`` between multiple threads 107 | 108 | It is NOT safe to: 109 | 110 | - Share an instance of a ``BPlusTree`` between multiple processes 111 | - Create multiple instances of ``BPlusTree`` pointing to the same file 112 | 113 | Durability 114 | ---------- 115 | 116 | A write-ahead log (WAL) is used to ensure that the data is safe. All changes 117 | made to the tree are appended to the WAL and only merged into the tree in an 118 | operation called a checkpoint, usually when the tree is closed. This approach 119 | is heavily inspired by other databases like SQLite. 120 | 121 | If tree doesn't get closed properly (power outage, process killed...) the WAL 122 | file is merged the next time the tree is opened. 123 | 124 | Performances 125 | ------------ 126 | 127 | Like any database, there are many knobs to finely tune the engine and get the 128 | best performance out of it: 129 | 130 | - ``order``, or branching factor, defines how many entries each node will hold 131 | - ``page_size`` is the amount of bytes allocated to a node and the length of 132 | read and write operations. It is best to keep it close to the block size of 133 | the disk 134 | - ``cache_size`` to keep frequently used nodes at hand. Big caches prevent the 135 | expensive operation of creating Python objects from raw pages but use more 136 | memory 137 | 138 | Some advices to efficiently use the tree: 139 | 140 | - Insert elements in ascending order if possible, prefer UUID v1 to UUID v4 141 | - Insert in batch with ``tree.batch_insert(iterator)`` instead of using 142 | ``tree.insert()`` in a loop 143 | - Let the tree iterate for you instead of using ``tree.get()`` in a loop 144 | - Use ``tree.checkpoint()`` from time to time if you insert a lot, this will 145 | prevent the WAL from growing unbounded 146 | - Use small keys and values, set their limit and overflow values accordingly 147 | - Store the file and WAL on a fast disk 148 | 149 | License 150 | ------- 151 | 152 | MIT 153 | -------------------------------------------------------------------------------- /tests/test_node.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from bplustree.const import TreeConf, ENDIAN 4 | from bplustree.entry import Record, Reference, OpaqueData 5 | from bplustree.node import (Node, LonelyRootNode, RootNode, InternalNode, 6 | LeafNode, FreelistNode, OverflowNode) 7 | from bplustree.serializer import IntSerializer 8 | 9 | tree_conf = TreeConf(4096, 7, 16, 16, IntSerializer()) 10 | 11 | 12 | @pytest.mark.parametrize('klass,order,min_children,max_children', [ 13 | (LonelyRootNode, 7, 0, 6), 14 | (LonelyRootNode, 100, 0, 99), 15 | (RootNode, 7, 2, 7), 16 | (RootNode, 100, 2, 100), 17 | (InternalNode, 7, 4, 7), 18 | (InternalNode, 100, 50, 100), 19 | (LeafNode, 7, 3, 6), 20 | (LeafNode, 100, 49, 99), 21 | ]) 22 | def test_node_limit_children(klass, order, min_children, max_children): 23 | node = klass(TreeConf(4096, order, 16, 16, IntSerializer())) 24 | assert node.min_children == min_children 25 | assert node.max_children == max_children 26 | 27 | 28 | @pytest.mark.parametrize('klass', [ 29 | LeafNode, InternalNode, RootNode, LonelyRootNode, 30 | ]) 31 | def test_empty_node_serialization(klass): 32 | n1 = klass(tree_conf) 33 | data = n1.dump() 34 | 35 | n2 = klass(tree_conf, data=data) 36 | assert n1.entries == n2.entries 37 | 38 | n3 = Node.from_page_data(tree_conf, data) 39 | assert isinstance(n3, klass) 40 | assert n1.entries == n3.entries 41 | 42 | 43 | def test_leaf_node_serialization(): 44 | n1 = LeafNode(tree_conf, next_page=66) 45 | n1.insert_entry(Record(tree_conf, 43, b'43')) 46 | n1.insert_entry(Record(tree_conf, 42, b'42')) 47 | assert n1.entries == [Record(tree_conf, 42, b'42'), 48 | Record(tree_conf, 43, b'43')] 49 | data = n1.dump() 50 | 51 | n2 = LeafNode(tree_conf, data=data) 52 | assert n1.entries == n2.entries 53 | assert n1.next_page == n2.next_page == 66 54 | 55 | 56 | def test_leaf_node_serialization_no_next_page(): 57 | n1 = LeafNode(tree_conf) 58 | data = n1.dump() 59 | 60 | n2 = LeafNode(tree_conf, data=data) 61 | assert n1.next_page is n2.next_page is None 62 | 63 | 64 | def test_root_node_serialization(): 65 | n1 = RootNode(tree_conf) 66 | n1.insert_entry(Reference(tree_conf, 43, 2, 3)) 67 | n1.insert_entry(Reference(tree_conf, 42, 1, 2)) 68 | assert n1.entries == [Reference(tree_conf, 42, 1, 2), 69 | Reference(tree_conf, 43, 2, 3)] 70 | data = n1.dump() 71 | 72 | n2 = RootNode(tree_conf, data=data) 73 | assert n1.entries == n2.entries 74 | assert n1.next_page is n2.next_page is None 75 | 76 | 77 | def test_node_slots(): 78 | n1 = RootNode(tree_conf) 79 | with pytest.raises(AttributeError): 80 | n1.foo = True 81 | 82 | 83 | def test_get_node_from_page_data(): 84 | data = (2).to_bytes(1, ENDIAN) + bytes(4096 - 1) 85 | tree_conf = TreeConf(4096, 7, 16, 16, IntSerializer()) 86 | assert isinstance( 87 | Node.from_page_data(tree_conf, data, 4), 88 | RootNode 89 | ) 90 | 91 | 92 | def test_insert_find_get_remove_entries(): 93 | node = RootNode(tree_conf) 94 | 95 | # Test empty _find_entry_index, get and remove 96 | with pytest.raises(ValueError): 97 | node._find_entry_index(42) 98 | with pytest.raises(ValueError): 99 | node.get_entry(42) 100 | with pytest.raises(ValueError): 101 | node.remove_entry(42) 102 | 103 | # Test insert_entry 104 | r42, r43 = Reference(tree_conf, 42, 1, 2), Reference(tree_conf, 43, 2, 3) 105 | node.insert_entry_at_the_end(r43) 106 | node.insert_entry(r42) 107 | assert sorted(node.entries) == node.entries 108 | 109 | # Test _find_entry_index 110 | assert node._find_entry_index(42) == 0 111 | assert node._find_entry_index(43) == 1 112 | 113 | # Test _get_entry 114 | assert node.get_entry(42) == r42 115 | assert node.get_entry(43) == r43 116 | 117 | node.remove_entry(43) 118 | assert node.entries == [r42] 119 | node.remove_entry(42) 120 | assert node.entries == [] 121 | 122 | 123 | def test_smallest_biggest(): 124 | node = RootNode(tree_conf) 125 | 126 | with pytest.raises(IndexError): 127 | node.pop_smallest() 128 | 129 | r42, r43 = Reference(tree_conf, 42, 1, 2), Reference(tree_conf, 43, 2, 3) 130 | node.insert_entry(r43) 131 | node.insert_entry(r42) 132 | 133 | # Smallest 134 | assert node.smallest_entry == r42 135 | assert node.smallest_key == 42 136 | 137 | # Biggest 138 | assert node.biggest_entry == r43 139 | assert node.biggest_key == 43 140 | 141 | assert node.pop_smallest() == r42 142 | assert node.entries == [r43] 143 | 144 | 145 | def test_freelist_node_serialization(): 146 | n1 = FreelistNode(tree_conf, next_page=3) 147 | data = n1.dump() 148 | 149 | n2 = FreelistNode(tree_conf, data=data) 150 | assert n1.next_page == n2.next_page 151 | 152 | 153 | def test_freelist_node_serialization_no_next_page(): 154 | n1 = FreelistNode(tree_conf, next_page=None) 155 | data = n1.dump() 156 | 157 | n2 = FreelistNode(tree_conf, data=data) 158 | assert n1.next_page is n2.next_page is None 159 | 160 | 161 | def test_overflow_node_serialization(): 162 | n1 = OverflowNode(tree_conf, next_page=3) 163 | n1.insert_entry_at_the_end(OpaqueData(data=b'foo')) 164 | data = n1.dump() 165 | 166 | n2 = OverflowNode(tree_conf, data=data) 167 | assert n1.next_page == n2.next_page 168 | 169 | 170 | def test_overflow_node_serialization_no_next_page(): 171 | n1 = OverflowNode(tree_conf, next_page=None) 172 | n1.insert_entry_at_the_end(OpaqueData(data=b'foo')) 173 | data = n1.dump() 174 | 175 | n2 = OverflowNode(tree_conf, data=data) 176 | assert n1.next_page is n2.next_page is None 177 | -------------------------------------------------------------------------------- /tests/test_memory.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import platform 4 | from unittest import mock 5 | 6 | import pytest 7 | 8 | from bplustree.node import LeafNode, FreelistNode 9 | from bplustree.memory import ( 10 | FileMemory, open_file_in_dir, WAL, ReachedEndOfFile, write_to_file 11 | ) 12 | from bplustree.const import TreeConf 13 | from .conftest import filename 14 | from bplustree.serializer import IntSerializer 15 | 16 | tree_conf = TreeConf(4096, 4, 16, 16, IntSerializer()) 17 | node = LeafNode(tree_conf, page=3) 18 | 19 | 20 | def test_file_memory_node(): 21 | mem = FileMemory(filename, tree_conf) 22 | 23 | with pytest.raises(ReachedEndOfFile): 24 | mem.get_node(3) 25 | 26 | mem.set_node(node) 27 | assert node == mem.get_node(3) 28 | 29 | mem.close() 30 | 31 | 32 | def test_file_memory_metadata(): 33 | mem = FileMemory(filename, tree_conf) 34 | with pytest.raises(ValueError): 35 | mem.get_metadata() 36 | mem.set_metadata(6, tree_conf) 37 | assert mem.get_metadata() == (6, tree_conf) 38 | 39 | 40 | def test_file_memory_next_available_page(): 41 | mem = FileMemory(filename, tree_conf) 42 | for i in range(1, 100): 43 | assert mem.next_available_page == i 44 | 45 | 46 | def test_file_memory_freelist(): 47 | mem = FileMemory(filename, tree_conf) 48 | assert mem.next_available_page == 1 49 | assert mem._traverse_free_list() == (None, None) 50 | 51 | mem.del_page(1) 52 | assert mem._traverse_free_list() == ( 53 | None, FreelistNode(tree_conf, page=1, next_page=None) 54 | ) 55 | assert mem.next_available_page == 1 56 | assert mem._traverse_free_list() == (None, None) 57 | 58 | mem.del_page(1) 59 | mem.del_page(2) 60 | assert mem._traverse_free_list() == ( 61 | FreelistNode(tree_conf, page=1, next_page=2), 62 | FreelistNode(tree_conf, page=2, next_page=None) 63 | ) 64 | mem.del_page(3) 65 | assert mem._traverse_free_list() == ( 66 | FreelistNode(tree_conf, page=2, next_page=3), 67 | FreelistNode(tree_conf, page=3, next_page=None) 68 | ) 69 | 70 | assert mem._pop_from_freelist() == 3 71 | assert mem._pop_from_freelist() == 2 72 | assert mem._pop_from_freelist() == 1 73 | assert mem._pop_from_freelist() is None 74 | 75 | 76 | def test_open_file_in_dir(): 77 | with pytest.raises(ValueError): 78 | open_file_in_dir('/foo/bar/does/not/exist') 79 | 80 | # Create file and re-open 81 | for _ in range(2): 82 | file_fd, dir_fd = open_file_in_dir(filename) 83 | 84 | assert isinstance(file_fd, io.FileIO) 85 | file_fd.close() 86 | 87 | if platform.system() == 'Windows': 88 | assert dir_fd is None 89 | else: 90 | assert isinstance(dir_fd, int) 91 | os.close(dir_fd) 92 | 93 | 94 | def test_write_to_file_multi_times(): 95 | def side_effect(*args, **kwargs): 96 | if len(args) == 1: 97 | data = args[0] 98 | if len(data) > 5: 99 | return 5 100 | else: 101 | return len(data) 102 | 103 | mock_fd = mock.MagicMock() 104 | mock_fd.write.side_effect = side_effect 105 | 106 | write_to_file(mock_fd, None, b'abcdefg') 107 | 108 | 109 | @mock.patch('bplustree.memory.platform.system', return_value='Windows') 110 | def test_open_file_in_dir_windows(_): 111 | file_fd, dir_fd = open_file_in_dir(filename) 112 | assert isinstance(file_fd, io.FileIO) 113 | file_fd.close() 114 | assert dir_fd is None 115 | 116 | 117 | def test_file_memory_write_transaction(): 118 | mem = FileMemory(filename, tree_conf) 119 | mem._lock = mock.Mock() 120 | 121 | assert mem._wal._not_committed_pages == {} 122 | assert mem._wal._committed_pages == {} 123 | 124 | with mem.write_transaction: 125 | mem.set_node(node) 126 | assert mem._wal._not_committed_pages == {3: 9} 127 | assert mem._wal._committed_pages == {} 128 | assert mem._lock.writer_lock.acquire.call_count == 1 129 | 130 | assert mem._wal._not_committed_pages == {} 131 | assert mem._wal._committed_pages == {3: 9} 132 | assert mem._lock.writer_lock.release.call_count == 1 133 | assert mem._lock.reader_lock.acquire.call_count == 0 134 | 135 | with mem.read_transaction: 136 | assert mem._lock.reader_lock.acquire.call_count == 1 137 | assert node == mem.get_node(3) 138 | 139 | assert mem._lock.reader_lock.release.call_count == 1 140 | mem.close() 141 | 142 | 143 | def test_file_memory_write_transaction_error(): 144 | mem = FileMemory(filename, tree_conf) 145 | mem._lock = mock.Mock() 146 | mem._cache[424242] = node 147 | 148 | with pytest.raises(ValueError): 149 | with mem.write_transaction: 150 | mem.set_node(node) 151 | assert mem._wal._not_committed_pages == {3: 9} 152 | assert mem._wal._committed_pages == {} 153 | assert mem._lock.writer_lock.acquire.call_count == 1 154 | raise ValueError('Foo') 155 | 156 | assert mem._wal._not_committed_pages == {} 157 | assert mem._wal._committed_pages == {} 158 | assert mem._lock.writer_lock.release.call_count == 1 159 | assert mem._cache.get(424242) is None 160 | 161 | 162 | def test_file_memory_repr(): 163 | mem = FileMemory(filename, tree_conf) 164 | assert repr(mem) == ''.format(filename) 165 | mem.close() 166 | 167 | 168 | def test_wal_create_reopen_empty(): 169 | WAL(filename, 64) 170 | 171 | wal = WAL(filename, 64) 172 | assert wal._page_size == 64 173 | 174 | 175 | def test_wal_create_reopen_uncommitted(): 176 | wal = WAL(filename, 64) 177 | wal.set_page(1, b'1' * 64) 178 | wal.commit() 179 | wal.set_page(2, b'2' * 64) 180 | assert wal.get_page(1) == b'1' * 64 181 | assert wal.get_page(2) == b'2' * 64 182 | 183 | wal = WAL(filename, 64) 184 | assert wal.get_page(1) == b'1' * 64 185 | assert wal.get_page(2) is None 186 | 187 | 188 | def test_wal_rollback(): 189 | wal = WAL(filename, 64) 190 | wal.set_page(1, b'1' * 64) 191 | wal.commit() 192 | wal.set_page(2, b'2' * 64) 193 | assert wal.get_page(1) == b'1' * 64 194 | assert wal.get_page(2) == b'2' * 64 195 | 196 | wal.rollback() 197 | assert wal.get_page(1) == b'1' * 64 198 | assert wal.get_page(2) is None 199 | 200 | 201 | def test_wal_checkpoint(): 202 | wal = WAL(filename, 64) 203 | wal.set_page(1, b'1' * 64) 204 | wal.commit() 205 | wal.set_page(2, b'2' * 64) 206 | 207 | rv = wal.checkpoint() 208 | assert list(rv) == [(1, b'1' * 64)] 209 | 210 | with pytest.raises(ValueError): 211 | wal.set_page(3, b'3' * 64) 212 | 213 | assert os.path.isfile(filename + '-wal') is False 214 | 215 | 216 | def test_wal_repr(): 217 | wal = WAL(filename, 64) 218 | assert repr(wal) == ''.format(filename) 219 | -------------------------------------------------------------------------------- /bplustree/entry.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Optional 3 | 4 | from .const import (ENDIAN, PAGE_REFERENCE_BYTES, 5 | USED_KEY_LENGTH_BYTES, USED_VALUE_LENGTH_BYTES, TreeConf) 6 | 7 | 8 | # Sentinel value indicating that a lazy loaded attribute is not yet loaded 9 | NOT_LOADED = object() 10 | 11 | 12 | class Entry(metaclass=abc.ABCMeta): 13 | 14 | __slots__ = [] 15 | 16 | @abc.abstractmethod 17 | def load(self, data: bytes): 18 | """Deserialize data into an object.""" 19 | 20 | @abc.abstractmethod 21 | def dump(self) -> bytes: 22 | """Serialize object to data.""" 23 | 24 | 25 | class ComparableEntry(Entry, metaclass=abc.ABCMeta): 26 | """Entry that can be sorted against other entries based on their key.""" 27 | 28 | __slots__ = [] 29 | 30 | def __eq__(self, other): 31 | return self.key == other.key 32 | 33 | def __lt__(self, other): 34 | return self.key < other.key 35 | 36 | def __le__(self, other): 37 | return self.key <= other.key 38 | 39 | def __gt__(self, other): 40 | return self.key > other.key 41 | 42 | def __ge__(self, other): 43 | return self.key >= other.key 44 | 45 | 46 | class Record(ComparableEntry): 47 | """A container for the actual data the tree stores.""" 48 | 49 | __slots__ = ['_tree_conf', 'length', '_key', '_value', '_overflow_page', 50 | '_data'] 51 | 52 | def __init__(self, tree_conf: TreeConf, key=None, 53 | value: Optional[bytes]=None, data: Optional[bytes]=None, 54 | overflow_page: Optional[int]=None): 55 | self._tree_conf = tree_conf 56 | self.length = ( 57 | USED_KEY_LENGTH_BYTES + self._tree_conf.key_size + 58 | USED_VALUE_LENGTH_BYTES + self._tree_conf.value_size + 59 | PAGE_REFERENCE_BYTES 60 | ) 61 | self._data = data 62 | 63 | if self._data: 64 | self._key = NOT_LOADED 65 | self._value = NOT_LOADED 66 | self._overflow_page = NOT_LOADED 67 | else: 68 | self._key = key 69 | self._value = value 70 | self._overflow_page = overflow_page 71 | 72 | @property 73 | def key(self): 74 | if self._key == NOT_LOADED: 75 | self.load(self._data) 76 | return self._key 77 | 78 | @key.setter 79 | def key(self, v): 80 | self._data = None 81 | self._key = v 82 | 83 | @property 84 | def value(self): 85 | if self._value == NOT_LOADED: 86 | self.load(self._data) 87 | return self._value 88 | 89 | @value.setter 90 | def value(self, v): 91 | self._data = None 92 | self._value = v 93 | 94 | @property 95 | def overflow_page(self): 96 | if self._overflow_page == NOT_LOADED: 97 | self.load(self._data) 98 | return self._overflow_page 99 | 100 | @overflow_page.setter 101 | def overflow_page(self, v): 102 | self._data = None 103 | self._overflow_page = v 104 | 105 | def load(self, data: bytes): 106 | assert len(data) == self.length 107 | 108 | end_used_key_length = USED_KEY_LENGTH_BYTES 109 | used_key_length = int.from_bytes(data[0:end_used_key_length], ENDIAN) 110 | assert 0 <= used_key_length <= self._tree_conf.key_size 111 | 112 | end_key = end_used_key_length + used_key_length 113 | self._key = self._tree_conf.serializer.deserialize( 114 | data[end_used_key_length:end_key] 115 | ) 116 | 117 | start_used_value_length = ( 118 | end_used_key_length + self._tree_conf.key_size 119 | ) 120 | end_used_value_length = ( 121 | start_used_value_length + USED_VALUE_LENGTH_BYTES 122 | ) 123 | used_value_length = int.from_bytes( 124 | data[start_used_value_length:end_used_value_length], ENDIAN 125 | ) 126 | assert 0 <= used_value_length <= self._tree_conf.value_size 127 | 128 | end_value = end_used_value_length + used_value_length 129 | 130 | start_overflow = end_used_value_length + self._tree_conf.value_size 131 | end_overflow = start_overflow + PAGE_REFERENCE_BYTES 132 | overflow_page = int.from_bytes( 133 | data[start_overflow:end_overflow], ENDIAN 134 | ) 135 | 136 | if overflow_page: 137 | self._overflow_page = overflow_page 138 | self._value = None 139 | else: 140 | self._overflow_page = None 141 | self._value = data[end_used_value_length:end_value] 142 | 143 | def dump(self) -> bytes: 144 | 145 | if self._data: 146 | return self._data 147 | 148 | assert self._value is None or self._overflow_page is None 149 | key_as_bytes = self._tree_conf.serializer.serialize( 150 | self._key, self._tree_conf.key_size 151 | ) 152 | used_key_length = len(key_as_bytes) 153 | overflow_page = self._overflow_page or 0 154 | if overflow_page: 155 | value = b'' 156 | else: 157 | value = self._value 158 | used_value_length = len(value) 159 | 160 | data = ( 161 | used_key_length.to_bytes(USED_VALUE_LENGTH_BYTES, ENDIAN) + 162 | key_as_bytes + 163 | bytes(self._tree_conf.key_size - used_key_length) + 164 | used_value_length.to_bytes(USED_VALUE_LENGTH_BYTES, ENDIAN) + 165 | value + 166 | bytes(self._tree_conf.value_size - used_value_length) + 167 | overflow_page.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) 168 | ) 169 | return data 170 | 171 | def __repr__(self): 172 | if self.overflow_page: 173 | return ''.format(self.key) 174 | if self.value: 175 | return ''.format( 176 | self.key, self.value[0:16] 177 | ) 178 | return ''.format(self.key) 179 | 180 | 181 | class Reference(ComparableEntry): 182 | """A container for a reference to other nodes.""" 183 | 184 | __slots__ = ['_tree_conf', 'length', '_key', '_before', '_after', '_data'] 185 | 186 | def __init__(self, tree_conf: TreeConf, key=None, before=None, after=None, 187 | data: bytes=None): 188 | self._tree_conf = tree_conf 189 | self.length = ( 190 | 2 * PAGE_REFERENCE_BYTES + 191 | USED_KEY_LENGTH_BYTES + 192 | self._tree_conf.key_size 193 | ) 194 | self._data = data 195 | 196 | if self._data: 197 | self._key = NOT_LOADED 198 | self._before = NOT_LOADED 199 | self._after = NOT_LOADED 200 | else: 201 | self._key = key 202 | self._before = before 203 | self._after = after 204 | 205 | @property 206 | def key(self): 207 | if self._key == NOT_LOADED: 208 | self.load(self._data) 209 | return self._key 210 | 211 | @key.setter 212 | def key(self, v): 213 | self._data = None 214 | self._key = v 215 | 216 | @property 217 | def before(self): 218 | if self._before == NOT_LOADED: 219 | self.load(self._data) 220 | return self._before 221 | 222 | @before.setter 223 | def before(self, v): 224 | self._data = None 225 | self._before = v 226 | 227 | @property 228 | def after(self): 229 | if self._after == NOT_LOADED: 230 | self.load(self._data) 231 | return self._after 232 | 233 | @after.setter 234 | def after(self, v): 235 | self._data = None 236 | self._after = v 237 | 238 | def load(self, data: bytes): 239 | assert len(data) == self.length 240 | end_before = PAGE_REFERENCE_BYTES 241 | self._before = int.from_bytes(data[0:end_before], ENDIAN) 242 | 243 | end_used_key_length = end_before + USED_KEY_LENGTH_BYTES 244 | used_key_length = int.from_bytes( 245 | data[end_before:end_used_key_length], ENDIAN 246 | ) 247 | assert 0 <= used_key_length <= self._tree_conf.key_size 248 | 249 | end_key = end_used_key_length + used_key_length 250 | self._key = self._tree_conf.serializer.deserialize( 251 | data[end_used_key_length:end_key] 252 | ) 253 | 254 | start_after = end_used_key_length + self._tree_conf.key_size 255 | end_after = start_after + PAGE_REFERENCE_BYTES 256 | self._after = int.from_bytes(data[start_after:end_after], ENDIAN) 257 | 258 | def dump(self) -> bytes: 259 | 260 | if self._data: 261 | return self._data 262 | 263 | assert isinstance(self._before, int) 264 | assert isinstance(self._after, int) 265 | 266 | key_as_bytes = self._tree_conf.serializer.serialize( 267 | self._key, self._tree_conf.key_size 268 | ) 269 | used_key_length = len(key_as_bytes) 270 | 271 | data = ( 272 | self._before.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) + 273 | used_key_length.to_bytes(USED_VALUE_LENGTH_BYTES, ENDIAN) + 274 | key_as_bytes + 275 | bytes(self._tree_conf.key_size - used_key_length) + 276 | self._after.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) 277 | ) 278 | return data 279 | 280 | def __repr__(self): 281 | return ''.format( 282 | self.key, self.before, self.after 283 | ) 284 | 285 | 286 | class OpaqueData(Entry): 287 | """Entry holding opaque data.""" 288 | 289 | __slots__ = ['data'] 290 | 291 | def __init__(self, tree_conf: TreeConf=None, data: bytes=None): 292 | self.data = data 293 | 294 | def load(self, data: bytes): 295 | self.data = data 296 | 297 | def dump(self) -> bytes: 298 | return self.data 299 | 300 | def __repr__(self): 301 | return ''.format(self.data) 302 | -------------------------------------------------------------------------------- /tests/test_tree.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone, timedelta 2 | import itertools 3 | from unittest import mock 4 | import uuid 5 | 6 | import pytest 7 | 8 | from bplustree.memory import FileMemory 9 | from bplustree.node import LonelyRootNode, LeafNode 10 | from bplustree.tree import BPlusTree 11 | from bplustree.serializer import ( 12 | IntSerializer, StrSerializer, UUIDSerializer, DatetimeUTCSerializer 13 | ) 14 | from .conftest import filename 15 | 16 | 17 | @pytest.fixture 18 | def b(): 19 | b = BPlusTree(filename, key_size=16, value_size=16, order=4) 20 | yield b 21 | b.close() 22 | 23 | 24 | def test_create_and_load_file(): 25 | b = BPlusTree(filename) 26 | assert isinstance(b._mem, FileMemory) 27 | b.insert(5, b'foo') 28 | b.close() 29 | 30 | b = BPlusTree(filename) 31 | assert isinstance(b._mem, FileMemory) 32 | assert b.get(5) == b'foo' 33 | b.close() 34 | 35 | 36 | @mock.patch('bplustree.tree.BPlusTree.close') 37 | def test_closing_context_manager(mock_close): 38 | with BPlusTree(filename, page_size=512, value_size=128) as b: 39 | pass 40 | mock_close.assert_called_once_with() 41 | 42 | 43 | def test_initial_values(): 44 | b = BPlusTree(filename, page_size=512, value_size=128) 45 | assert b._tree_conf.page_size == 512 46 | assert b._tree_conf.order == 100 47 | assert b._tree_conf.key_size == 8 48 | assert b._tree_conf.value_size == 128 49 | b.close() 50 | 51 | 52 | def test_partial_constructors(b): 53 | node = b.RootNode() 54 | record = b.Record() 55 | assert node._tree_conf == b._tree_conf 56 | assert record._tree_conf == b._tree_conf 57 | 58 | 59 | def test_insert_setitem_tree(b): 60 | b.insert(1, b'foo') 61 | 62 | with pytest.raises(ValueError): 63 | b.insert(1, b'bar') 64 | assert b.get(1) == b'foo' 65 | 66 | b.insert(1, b'baz', replace=True) 67 | assert b.get(1) == b'baz' 68 | 69 | b[1] = b'foo' 70 | assert b.get(1) == b'foo' 71 | 72 | 73 | def test_get_tree(b): 74 | b.insert(1, b'foo') 75 | assert b.get(1) == b'foo' 76 | assert b.get(2) is None 77 | assert b.get(2, 'bar') == 'bar' 78 | 79 | 80 | def test_getitem_tree(b): 81 | b.insert(1, b'foo') 82 | b.insert(2, b'bar') 83 | b.insert(5, b'baz') 84 | 85 | assert b[1] == b'foo' 86 | with pytest.raises(KeyError): 87 | _ = b[4] 88 | 89 | assert b[1:3] == {1: b'foo', 2: b'bar'} 90 | assert b[0:10] == {1: b'foo', 2: b'bar', 5: b'baz'} 91 | 92 | 93 | def test_contains_tree(b): 94 | b.insert(1, b'foo') 95 | assert 1 in b 96 | assert 2 not in b 97 | 98 | 99 | def test_len_tree(b): 100 | assert len(b) == 0 101 | b.insert(1, b'foo') 102 | assert len(b) == 1 103 | for i in range(2, 101): 104 | b.insert(i, str(i).encode()) 105 | assert len(b) == 100 106 | 107 | 108 | def test_length_hint_tree(): 109 | b = BPlusTree(filename, key_size=16, value_size=16, order=100) 110 | assert b.__length_hint__() == 49 111 | b.insert(1, b'foo') 112 | assert b.__length_hint__() == 49 113 | for i in range(2, 10001): 114 | b.insert(i, str(i).encode()) 115 | assert b.__length_hint__() == 7242 116 | b.close() 117 | 118 | 119 | def test_bool_tree(b): 120 | assert not b 121 | b.insert(1, b'foo') 122 | assert b 123 | 124 | 125 | def test_iter_keys_values_items_tree(b): 126 | # Empty tree 127 | iter = b.__iter__() 128 | with pytest.raises(StopIteration): 129 | next(iter) 130 | 131 | # Insert in reverse... 132 | for i in range(1000, 0, -1): 133 | b.insert(i, str(i).encode()) 134 | # ...iter in order 135 | previous = 0 136 | for i in b: 137 | assert i == previous + 1 138 | previous += 1 139 | 140 | # Test .keys() 141 | previous = 0 142 | for i in b.keys(): 143 | assert i == previous + 1 144 | previous += 1 145 | 146 | # Test slice .keys() 147 | assert list(b.keys(slice(10, 13))) == [10, 11, 12] 148 | 149 | # Test .values() 150 | previous = 0 151 | for i in b.values(): 152 | assert int(i.decode()) == previous + 1 153 | previous += 1 154 | 155 | # Test slice .values() 156 | assert list(b.values(slice(10, 13))) == [b'10', b'11', b'12'] 157 | 158 | # Test .items() 159 | previous = 0 160 | for k, v in b.items(): 161 | expected = previous + 1 162 | assert (k, int(v.decode())) == (expected, expected) 163 | previous += 1 164 | 165 | # Test slice .items() 166 | expected = [(10, b'10'), (11, b'11'), (12, b'12')] 167 | assert list(b.items(slice(10, 13))) == expected 168 | 169 | 170 | def test_iter_slice(b): 171 | with pytest.raises(ValueError): 172 | next(b._iter_slice(slice(None, None, -1))) 173 | 174 | with pytest.raises(ValueError): 175 | next(b._iter_slice(slice(10, 0, None))) 176 | 177 | # Contains from 0 to 9 included 178 | for i in range(10): 179 | b.insert(i, str(i).encode()) 180 | 181 | iter = b._iter_slice(slice(None, 2)) 182 | assert next(iter).key == 0 183 | assert next(iter).key == 1 184 | with pytest.raises(StopIteration): 185 | next(iter) 186 | 187 | iter = b._iter_slice(slice(5, 7)) 188 | assert next(iter).key == 5 189 | assert next(iter).key == 6 190 | with pytest.raises(StopIteration): 191 | next(iter) 192 | 193 | iter = b._iter_slice(slice(8, 9)) 194 | assert next(iter).key == 8 195 | with pytest.raises(StopIteration): 196 | next(iter) 197 | 198 | iter = b._iter_slice(slice(9, 12)) 199 | assert next(iter).key == 9 200 | with pytest.raises(StopIteration): 201 | next(iter) 202 | 203 | iter = b._iter_slice(slice(15, 17)) 204 | with pytest.raises(StopIteration): 205 | next(iter) 206 | 207 | iter = b._iter_slice(slice(-2, 17)) 208 | assert next(iter).key == 0 209 | 210 | b.close() 211 | 212 | # Contains from 10, 20, 30 .. 200 213 | b = BPlusTree(filename, order=5) 214 | for i in range(10, 201, 10): 215 | b.insert(i, str(i).encode()) 216 | 217 | iter = b._iter_slice(slice(65, 85)) 218 | assert next(iter).key == 70 219 | assert next(iter).key == 80 220 | with pytest.raises(StopIteration): 221 | next(iter) 222 | 223 | 224 | def test_checkpoint(b): 225 | b.checkpoint() 226 | b.insert(1, b'foo') 227 | assert not b._mem._wal._not_committed_pages 228 | assert b._mem._wal._committed_pages 229 | 230 | b.checkpoint() 231 | assert not b._mem._wal._not_committed_pages 232 | assert not b._mem._wal._committed_pages 233 | 234 | 235 | def test_left_record_node_in_tree(): 236 | b = BPlusTree(filename, order=3) 237 | assert b._left_record_node == b._root_node 238 | assert isinstance(b._left_record_node, LonelyRootNode) 239 | b.insert(1, b'1') 240 | b.insert(2, b'2') 241 | b.insert(3, b'3') 242 | assert isinstance(b._left_record_node, LeafNode) 243 | b.close() 244 | 245 | 246 | iterators = [ 247 | range(0, 1000, 1), 248 | range(1000, 0, -1), 249 | list(range(0, 1000, 2)) + list(range(1, 1000, 2)) 250 | ] 251 | orders = [3, 4, 50] 252 | page_sizes = [4096, 8192] 253 | key_sizes = [4, 16] 254 | values_sizes = [1, 16] 255 | serializer_class = [IntSerializer, StrSerializer] 256 | cache_sizes = [0, 50] 257 | matrix = itertools.product(iterators, orders, page_sizes, key_sizes, 258 | values_sizes, serializer_class, cache_sizes) 259 | 260 | 261 | @pytest.mark.parametrize( 262 | 'iterator,order,page_size,k_size,v_size,serialize_class,cache_size', matrix 263 | ) 264 | def test_insert_split_in_tree(iterator, order, page_size, k_size, v_size, 265 | serialize_class, cache_size): 266 | 267 | inserted = list() 268 | for i in iterator: 269 | v = str(i).encode() 270 | k = i 271 | if serialize_class is StrSerializer: 272 | k = str(i) 273 | inserted.append((k, v)) 274 | 275 | b = BPlusTree(filename, order=order, page_size=page_size, 276 | key_size=k_size, value_size=v_size, cache_size=cache_size, 277 | serializer=serialize_class()) 278 | 279 | if sorted(inserted) == inserted: 280 | b.batch_insert(inserted) 281 | else: 282 | for k, v in inserted: 283 | b.insert(k, v) 284 | 285 | # Reload tree from file before checking values 286 | b.close() 287 | b = BPlusTree(filename, order=order, page_size=page_size, 288 | key_size=k_size, value_size=v_size, cache_size=cache_size, 289 | serializer=serialize_class()) 290 | 291 | for k, v in inserted: 292 | assert b.get(k) == v 293 | 294 | b.close() 295 | 296 | 297 | def test_insert_split_in_tree_uuid(): 298 | # Not in the test matrix because the iterators don't really make sense 299 | test_insert_split_in_tree( 300 | [uuid.uuid4() for _ in range(1000)], 301 | 20, 302 | 4096, 303 | 16, 304 | 40, 305 | UUIDSerializer, 306 | 50 307 | ) 308 | 309 | 310 | def test_insert_split_in_tree_datetime_utc(): 311 | dt = datetime(2018, 1, 6, 21, 42, 2, 424739, tzinfo=timezone.utc) 312 | test_insert_split_in_tree( 313 | [dt + timedelta(minutes=i) for i in range(1000)], 314 | 20, 315 | 2048, 316 | 8, 317 | 40, 318 | DatetimeUTCSerializer, 319 | 50 320 | ) 321 | 322 | 323 | def test_overflow(b): 324 | data = b'f' * 323343 325 | with b._mem.write_transaction: 326 | first_overflow_page = b._create_overflow(data) 327 | assert b._read_from_overflow(first_overflow_page) == data 328 | 329 | with b._mem.read_transaction: 330 | assert b._read_from_overflow(first_overflow_page) == data 331 | 332 | assert b._mem.last_page == 81 333 | 334 | with b._mem.write_transaction: 335 | b._delete_overflow(first_overflow_page) 336 | 337 | with b._mem.write_transaction: 338 | for i in range(81, 2, -1): 339 | assert b._mem.next_available_page == i 340 | 341 | 342 | def test_batch_insert(b): 343 | def generate(from_, to): 344 | for i in range(from_, to): 345 | yield i, str(i).encode() 346 | 347 | b.batch_insert(generate(0, 1000)) 348 | b.batch_insert(generate(1000, 2000)) 349 | 350 | i = 0 351 | for k, v in b.items(): 352 | assert k == i 353 | assert v == str(i).encode() 354 | i += 1 355 | assert i == 2000 356 | 357 | 358 | def test_batch_insert_no_in_order(b): 359 | with pytest.raises(ValueError): 360 | b.batch_insert([(2, b'2'), (1, b'1')]) 361 | assert b.get(1) is None 362 | assert b.get(2) is None 363 | 364 | b.insert(2, b'2') 365 | with pytest.raises(ValueError): 366 | b.batch_insert([(1, b'1')]) 367 | 368 | with pytest.raises(ValueError): 369 | b.batch_insert([(2, b'2')]) 370 | 371 | assert b.get(1) is None 372 | assert b.get(2) == b'2' 373 | -------------------------------------------------------------------------------- /bplustree/node.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import bisect 3 | import math 4 | from typing import Optional 5 | 6 | from .const import (ENDIAN, NODE_TYPE_BYTES, USED_PAGE_LENGTH_BYTES, 7 | PAGE_REFERENCE_BYTES, TreeConf) 8 | from .entry import Entry, Record, Reference, OpaqueData 9 | 10 | 11 | class Node(metaclass=abc.ABCMeta): 12 | 13 | __slots__ = ['_tree_conf', 'entries', 'page', 'parent', 'next_page'] 14 | 15 | # Attributes to redefine in inherited classes 16 | _node_type_int = 0 17 | max_children = 0 18 | min_children = 0 19 | _entry_class = None 20 | 21 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 22 | page: int=None, parent: 'Node'=None, next_page: int=None): 23 | self._tree_conf = tree_conf 24 | self.entries = list() 25 | self.page = page 26 | self.parent = parent 27 | self.next_page = next_page 28 | if data: 29 | self.load(data) 30 | 31 | def load(self, data: bytes): 32 | assert len(data) == self._tree_conf.page_size 33 | end_used_page_length = NODE_TYPE_BYTES + USED_PAGE_LENGTH_BYTES 34 | used_page_length = int.from_bytes( 35 | data[NODE_TYPE_BYTES:end_used_page_length], ENDIAN 36 | ) 37 | end_header = end_used_page_length + PAGE_REFERENCE_BYTES 38 | self.next_page = int.from_bytes( 39 | data[end_used_page_length:end_header], ENDIAN 40 | ) 41 | if self.next_page == 0: 42 | self.next_page = None 43 | 44 | if self._entry_class is None: 45 | # For Nodes that cannot hold Entries 46 | return 47 | 48 | try: 49 | # For Nodes that can hold multiple sized Entries 50 | entry_length = self._entry_class(self._tree_conf).length 51 | except AttributeError: 52 | # For Nodes that can hold a single variable sized Entry 53 | entry_length = used_page_length - end_header 54 | 55 | for start_offset in range(end_header, used_page_length, entry_length): 56 | entry_data = data[start_offset:start_offset+entry_length] 57 | entry = self._entry_class(self._tree_conf, data=entry_data) 58 | self.entries.append(entry) 59 | 60 | def dump(self) -> bytearray: 61 | data = bytearray() 62 | for record in self.entries: 63 | data.extend(record.dump()) 64 | 65 | # used_page_length = len(header) + len(data), but the header is 66 | # generated later 67 | used_page_length = len(data) + 4 + PAGE_REFERENCE_BYTES 68 | assert 0 < used_page_length <= self._tree_conf.page_size 69 | assert len(data) <= self.max_payload 70 | 71 | next_page = 0 if self.next_page is None else self.next_page 72 | header = ( 73 | self._node_type_int.to_bytes(1, ENDIAN) + 74 | used_page_length.to_bytes(3, ENDIAN) + 75 | next_page.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) 76 | ) 77 | 78 | data = bytearray(header) + data 79 | 80 | padding = self._tree_conf.page_size - used_page_length 81 | assert padding >= 0 82 | data.extend(bytearray(padding)) 83 | assert len(data) == self._tree_conf.page_size 84 | 85 | return data 86 | 87 | @property 88 | def max_payload(self) -> int: 89 | """Size in bytes of serialized payload a Node can carry.""" 90 | return ( 91 | self._tree_conf.page_size - 4 - PAGE_REFERENCE_BYTES 92 | ) 93 | 94 | @property 95 | def can_add_entry(self) -> bool: 96 | return self.num_children < self.max_children 97 | 98 | @property 99 | def can_delete_entry(self) -> bool: 100 | return self.num_children > self.min_children 101 | 102 | @property 103 | def smallest_key(self): 104 | return self.smallest_entry.key 105 | 106 | @property 107 | def smallest_entry(self): 108 | return self.entries[0] 109 | 110 | @property 111 | def biggest_key(self): 112 | return self.biggest_entry.key 113 | 114 | @property 115 | def biggest_entry(self): 116 | return self.entries[-1] 117 | 118 | @property 119 | def num_children(self) -> int: 120 | """Number of entries or other nodes connected to the node.""" 121 | return len(self.entries) 122 | 123 | def pop_smallest(self) -> Entry: 124 | """Remove and return the smallest entry.""" 125 | return self.entries.pop(0) 126 | 127 | def insert_entry(self, entry: Entry): 128 | bisect.insort(self.entries, entry) 129 | 130 | def insert_entry_at_the_end(self, entry: Entry): 131 | """Insert an entry at the end of the entry list. 132 | 133 | This is an optimized version of `insert_entry` when it is known that 134 | the key to insert is bigger than any other entries. 135 | """ 136 | self.entries.append(entry) 137 | 138 | def remove_entry(self, key): 139 | self.entries.pop(self._find_entry_index(key)) 140 | 141 | def get_entry(self, key) -> Entry: 142 | return self.entries[self._find_entry_index(key)] 143 | 144 | def _find_entry_index(self, key) -> int: 145 | entry = self._entry_class( 146 | self._tree_conf, 147 | key=key # Hack to compare and order 148 | ) 149 | i = bisect.bisect_left(self.entries, entry) 150 | if i != len(self.entries) and self.entries[i] == entry: 151 | return i 152 | raise ValueError('No entry for key {}'.format(key)) 153 | 154 | def split_entries(self) -> list: 155 | """Split the entries in half. 156 | 157 | Keep the lower part in the node and return the upper one. 158 | """ 159 | len_entries = len(self.entries) 160 | rv = self.entries[len_entries//2:] 161 | self.entries = self.entries[:len_entries//2] 162 | assert len(self.entries) + len(rv) == len_entries 163 | return rv 164 | 165 | @classmethod 166 | def from_page_data(cls, tree_conf: TreeConf, data: bytes, 167 | page: int=None) -> 'Node': 168 | node_type_byte = data[0:NODE_TYPE_BYTES] 169 | node_type_int = int.from_bytes(node_type_byte, ENDIAN) 170 | if node_type_int == 1: 171 | return LonelyRootNode(tree_conf, data, page) 172 | elif node_type_int == 2: 173 | return RootNode(tree_conf, data, page) 174 | elif node_type_int == 3: 175 | return InternalNode(tree_conf, data, page) 176 | elif node_type_int == 4: 177 | return LeafNode(tree_conf, data, page) 178 | elif node_type_int == 5: 179 | return OverflowNode(tree_conf, data, page) 180 | elif node_type_int == 6: 181 | return FreelistNode(tree_conf, data, page) 182 | else: 183 | assert False, 'No Node with type {} exists'.format(node_type_int) 184 | 185 | def __repr__(self): 186 | return '<{}: page={} entries={}>'.format( 187 | self.__class__.__name__, self.page, len(self.entries) 188 | ) 189 | 190 | def __eq__(self, other): 191 | return ( 192 | self.__class__ is other.__class__ and 193 | self.page == other.page and 194 | self.entries == other.entries 195 | ) 196 | 197 | 198 | class RecordNode(Node): 199 | 200 | __slots__ = ['_entry_class'] 201 | 202 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 203 | page: int=None, parent: 'Node'=None, next_page: int=None): 204 | self._entry_class = Record 205 | super().__init__(tree_conf, data, page, parent, next_page) 206 | 207 | 208 | class LonelyRootNode(RecordNode): 209 | """A Root node that holds records. 210 | 211 | It is an exception for when there is only a single node in the tree. 212 | """ 213 | 214 | __slots__ = ['_node_type_int', 'min_children', 'max_children'] 215 | 216 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 217 | page: int=None, parent: 'Node'=None): 218 | self._node_type_int = 1 219 | self.min_children = 0 220 | self.max_children = tree_conf.order - 1 221 | super().__init__(tree_conf, data, page, parent) 222 | 223 | def convert_to_leaf(self): 224 | leaf = LeafNode(self._tree_conf, page=self.page) 225 | leaf.entries = self.entries 226 | return leaf 227 | 228 | 229 | class LeafNode(RecordNode): 230 | """Node that holds the actual records within the tree.""" 231 | 232 | __slots__ = ['_node_type_int', 'min_children', 'max_children'] 233 | 234 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 235 | page: int=None, parent: 'Node'=None, next_page: int=None): 236 | self._node_type_int = 4 237 | self.min_children = math.ceil(tree_conf.order / 2) - 1 238 | self.max_children = tree_conf.order - 1 239 | super().__init__(tree_conf, data, page, parent, next_page) 240 | 241 | 242 | class ReferenceNode(Node): 243 | 244 | __slots__ = ['_entry_class'] 245 | 246 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 247 | page: int=None, parent: 'Node'=None): 248 | self._entry_class = Reference 249 | super().__init__(tree_conf, data, page, parent) 250 | 251 | @property 252 | def num_children(self) -> int: 253 | return len(self.entries) + 1 if self.entries else 0 254 | 255 | def insert_entry(self, entry: 'Reference'): 256 | """Make sure that after of a reference matches before of the next one. 257 | 258 | Probably very inefficient approach. 259 | """ 260 | super().insert_entry(entry) 261 | i = self.entries.index(entry) 262 | if i > 0: 263 | previous_entry = self.entries[i-1] 264 | previous_entry.after = entry.before 265 | try: 266 | next_entry = self.entries[i+1] 267 | except IndexError: 268 | pass 269 | else: 270 | next_entry.before = entry.after 271 | 272 | 273 | class RootNode(ReferenceNode): 274 | """The first node at the top of the tree.""" 275 | 276 | __slots__ = ['_node_type_int', 'min_children', 'max_children'] 277 | 278 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 279 | page: int=None, parent: 'Node'=None): 280 | self._node_type_int = 2 281 | self.min_children = 2 282 | self.max_children = tree_conf.order 283 | super().__init__(tree_conf, data, page, parent) 284 | 285 | def convert_to_internal(self): 286 | internal = InternalNode(self._tree_conf, page=self.page) 287 | internal.entries = self.entries 288 | return internal 289 | 290 | 291 | class InternalNode(ReferenceNode): 292 | """Node that only holds references to other Internal nodes or Leaves.""" 293 | 294 | __slots__ = ['_node_type_int', 'min_children', 'max_children'] 295 | 296 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 297 | page: int=None, parent: 'Node'=None): 298 | self._node_type_int = 3 299 | self.min_children = math.ceil(tree_conf.order / 2) 300 | self.max_children = tree_conf.order 301 | super().__init__(tree_conf, data, page, parent) 302 | 303 | 304 | class OverflowNode(Node): 305 | """Node that holds a single Record value too large for its Node.""" 306 | 307 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 308 | page: int=None, next_page: int=None): 309 | self._node_type_int = 5 310 | self.max_children = 1 311 | self.min_children = 1 312 | self._entry_class = OpaqueData 313 | super().__init__(tree_conf, data, page, next_page=next_page) 314 | 315 | def __repr__(self): 316 | return '<{}: page={} next_page={}>'.format( 317 | self.__class__.__name__, self.page, self.next_page 318 | ) 319 | 320 | 321 | class FreelistNode(Node): 322 | """Node that is a marker for a deallocated page.""" 323 | 324 | def __init__(self, tree_conf: TreeConf, data: Optional[bytes]=None, 325 | page: int=None, next_page: int=None): 326 | self._node_type_int = 6 327 | self.max_children = 0 328 | self.min_children = 0 329 | super().__init__(tree_conf, data, page, next_page=next_page) 330 | 331 | def __repr__(self): 332 | return '<{}: page={} next_page={}>'.format( 333 | self.__class__.__name__, self.page, self.next_page 334 | ) 335 | -------------------------------------------------------------------------------- /bplustree/tree.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from logging import getLogger 3 | from typing import Optional, Union, Iterator, Iterable 4 | 5 | from . import utils 6 | from .const import TreeConf 7 | from .entry import Record, Reference, OpaqueData 8 | from .memory import FileMemory 9 | from .node import ( 10 | Node, LonelyRootNode, RootNode, InternalNode, LeafNode, OverflowNode 11 | ) 12 | from .serializer import Serializer, IntSerializer 13 | 14 | 15 | logger = getLogger(__name__) 16 | 17 | 18 | class BPlusTree: 19 | 20 | __slots__ = ['_filename', '_tree_conf', '_mem', '_root_node_page', 21 | '_is_open', 'LonelyRootNode', 'RootNode', 'InternalNode', 22 | 'LeafNode', 'OverflowNode', 'Record', 'Reference'] 23 | 24 | # ######################### Public API ################################ 25 | 26 | def __init__(self, filename: str, page_size: int= 4096, order: int=100, 27 | key_size: int=8, value_size: int=32, cache_size: int=64, 28 | serializer: Optional[Serializer]=None): 29 | self._filename = filename 30 | self._tree_conf = TreeConf( 31 | page_size, order, key_size, value_size, 32 | serializer or IntSerializer() 33 | ) 34 | self._create_partials() 35 | self._mem = FileMemory(filename, self._tree_conf, 36 | cache_size=cache_size) 37 | try: 38 | metadata = self._mem.get_metadata() 39 | except ValueError: 40 | self._initialize_empty_tree() 41 | else: 42 | self._root_node_page, self._tree_conf = metadata 43 | self._is_open = True 44 | 45 | def close(self): 46 | with self._mem.write_transaction: 47 | if not self._is_open: 48 | logger.info('Tree is already closed') 49 | return 50 | 51 | self._mem.close() 52 | self._is_open = False 53 | 54 | def __enter__(self): 55 | return self 56 | 57 | def __exit__(self, exc_type, exc_val, exc_tb): 58 | self.close() 59 | 60 | def checkpoint(self): 61 | with self._mem.write_transaction: 62 | self._mem.perform_checkpoint(reopen_wal=True) 63 | 64 | def insert(self, key, value: bytes, replace=False): 65 | """Insert a value in the tree. 66 | 67 | :param key: The key at which the value will be recorded, must be of the 68 | same type used by the Serializer 69 | :param value: The value to record in bytes 70 | :param replace: If True, already existing value will be overridden, 71 | otherwise a ValueError is raised. 72 | """ 73 | if not isinstance(value, bytes): 74 | ValueError('Values must be bytes objects') 75 | 76 | with self._mem.write_transaction: 77 | node = self._search_in_tree(key, self._root_node) 78 | 79 | # Check if a record with the key already exists 80 | try: 81 | existing_record = node.get_entry(key) 82 | except ValueError: 83 | pass 84 | else: 85 | if not replace: 86 | raise ValueError('Key {} already exists'.format(key)) 87 | 88 | if existing_record.overflow_page: 89 | self._delete_overflow(existing_record.overflow_page) 90 | 91 | if len(value) <= self._tree_conf.value_size: 92 | existing_record.value = value 93 | existing_record.overflow_page = None 94 | else: 95 | existing_record.value = None 96 | existing_record.overflow_page = self._create_overflow( 97 | value 98 | ) 99 | self._mem.set_node(node) 100 | return 101 | 102 | if len(value) <= self._tree_conf.value_size: 103 | record = self.Record(key, value=value) 104 | else: 105 | # Record values exceeding the max value_size must be placed 106 | # into overflow pages 107 | first_overflow_page = self._create_overflow(value) 108 | record = self.Record(key, value=None, 109 | overflow_page=first_overflow_page) 110 | 111 | if node.can_add_entry: 112 | node.insert_entry(record) 113 | self._mem.set_node(node) 114 | else: 115 | node.insert_entry(record) 116 | self._split_leaf(node) 117 | 118 | def batch_insert(self, iterable: Iterable): 119 | """Insert many elements in the tree at once. 120 | 121 | The iterable object must yield tuples (key, value) in ascending order. 122 | All keys to insert must be bigger than all keys currently in the tree. 123 | All inserts happen in a single transaction. This is way faster than 124 | manually inserting in a loop. 125 | """ 126 | node = None 127 | with self._mem.write_transaction: 128 | 129 | for key, value in iterable: 130 | 131 | if node is None: 132 | node = self._search_in_tree(key, self._root_node) 133 | 134 | try: 135 | biggest_entry = node.biggest_entry 136 | except IndexError: 137 | biggest_entry = None 138 | if biggest_entry and key <= biggest_entry.key: 139 | raise ValueError('Keys to batch insert must be sorted and ' 140 | 'bigger than keys currently in the tree') 141 | 142 | if len(value) <= self._tree_conf.value_size: 143 | record = self.Record(key, value=value) 144 | else: 145 | # Record values exceeding the max value_size must be placed 146 | # into overflow pages 147 | first_overflow_page = self._create_overflow(value) 148 | record = self.Record(key, value=None, 149 | overflow_page=first_overflow_page) 150 | 151 | if node.can_add_entry: 152 | node.insert_entry_at_the_end(record) 153 | else: 154 | node.insert_entry_at_the_end(record) 155 | self._split_leaf(node) 156 | node = None 157 | 158 | if node is not None: 159 | self._mem.set_node(node) 160 | 161 | def get(self, key, default=None) -> bytes: 162 | with self._mem.read_transaction: 163 | node = self._search_in_tree(key, self._root_node) 164 | try: 165 | record = node.get_entry(key) 166 | except ValueError: 167 | return default 168 | else: 169 | rv = self._get_value_from_record(record) 170 | assert isinstance(rv, bytes) 171 | return rv 172 | 173 | def __contains__(self, item): 174 | with self._mem.read_transaction: 175 | o = object() 176 | return False if self.get(item, default=o) is o else True 177 | 178 | def __setitem__(self, key, value): 179 | self.insert(key, value, replace=True) 180 | 181 | def __getitem__(self, item): 182 | with self._mem.read_transaction: 183 | 184 | if isinstance(item, slice): 185 | # Returning a dict is the most sensible thing to do 186 | # as a method cannot return a sometimes a generator 187 | # and sometimes a normal value 188 | rv = dict() 189 | for record in self._iter_slice(item): 190 | rv[record.key] = self._get_value_from_record(record) 191 | return rv 192 | 193 | else: 194 | rv = self.get(item) 195 | if rv is None: 196 | raise KeyError(item) 197 | return rv 198 | 199 | def __len__(self): 200 | with self._mem.read_transaction: 201 | node = self._left_record_node 202 | rv = 0 203 | while True: 204 | rv += len(node.entries) 205 | if not node.next_page: 206 | return rv 207 | node = self._mem.get_node(node.next_page) 208 | 209 | def __length_hint__(self): 210 | with self._mem.read_transaction: 211 | node = self._root_node 212 | if isinstance(node, LonelyRootNode): 213 | # Assume that the lonely root node is half full 214 | return node.max_children // 2 215 | # Assume that there are no holes in pages 216 | last_page = self._mem.last_page 217 | # Assume that 70% of nodes in a tree carry values 218 | num_leaf_nodes = int(last_page * 0.70) 219 | # Assume that every leaf node is half full 220 | num_records_per_leaf_node = int( 221 | (node.max_children + node.min_children) / 2 222 | ) 223 | return num_leaf_nodes * num_records_per_leaf_node 224 | 225 | def __iter__(self, slice_: Optional[slice]=None): 226 | if not slice_: 227 | slice_ = slice(None) 228 | with self._mem.read_transaction: 229 | for record in self._iter_slice(slice_): 230 | yield record.key 231 | 232 | keys = __iter__ 233 | 234 | def items(self, slice_: Optional[slice]=None) -> Iterator[tuple]: 235 | if not slice_: 236 | slice_ = slice(None) 237 | with self._mem.read_transaction: 238 | for record in self._iter_slice(slice_): 239 | yield record.key, self._get_value_from_record(record) 240 | 241 | def values(self, slice_: Optional[slice]=None) -> Iterator[bytes]: 242 | if not slice_: 243 | slice_ = slice(None) 244 | with self._mem.read_transaction: 245 | for record in self._iter_slice(slice_): 246 | yield self._get_value_from_record(record) 247 | 248 | def __bool__(self): 249 | with self._mem.read_transaction: 250 | for _ in self: 251 | return True 252 | return False 253 | 254 | def __repr__(self): 255 | return ''.format(self._filename, self._tree_conf) 256 | 257 | # ####################### Implementation ############################## 258 | 259 | def _initialize_empty_tree(self): 260 | self._root_node_page = self._mem.next_available_page 261 | with self._mem.write_transaction: 262 | self._mem.set_node(self.LonelyRootNode(page=self._root_node_page)) 263 | self._mem.set_metadata(self._root_node_page, self._tree_conf) 264 | 265 | def _create_partials(self): 266 | self.LonelyRootNode = partial(LonelyRootNode, self._tree_conf) 267 | self.RootNode = partial(RootNode, self._tree_conf) 268 | self.InternalNode = partial(InternalNode, self._tree_conf) 269 | self.LeafNode = partial(LeafNode, self._tree_conf) 270 | self.OverflowNode = partial(OverflowNode, self._tree_conf) 271 | self.Record = partial(Record, self._tree_conf) 272 | self.Reference = partial(Reference, self._tree_conf) 273 | 274 | @property 275 | def _root_node(self) -> Union['LonelyRootNode', 'RootNode']: 276 | root_node = self._mem.get_node(self._root_node_page) 277 | assert isinstance(root_node, (LonelyRootNode, RootNode)) 278 | return root_node 279 | 280 | @property 281 | def _left_record_node(self) -> Union['LonelyRootNode', 'LeafNode']: 282 | node = self._root_node 283 | while not isinstance(node, (LonelyRootNode, LeafNode)): 284 | node = self._mem.get_node(node.smallest_entry.before) 285 | return node 286 | 287 | def _iter_slice(self, slice_: slice) -> Iterator[Record]: 288 | if slice_.step is not None: 289 | raise ValueError('Cannot iterate with a custom step') 290 | 291 | if (slice_.start is not None and slice_.stop is not None and 292 | slice_.start >= slice_.stop): 293 | raise ValueError('Cannot iterate backwards') 294 | 295 | if slice_.start is None: 296 | node = self._left_record_node 297 | else: 298 | node = self._search_in_tree(slice_.start, self._root_node) 299 | 300 | while True: 301 | for entry in node.entries: 302 | if slice_.start is not None and entry.key < slice_.start: 303 | continue 304 | 305 | if slice_.stop is not None and entry.key >= slice_.stop: 306 | return 307 | 308 | yield entry 309 | 310 | if node.next_page: 311 | node = self._mem.get_node(node.next_page) 312 | else: 313 | return 314 | 315 | def _search_in_tree(self, key, node) -> 'Node': 316 | if isinstance(node, (LonelyRootNode, LeafNode)): 317 | return node 318 | 319 | page = None 320 | 321 | if key < node.smallest_key: 322 | page = node.smallest_entry.before 323 | 324 | elif node.biggest_key <= key: 325 | page = node.biggest_entry.after 326 | 327 | else: 328 | for ref_a, ref_b in utils.pairwise(node.entries): 329 | if ref_a.key <= key < ref_b.key: 330 | page = ref_a.after 331 | break 332 | 333 | assert page is not None 334 | 335 | child_node = self._mem.get_node(page) 336 | child_node.parent = node 337 | return self._search_in_tree(key, child_node) 338 | 339 | def _split_leaf(self, old_node: 'Node'): 340 | """Split a leaf Node to allow the tree to grow.""" 341 | parent = old_node.parent 342 | new_node = self.LeafNode(page=self._mem.next_available_page, 343 | next_page=old_node.next_page) 344 | new_entries = old_node.split_entries() 345 | new_node.entries = new_entries 346 | ref = self.Reference(new_node.smallest_key, 347 | old_node.page, new_node.page) 348 | 349 | if isinstance(old_node, LonelyRootNode): 350 | # Convert the LonelyRoot into a Leaf 351 | old_node = old_node.convert_to_leaf() 352 | self._create_new_root(ref) 353 | elif parent.can_add_entry: 354 | parent.insert_entry(ref) 355 | self._mem.set_node(parent) 356 | else: 357 | parent.insert_entry(ref) 358 | self._split_parent(parent) 359 | 360 | old_node.next_page = new_node.page 361 | 362 | self._mem.set_node(old_node) 363 | self._mem.set_node(new_node) 364 | 365 | def _split_parent(self, old_node: Node): 366 | parent = old_node.parent 367 | new_node = self.InternalNode(page=self._mem.next_available_page) 368 | new_entries = old_node.split_entries() 369 | new_node.entries = new_entries 370 | 371 | ref = new_node.pop_smallest() 372 | ref.before = old_node.page 373 | ref.after = new_node.page 374 | 375 | if isinstance(old_node, RootNode): 376 | # Convert the Root into an Internal 377 | old_node = old_node.convert_to_internal() 378 | self._create_new_root(ref) 379 | elif parent.can_add_entry: 380 | parent.insert_entry(ref) 381 | self._mem.set_node(parent) 382 | else: 383 | parent.insert_entry(ref) 384 | self._split_parent(parent) 385 | 386 | self._mem.set_node(old_node) 387 | self._mem.set_node(new_node) 388 | 389 | def _create_new_root(self, reference: Reference): 390 | new_root = self.RootNode(page=self._mem.next_available_page) 391 | new_root.insert_entry(reference) 392 | self._root_node_page = new_root.page 393 | self._mem.set_metadata(self._root_node_page, self._tree_conf) 394 | self._mem.set_node(new_root) 395 | 396 | def _create_overflow(self, value: bytes) -> int: 397 | first_overflow_page = self._mem.next_available_page 398 | next_overflow_page = first_overflow_page 399 | 400 | iterator = utils.iter_slice(value, self.OverflowNode().max_payload) 401 | for slice_value, is_last in iterator: 402 | current_overflow_page = next_overflow_page 403 | 404 | if is_last: 405 | next_overflow_page = None 406 | else: 407 | next_overflow_page = self._mem.next_available_page 408 | 409 | overflow_node = self.OverflowNode( 410 | page=current_overflow_page, next_page=next_overflow_page 411 | ) 412 | overflow_node.insert_entry_at_the_end(OpaqueData(data=slice_value)) 413 | self._mem.set_node(overflow_node) 414 | 415 | return first_overflow_page 416 | 417 | def _traverse_overflow(self, first_overflow_page: int): 418 | """Yield all Nodes of an overflow chain.""" 419 | next_overflow_page = first_overflow_page 420 | while True: 421 | overflow_node = self._mem.get_node(next_overflow_page) 422 | yield overflow_node 423 | 424 | next_overflow_page = overflow_node.next_page 425 | if next_overflow_page is None: 426 | break 427 | 428 | def _read_from_overflow(self, first_overflow_page: int) -> bytes: 429 | """Collect all values of an overflow chain.""" 430 | rv = bytearray() 431 | for overflow_node in self._traverse_overflow(first_overflow_page): 432 | rv.extend(overflow_node.smallest_entry.data) 433 | 434 | return bytes(rv) 435 | 436 | def _delete_overflow(self, first_overflow_page: int): 437 | """Delete all Nodes in an overflow chain.""" 438 | for overflow_node in self._traverse_overflow(first_overflow_page): 439 | self._mem.del_node(overflow_node) 440 | 441 | def _get_value_from_record(self, record: Record) -> bytes: 442 | if record.value is not None: 443 | return record.value 444 | 445 | return self._read_from_overflow(record.overflow_page) 446 | -------------------------------------------------------------------------------- /bplustree/memory.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import io 3 | from logging import getLogger 4 | import os 5 | import platform 6 | from typing import Union, Tuple, Optional 7 | 8 | import cachetools 9 | import rwlock 10 | 11 | from .node import Node, FreelistNode 12 | from .const import ( 13 | ENDIAN, PAGE_REFERENCE_BYTES, OTHERS_BYTES, TreeConf, FRAME_TYPE_BYTES 14 | ) 15 | 16 | logger = getLogger(__name__) 17 | 18 | 19 | class ReachedEndOfFile(Exception): 20 | """Read a file until its end.""" 21 | 22 | 23 | def open_file_in_dir(path: str) -> Tuple[io.FileIO, Optional[int]]: 24 | """Open a file and its directory. 25 | 26 | The file is opened in binary mode and created if it does not exist. 27 | Both file descriptors must be closed after use to prevent them from 28 | leaking. 29 | 30 | On Windows, the directory is not opened, as it is useless. 31 | """ 32 | directory = os.path.dirname(path) 33 | if not os.path.isdir(directory): 34 | raise ValueError('No directory {}'.format(directory)) 35 | 36 | if not os.path.exists(path): 37 | file_fd = open(path, mode='x+b', buffering=0) 38 | else: 39 | file_fd = open(path, mode='r+b', buffering=0) 40 | 41 | if platform.system() == 'Windows': 42 | # Opening a directory is not possible on Windows, but that is not 43 | # a problem since Windows does not need to fsync the directory in 44 | # order to persist metadata 45 | dir_fd = None 46 | else: 47 | dir_fd = os.open(directory, os.O_RDONLY) 48 | 49 | return file_fd, dir_fd 50 | 51 | 52 | def write_to_file(file_fd: io.FileIO, dir_fileno: Optional[int], 53 | data: bytes, fsync: bool=True): 54 | length_to_write = len(data) 55 | written = 0 56 | while written < length_to_write: 57 | written += file_fd.write(data[written:]) 58 | if fsync: 59 | fsync_file_and_dir(file_fd.fileno(), dir_fileno) 60 | 61 | 62 | def fsync_file_and_dir(file_fileno: int, dir_fileno: Optional[int]): 63 | os.fsync(file_fileno) 64 | if dir_fileno is not None: 65 | os.fsync(dir_fileno) 66 | 67 | 68 | def read_from_file(file_fd: io.FileIO, start: int, stop: int) -> bytes: 69 | length = stop - start 70 | assert length >= 0 71 | file_fd.seek(start) 72 | data = bytes() 73 | while file_fd.tell() < stop: 74 | read_data = file_fd.read(stop - file_fd.tell()) 75 | if read_data == b'': 76 | raise ReachedEndOfFile('Read until the end of file') 77 | data += read_data 78 | assert len(data) == length 79 | return data 80 | 81 | 82 | class FakeCache: 83 | """A cache that doesn't cache anything. 84 | 85 | Because cachetools does not work with maxsize=0. 86 | """ 87 | 88 | def get(self, k): 89 | pass 90 | 91 | def __setitem__(self, key, value): 92 | pass 93 | 94 | def clear(self): 95 | pass 96 | 97 | 98 | class FileMemory: 99 | 100 | __slots__ = ['_filename', '_tree_conf', '_lock', '_cache', '_fd', 101 | '_dir_fd', '_wal', 'last_page', '_freelist_start_page', 102 | '_root_node_page'] 103 | 104 | def __init__(self, filename: str, tree_conf: TreeConf, 105 | cache_size: int=512): 106 | self._filename = filename 107 | self._tree_conf = tree_conf 108 | self._lock = rwlock.RWLock() 109 | 110 | if cache_size == 0: 111 | self._cache = FakeCache() 112 | else: 113 | self._cache = cachetools.LRUCache(maxsize=cache_size) 114 | 115 | self._fd, self._dir_fd = open_file_in_dir(filename) 116 | 117 | self._wal = WAL(filename, tree_conf.page_size) 118 | if self._wal.needs_recovery: 119 | self.perform_checkpoint(reopen_wal=True) 120 | 121 | # Get the next available page 122 | self._fd.seek(0, io.SEEK_END) 123 | last_byte = self._fd.tell() 124 | self.last_page = int(last_byte / self._tree_conf.page_size) 125 | self._freelist_start_page = 0 126 | 127 | # Todo: Remove this, it should only be in Tree 128 | self._root_node_page = 0 129 | 130 | def get_node(self, page: int): 131 | """Get a node from storage. 132 | 133 | The cache is not there to prevent hitting the disk, the OS is already 134 | very good at it. It is there to avoid paying the price of deserializing 135 | the data to create the Node object and its entry. This is a very 136 | expensive operation in Python. 137 | 138 | Since we have at most a single writer we can write to cache on 139 | `set_node` if we invalidate the cache when a transaction is rolled 140 | back. 141 | """ 142 | node = self._cache.get(page) 143 | if node is not None: 144 | return node 145 | 146 | data = self._wal.get_page(page) 147 | if not data: 148 | data = self._read_page(page) 149 | 150 | node = Node.from_page_data(self._tree_conf, data=data, page=page) 151 | self._cache[node.page] = node 152 | return node 153 | 154 | def set_node(self, node: Node): 155 | self._wal.set_page(node.page, node.dump()) 156 | self._cache[node.page] = node 157 | 158 | def del_node(self, node: Node): 159 | self._insert_in_freelist(node.page) 160 | 161 | def del_page(self, page: int): 162 | self._insert_in_freelist(page) 163 | 164 | @property 165 | def read_transaction(self): 166 | 167 | class ReadTransaction: 168 | 169 | def __enter__(self2): 170 | self._lock.reader_lock.acquire() 171 | 172 | def __exit__(self2, exc_type, exc_val, exc_tb): 173 | self._lock.reader_lock.release() 174 | 175 | return ReadTransaction() 176 | 177 | @property 178 | def write_transaction(self): 179 | 180 | class WriteTransaction: 181 | 182 | def __enter__(self2): 183 | self._lock.writer_lock.acquire() 184 | 185 | def __exit__(self2, exc_type, exc_val, exc_tb): 186 | if exc_type: 187 | # When an error happens in the middle of a write 188 | # transaction we must roll it back and clear the cache 189 | # because the writer may have partially modified the Nodes 190 | self._wal.rollback() 191 | self._cache.clear() 192 | else: 193 | self._wal.commit() 194 | self._lock.writer_lock.release() 195 | 196 | return WriteTransaction() 197 | 198 | @property 199 | def next_available_page(self) -> int: 200 | last_freelist_page = self._pop_from_freelist() 201 | if last_freelist_page is not None: 202 | return last_freelist_page 203 | 204 | self.last_page += 1 205 | return self.last_page 206 | 207 | def _traverse_free_list(self) -> Tuple[Optional[FreelistNode], 208 | Optional[FreelistNode]]: 209 | if self._freelist_start_page == 0: 210 | return None, None 211 | 212 | second_to_last_node = None 213 | last_node = self.get_node(self._freelist_start_page) 214 | 215 | while last_node.next_page is not None: 216 | second_to_last_node = last_node 217 | last_node = self.get_node(second_to_last_node.next_page) 218 | 219 | return second_to_last_node, last_node 220 | 221 | def _insert_in_freelist(self, page: int): 222 | """Insert a page at the end of the freelist.""" 223 | _, last_node = self._traverse_free_list() 224 | 225 | self.set_node(FreelistNode(self._tree_conf, page=page, next_page=None)) 226 | 227 | if last_node is None: 228 | # Write in metadata that the freelist got a new starting point 229 | self._freelist_start_page = page 230 | self.set_metadata(None, None) 231 | else: 232 | last_node.next_page = page 233 | self.set_node(last_node) 234 | 235 | def _pop_from_freelist(self) -> Optional[int]: 236 | """Remove the last page from the freelist and return its page.""" 237 | second_to_last_node, last_node = self._traverse_free_list() 238 | 239 | if last_node is None: 240 | # Freelist is completely empty, nothing to pop 241 | return None 242 | 243 | if second_to_last_node is None: 244 | # Write in metadata that the freelist is empty 245 | self._freelist_start_page = 0 246 | self.set_metadata(None, None) 247 | else: 248 | second_to_last_node.next_page = None 249 | self.set_node(second_to_last_node) 250 | 251 | return last_node.page 252 | 253 | # Todo: make metadata as a normal Node 254 | def get_metadata(self) -> tuple: 255 | try: 256 | data = self._read_page(0) 257 | except ReachedEndOfFile: 258 | raise ValueError('Metadata not set yet') 259 | end_root_node_page = PAGE_REFERENCE_BYTES 260 | root_node_page = int.from_bytes( 261 | data[0:end_root_node_page], ENDIAN 262 | ) 263 | end_page_size = end_root_node_page + OTHERS_BYTES 264 | page_size = int.from_bytes( 265 | data[end_root_node_page:end_page_size], ENDIAN 266 | ) 267 | end_order = end_page_size + OTHERS_BYTES 268 | order = int.from_bytes( 269 | data[end_page_size:end_order], ENDIAN 270 | ) 271 | end_key_size = end_order + OTHERS_BYTES 272 | key_size = int.from_bytes( 273 | data[end_order:end_key_size], ENDIAN 274 | ) 275 | end_value_size = end_key_size + OTHERS_BYTES 276 | value_size = int.from_bytes( 277 | data[end_key_size:end_value_size], ENDIAN 278 | ) 279 | end_freelist_start_page = end_value_size + PAGE_REFERENCE_BYTES 280 | self._freelist_start_page = int.from_bytes( 281 | data[end_value_size:end_freelist_start_page], ENDIAN 282 | ) 283 | self._tree_conf = TreeConf( 284 | page_size, order, key_size, value_size, self._tree_conf.serializer 285 | ) 286 | self._root_node_page = root_node_page 287 | return root_node_page, self._tree_conf 288 | 289 | def set_metadata(self, root_node_page: Optional[int], 290 | tree_conf: Optional[TreeConf]): 291 | 292 | if root_node_page is None: 293 | root_node_page = self._root_node_page 294 | 295 | if tree_conf is None: 296 | tree_conf = self._tree_conf 297 | 298 | length = 2 * PAGE_REFERENCE_BYTES + 4 * OTHERS_BYTES 299 | data = ( 300 | root_node_page.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) + 301 | tree_conf.page_size.to_bytes(OTHERS_BYTES, ENDIAN) + 302 | tree_conf.order.to_bytes(OTHERS_BYTES, ENDIAN) + 303 | tree_conf.key_size.to_bytes(OTHERS_BYTES, ENDIAN) + 304 | tree_conf.value_size.to_bytes(OTHERS_BYTES, ENDIAN) + 305 | self._freelist_start_page.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) + 306 | bytes(tree_conf.page_size - length) 307 | ) 308 | self._write_page_in_tree(0, data, fsync=True) 309 | 310 | self._tree_conf = tree_conf 311 | self._root_node_page = root_node_page 312 | 313 | def close(self): 314 | self.perform_checkpoint() 315 | self._fd.close() 316 | if self._dir_fd is not None: 317 | os.close(self._dir_fd) 318 | 319 | def perform_checkpoint(self, reopen_wal=False): 320 | logger.info('Performing checkpoint of %s', self._filename) 321 | for page, page_data in self._wal.checkpoint(): 322 | self._write_page_in_tree(page, page_data, fsync=False) 323 | fsync_file_and_dir(self._fd.fileno(), self._dir_fd) 324 | if reopen_wal: 325 | self._wal = WAL(self._filename, self._tree_conf.page_size) 326 | 327 | def _read_page(self, page: int) -> bytes: 328 | start = page * self._tree_conf.page_size 329 | stop = start + self._tree_conf.page_size 330 | assert stop - start == self._tree_conf.page_size 331 | return read_from_file(self._fd, start, stop) 332 | 333 | def _write_page_in_tree(self, page: int, data: Union[bytes, bytearray], 334 | fsync: bool=True): 335 | """Write a page of data in the tree file itself. 336 | 337 | To be used during checkpoints and other non-standard uses. 338 | """ 339 | assert len(data) == self._tree_conf.page_size 340 | self._fd.seek(page * self._tree_conf.page_size) 341 | write_to_file(self._fd, self._dir_fd, data, fsync=fsync) 342 | 343 | def __repr__(self): 344 | return ''.format(self._filename) 345 | 346 | 347 | class FrameType(enum.Enum): 348 | PAGE = 1 349 | COMMIT = 2 350 | ROLLBACK = 3 351 | 352 | 353 | class WAL: 354 | 355 | __slots__ = ['filename', '_fd', '_dir_fd', '_page_size', 356 | '_committed_pages', '_not_committed_pages', 'needs_recovery'] 357 | 358 | FRAME_HEADER_LENGTH = ( 359 | FRAME_TYPE_BYTES + PAGE_REFERENCE_BYTES 360 | ) 361 | 362 | def __init__(self, filename: str, page_size: int): 363 | self.filename = filename + '-wal' 364 | self._fd, self._dir_fd = open_file_in_dir(self.filename) 365 | self._page_size = page_size 366 | self._committed_pages = dict() 367 | self._not_committed_pages = dict() 368 | 369 | self._fd.seek(0, io.SEEK_END) 370 | if self._fd.tell() == 0: 371 | self._create_header() 372 | self.needs_recovery = False 373 | else: 374 | logger.warning('Found an existing WAL file, ' 375 | 'the B+Tree was not closed properly') 376 | self.needs_recovery = True 377 | self._load_wal() 378 | 379 | def checkpoint(self): 380 | """Transfer the modified data back to the tree and close the WAL.""" 381 | if self._not_committed_pages: 382 | logger.warning('Closing WAL with uncommitted data, discarding it') 383 | 384 | fsync_file_and_dir(self._fd.fileno(), self._dir_fd) 385 | 386 | for page, page_start in self._committed_pages.items(): 387 | page_data = read_from_file( 388 | self._fd, 389 | page_start, 390 | page_start + self._page_size 391 | ) 392 | yield page, page_data 393 | 394 | self._fd.close() 395 | os.unlink(self.filename) 396 | if self._dir_fd is not None: 397 | os.fsync(self._dir_fd) 398 | os.close(self._dir_fd) 399 | 400 | def _create_header(self): 401 | data = self._page_size.to_bytes(OTHERS_BYTES, ENDIAN) 402 | self._fd.seek(0) 403 | write_to_file(self._fd, self._dir_fd, data, True) 404 | 405 | def _load_wal(self): 406 | self._fd.seek(0) 407 | header_data = read_from_file(self._fd, 0, OTHERS_BYTES) 408 | assert int.from_bytes(header_data, ENDIAN) == self._page_size 409 | 410 | while True: 411 | try: 412 | self._load_next_frame() 413 | except ReachedEndOfFile: 414 | break 415 | if self._not_committed_pages: 416 | logger.warning('WAL has uncommitted data, discarding it') 417 | self._not_committed_pages = dict() 418 | 419 | def _load_next_frame(self): 420 | start = self._fd.tell() 421 | stop = start + self.FRAME_HEADER_LENGTH 422 | data = read_from_file(self._fd, start, stop) 423 | 424 | frame_type = int.from_bytes(data[0:FRAME_TYPE_BYTES], ENDIAN) 425 | page = int.from_bytes( 426 | data[FRAME_TYPE_BYTES:FRAME_TYPE_BYTES+PAGE_REFERENCE_BYTES], 427 | ENDIAN 428 | ) 429 | 430 | frame_type = FrameType(frame_type) 431 | if frame_type is FrameType.PAGE: 432 | self._fd.seek(stop + self._page_size) 433 | 434 | self._index_frame(frame_type, page, stop) 435 | 436 | def _index_frame(self, frame_type: FrameType, page: int, page_start: int): 437 | if frame_type is FrameType.PAGE: 438 | self._not_committed_pages[page] = page_start 439 | elif frame_type is FrameType.COMMIT: 440 | self._committed_pages.update(self._not_committed_pages) 441 | self._not_committed_pages = dict() 442 | elif frame_type is FrameType.ROLLBACK: 443 | self._not_committed_pages = dict() 444 | else: 445 | assert False 446 | 447 | def _add_frame(self, frame_type: FrameType, page: Optional[int]=None, 448 | page_data: Optional[bytes]=None): 449 | if frame_type is FrameType.PAGE and (not page or not page_data): 450 | raise ValueError('PAGE frame without page data') 451 | if page_data and len(page_data) != self._page_size: 452 | raise ValueError('Page data is different from page size') 453 | if not page: 454 | page = 0 455 | if frame_type is not FrameType.PAGE: 456 | page_data = b'' 457 | data = ( 458 | frame_type.value.to_bytes(FRAME_TYPE_BYTES, ENDIAN) + 459 | page.to_bytes(PAGE_REFERENCE_BYTES, ENDIAN) + 460 | page_data 461 | ) 462 | self._fd.seek(0, io.SEEK_END) 463 | write_to_file(self._fd, self._dir_fd, data, 464 | fsync=frame_type != FrameType.PAGE) 465 | self._index_frame(frame_type, page, self._fd.tell() - self._page_size) 466 | 467 | def get_page(self, page: int) -> Optional[bytes]: 468 | page_start = None 469 | for store in (self._not_committed_pages, self._committed_pages): 470 | page_start = store.get(page) 471 | if page_start: 472 | break 473 | 474 | if not page_start: 475 | return None 476 | 477 | return read_from_file(self._fd, page_start, 478 | page_start + self._page_size) 479 | 480 | def set_page(self, page: int, page_data: bytes): 481 | self._add_frame(FrameType.PAGE, page, page_data) 482 | 483 | def commit(self): 484 | # Commit is a no-op when there is no uncommitted pages 485 | if self._not_committed_pages: 486 | self._add_frame(FrameType.COMMIT) 487 | 488 | def rollback(self): 489 | # Rollback is a no-op when there is no uncommitted pages 490 | if self._not_committed_pages: 491 | self._add_frame(FrameType.ROLLBACK) 492 | 493 | def __repr__(self): 494 | return ''.format(self.filename) 495 | --------------------------------------------------------------------------------