├── .github └── workflows │ ├── conventional_commit.yml │ ├── lint.yml │ ├── secure.yml │ ├── tests.yml │ └── type.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── Makefile ├── README.md ├── ci └── Dockerfile ├── pyproject.toml ├── pytype.cfg ├── src └── pydistributedkv │ ├── __init__.py │ ├── configurator │ ├── __init__.py │ └── settings │ │ ├── __init__.py │ │ ├── base.py │ │ └── config.py │ ├── domain │ ├── __init__.py │ └── models.py │ ├── entrypoints │ ├── __init__.py │ ├── cli │ │ ├── __init__.py │ │ ├── run_follower.py │ │ └── run_leader.py │ ├── main.py │ └── web │ │ ├── __init__.py │ │ ├── follower │ │ ├── __init__.py │ │ └── follower.py │ │ └── leader │ │ ├── __init__.py │ │ └── leader.py │ ├── service │ ├── __init__.py │ ├── compaction.py │ ├── heartbeat.py │ ├── request_deduplication.py │ └── storage.py │ └── utils │ ├── __init__.py │ └── common.py └── tests ├── __init__.py ├── conftest.py ├── domain ├── __init__.py ├── test_models.py ├── test_wal_crc.py └── test_wal_segmentation.py ├── integrations ├── __init__.py ├── delete_key.http ├── get_key.http └── set_key.http ├── service ├── __init__.py ├── test_compaction_service.py ├── test_heartbeat.py ├── test_request_deduplication.py └── test_storage_with_segmented_wal.py ├── test_compaction_integration.py ├── test_storage_versioning.py └── test_versioned_value.py /.github/workflows/conventional_commit.yml: -------------------------------------------------------------------------------- 1 | name: Conventional Commits 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | types: [opened, synchronize, reopened, edited] 8 | 9 | jobs: 10 | build: 11 | name: conventional-commit 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - uses: webiny/action-conventional-commits@v1.3.0 17 | with: 18 | allowed-commit-types: "feat,fix,docs,test,ci,refactor,perf,chore,revert,release,build,style" -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | types: [opened, synchronize, reopened, edited] 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: "3.12" 16 | architecture: x64 17 | - run: pip install --upgrade virtualenv pip setuptools 18 | - run: virtualenv .venv 19 | - run: source .venv/bin/activate; pip install flit==3.8.0 20 | - run: make install-dev 21 | - run: make lint -------------------------------------------------------------------------------- /.github/workflows/secure.yml: -------------------------------------------------------------------------------- 1 | name: secure 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | types: [opened, synchronize, reopened, edited] 7 | 8 | jobs: 9 | secure: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: "3.12" 16 | architecture: x64 17 | - run: pip install --upgrade virtualenv pip setuptools 18 | - run: virtualenv .venv 19 | - run: source .venv/bin/activate; pip install flit==3.8.0 20 | - run: make install-dev 21 | - run: make secure -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test and Comment 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | types: [opened, synchronize, reopened, edited] 8 | 9 | permissions: 10 | checks: write 11 | pull-requests: write 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.12 23 | 24 | - name: Install dependencies 25 | run: | 26 | pip install --upgrade pip && pip install virtualenv && virtualenv .venv 27 | make install-flit && FLIT_ROOT_INSTALL=1 make install-dev 28 | 29 | - name: Run Python tests 30 | run: .venv/bin/python -m pytest --junit-xml=test-results/results.xml 31 | 32 | - name: Comment on PR 33 | uses: EnricoMi/publish-unit-test-result-action@v2 34 | if: always() 35 | with: 36 | files: | 37 | test-results/**/*.xml -------------------------------------------------------------------------------- /.github/workflows/type.yml: -------------------------------------------------------------------------------- 1 | name: type 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | types: [opened, synchronize, reopened, edited] 7 | 8 | jobs: 9 | type: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: "3.12" 16 | architecture: x64 17 | - run: pip install --upgrade virtualenv pip setuptools 18 | - run: virtualenv .venv 19 | - run: source .venv/bin/activate; pip install flit==3.8.0 20 | - run: make install-dev 21 | - run: make type-check -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | clamav_file_scanner.egg-info/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | .DS_Store 30 | data 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # SSH stuff 74 | start_ssh.sh 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | 167 | # Local dev files 168 | .tranco/ 169 | *.json 170 | *.csv 171 | Screenshots/ 172 | old_benchmarks/ 173 | logs/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.12 3 | 4 | default_stages: [pre-commit, pre-push] 5 | 6 | repos: 7 | - repo: https://github.com/MarcoGorelli/absolufy-imports 8 | rev: v0.3.1 9 | hooks: 10 | - id: absolufy-imports 11 | - repo: local 12 | hooks: 13 | - id: lint 14 | name: lint 15 | entry: make lint 16 | language: system 17 | types: [ python ] 18 | pass_filenames: false 19 | - id: secure 20 | name: secure 21 | entry: make secure 22 | language: system 23 | types: [python] 24 | pass_filenames: false 25 | - id: pytype 26 | name: pytype 27 | entry: make type-check 28 | language: system 29 | types: [python] 30 | pass_filenames: false 31 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | # Analyse import fallback blocks. This can be used to support both Python 2 and 4 | # 3 compatible code, which means that the block might have code that exists 5 | # only in one or another interpreter, leading to false positives when analysed. 6 | analyse-fallback-blocks=no 7 | 8 | # Load and enable all available extensions. Use --list-extensions to see a list 9 | # all available extensions. 10 | #enable-all-extensions= 11 | 12 | # In error mode, messages with a category besides ERROR or FATAL are 13 | # suppressed, and no reports are done by default. Error mode is compatible with 14 | # disabling specific errors. 15 | #errors-only= 16 | 17 | # Always return a 0 (non-error) status code, even if lint errors are found. 18 | # This is primarily useful in continuous integration scripts. 19 | #exit-zero= 20 | 21 | # A comma-separated list of package or module names from where C extensions may 22 | # be loaded. Extensions are loading into the active Python interpreter and may 23 | # run arbitrary code. 24 | extension-pkg-allow-list= 25 | 26 | # A comma-separated list of package or module names from where C extensions may 27 | # be loaded. Extensions are loading into the active Python interpreter and may 28 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list 29 | # for backward compatibility.) 30 | extension-pkg-whitelist= 31 | 32 | # Return non-zero exit code if any of these messages/categories are detected, 33 | # even if score is above --fail-under value. Syntax same as enable. Messages 34 | # specified are enabled, while categories only check already-enabled messages. 35 | fail-on= 36 | 37 | # Specify a score threshold under which the program will exit with error. 38 | fail-under=10 39 | 40 | # Interpret the stdin as a python script, whose filename needs to be passed as 41 | # the module_or_package argument. 42 | #from-stdin= 43 | 44 | # Files or directories to be skipped. They should be base names, not paths. 45 | ignore=CVS 46 | 47 | # Add files or directories matching the regular expressions patterns to the 48 | # ignore-list. The regex matches against paths and can be in Posix or Windows 49 | # format. Because '\' represents the directory delimiter on Windows systems, it 50 | # can't be used as an escape character. 51 | ignore-paths= 52 | 53 | # Files or directories matching the regular expression patterns are skipped. 54 | # The regex matches against base names, not paths. The default value ignores 55 | # Emacs file locks 56 | ignore-patterns=^\.# 57 | 58 | # List of module names for which member attributes should not be checked 59 | # (useful for modules/projects where namespaces are manipulated during runtime 60 | # and thus existing member attributes cannot be deduced by static analysis). It 61 | # supports qualified module names, as well as Unix pattern matching. 62 | ignored-modules= 63 | 64 | # Python code to execute, usually for sys.path manipulation such as 65 | # pygtk.require(). 66 | #init-hook= 67 | 68 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 69 | # number of processors available to use, and will cap the count on Windows to 70 | # avoid hangs. 71 | jobs=1 72 | 73 | # Control the amount of potential inferred values when inferring a single 74 | # object. This can help the performance when dealing with large functions or 75 | # complex, nested conditions. 76 | limit-inference-results=100 77 | 78 | # List of plugins (as comma separated values of python module names) to load, 79 | # usually to register additional checkers. 80 | load-plugins= 81 | 82 | # Pickle collected data for later comparisons. 83 | persistent=yes 84 | 85 | # Minimum Python version to use for version dependent checks. Will default to 86 | # the version used to run pylint. 87 | py-version=3.10 88 | 89 | # Discover python modules and packages in the file system subtree. 90 | recursive=no 91 | 92 | # When enabled, pylint would attempt to guess common misconfiguration and emit 93 | # user-friendly hints instead of false-positive error messages. 94 | suggestion-mode=yes 95 | 96 | # Allow loading of arbitrary C extensions. Extensions are imported into the 97 | # active Python interpreter and may run arbitrary code. 98 | unsafe-load-any-extension=no 99 | 100 | # In verbose mode, extra non-checker-related info will be displayed. 101 | #verbose= 102 | 103 | 104 | [BASIC] 105 | 106 | # Naming style matching correct argument names. 107 | argument-naming-style=snake_case 108 | 109 | # Regular expression matching correct argument names. Overrides argument- 110 | # naming-style. If left empty, argument names will be checked with the set 111 | # naming style. 112 | #argument-rgx= 113 | 114 | # Naming style matching correct attribute names. 115 | attr-naming-style=snake_case 116 | 117 | # Regular expression matching correct attribute names. Overrides attr-naming- 118 | # style. If left empty, attribute names will be checked with the set naming 119 | # style. 120 | #attr-rgx= 121 | 122 | # Bad variable names which should always be refused, separated by a comma. 123 | bad-names=foo, 124 | bar, 125 | baz, 126 | toto, 127 | tutu, 128 | tata 129 | 130 | # Bad variable names regexes, separated by a comma. If names match any regex, 131 | # they will always be refused 132 | bad-names-rgxs= 133 | 134 | # Naming style matching correct class attribute names. 135 | class-attribute-naming-style=any 136 | 137 | # Regular expression matching correct class attribute names. Overrides class- 138 | # attribute-naming-style. If left empty, class attribute names will be checked 139 | # with the set naming style. 140 | #class-attribute-rgx= 141 | 142 | # Naming style matching correct class constant names. 143 | class-const-naming-style=UPPER_CASE 144 | 145 | # Regular expression matching correct class constant names. Overrides class- 146 | # const-naming-style. If left empty, class constant names will be checked with 147 | # the set naming style. 148 | #class-const-rgx= 149 | 150 | # Naming style matching correct class names. 151 | class-naming-style=PascalCase 152 | 153 | # Regular expression matching correct class names. Overrides class-naming- 154 | # style. If left empty, class names will be checked with the set naming style. 155 | #class-rgx= 156 | 157 | # Naming style matching correct constant names. 158 | const-naming-style=UPPER_CASE 159 | 160 | # Regular expression matching correct constant names. Overrides const-naming- 161 | # style. If left empty, constant names will be checked with the set naming 162 | # style. 163 | #const-rgx= 164 | 165 | # Minimum line length for functions/classes that require docstrings, shorter 166 | # ones are exempt. 167 | docstring-min-length=-1 168 | 169 | # Naming style matching correct function names. 170 | function-naming-style=snake_case 171 | 172 | # Regular expression matching correct function names. Overrides function- 173 | # naming-style. If left empty, function names will be checked with the set 174 | # naming style. 175 | #function-rgx= 176 | 177 | # Good variable names which should always be accepted, separated by a comma. 178 | good-names=i, 179 | j, 180 | k, 181 | ex, 182 | Run, 183 | _ 184 | 185 | # Good variable names regexes, separated by a comma. If names match any regex, 186 | # they will always be accepted 187 | good-names-rgxs= 188 | 189 | # Include a hint for the correct naming format with invalid-name. 190 | include-naming-hint=no 191 | 192 | # Naming style matching correct inline iteration names. 193 | inlinevar-naming-style=any 194 | 195 | # Regular expression matching correct inline iteration names. Overrides 196 | # inlinevar-naming-style. If left empty, inline iteration names will be checked 197 | # with the set naming style. 198 | #inlinevar-rgx= 199 | 200 | # Naming style matching correct method names. 201 | method-naming-style=snake_case 202 | 203 | # Regular expression matching correct method names. Overrides method-naming- 204 | # style. If left empty, method names will be checked with the set naming style. 205 | #method-rgx= 206 | 207 | # Naming style matching correct module names. 208 | module-naming-style=snake_case 209 | 210 | # Regular expression matching correct module names. Overrides module-naming- 211 | # style. If left empty, module names will be checked with the set naming style. 212 | #module-rgx= 213 | 214 | # Colon-delimited sets of names that determine each other's naming style when 215 | # the name regexes allow several styles. 216 | name-group= 217 | 218 | # Regular expression which should only match function or class names that do 219 | # not require a docstring. 220 | no-docstring-rgx=^_ 221 | 222 | # List of decorators that produce properties, such as abc.abstractproperty. Add 223 | # to this list to register other decorators that produce valid properties. 224 | # These decorators are taken in consideration only for invalid-name. 225 | property-classes=abc.abstractproperty 226 | 227 | # Regular expression matching correct type variable names. If left empty, type 228 | # variable names will be checked with the set naming style. 229 | #typevar-rgx= 230 | 231 | # Naming style matching correct variable names. 232 | variable-naming-style=snake_case 233 | 234 | # Regular expression matching correct variable names. Overrides variable- 235 | # naming-style. If left empty, variable names will be checked with the set 236 | # naming style. 237 | #variable-rgx= 238 | 239 | 240 | [CLASSES] 241 | 242 | # Warn about protected attribute access inside special methods 243 | check-protected-access-in-special-methods=no 244 | 245 | # List of method names used to declare (i.e. assign) instance attributes. 246 | defining-attr-methods=__init__, 247 | __new__, 248 | setUp, 249 | __post_init__ 250 | 251 | # List of member names, which should be excluded from the protected access 252 | # warning. 253 | exclude-protected=_asdict, 254 | _fields, 255 | _replace, 256 | _source, 257 | _make 258 | 259 | # List of valid names for the first argument in a class method. 260 | valid-classmethod-first-arg=cls 261 | 262 | # List of valid names for the first argument in a metaclass class method. 263 | valid-metaclass-classmethod-first-arg=cls 264 | 265 | 266 | [DESIGN] 267 | 268 | # List of regular expressions of class ancestor names to ignore when counting 269 | # public methods (see R0903) 270 | exclude-too-few-public-methods= 271 | 272 | # List of qualified class names to ignore when counting class parents (see 273 | # R0901) 274 | ignored-parents= 275 | 276 | # Maximum number of arguments for function / method. 277 | max-args=5 278 | 279 | # Maximum number of attributes for a class (see R0902). 280 | max-attributes=7 281 | 282 | # Maximum number of boolean expressions in an if statement (see R0916). 283 | max-bool-expr=5 284 | 285 | # Maximum number of branch for function / method body. 286 | max-branches=12 287 | 288 | # Maximum number of locals for function / method body. 289 | max-locals=15 290 | 291 | # Maximum number of parents for a class (see R0901). 292 | max-parents=7 293 | 294 | # Maximum number of public methods for a class (see R0904). 295 | max-public-methods=20 296 | 297 | # Maximum number of return / yield for function / method body. 298 | max-returns=6 299 | 300 | # Maximum number of statements in function / method body. 301 | max-statements=50 302 | 303 | # Minimum number of public methods for a class (see R0903). 304 | min-public-methods=2 305 | 306 | 307 | [EXCEPTIONS] 308 | 309 | # Exceptions that will emit a warning when caught. 310 | overgeneral-exceptions=BaseException, 311 | Exception 312 | 313 | 314 | [FORMAT] 315 | 316 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 317 | expected-line-ending-format= 318 | 319 | # Regexp for a line that is allowed to be longer than the limit. 320 | ignore-long-lines=^\s*(# )??$ 321 | 322 | # Number of spaces of indent required inside a hanging or continued line. 323 | indent-after-paren=4 324 | 325 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 326 | # tab). 327 | indent-string=' ' 328 | 329 | # Maximum number of characters on a single line. 330 | max-line-length=140 331 | 332 | # Maximum number of lines in a module. 333 | max-module-lines=1000 334 | 335 | # Allow the body of a class to be on the same line as the declaration if body 336 | # contains single statement. 337 | single-line-class-stmt=no 338 | 339 | # Allow the body of an if to be on the same line as the test if there is no 340 | # else. 341 | single-line-if-stmt=no 342 | 343 | 344 | [IMPORTS] 345 | 346 | # List of modules that can be imported at any level, not just the top level 347 | # one. 348 | allow-any-import-level= 349 | 350 | # Allow wildcard imports from modules that define __all__. 351 | allow-wildcard-with-all=no 352 | 353 | # Deprecated modules which should not be used, separated by a comma. 354 | deprecated-modules= 355 | 356 | # Output a graph (.gv or any supported image format) of external dependencies 357 | # to the given file (report RP0402 must not be disabled). 358 | ext-import-graph= 359 | 360 | # Output a graph (.gv or any supported image format) of all (i.e. internal and 361 | # external) dependencies to the given file (report RP0402 must not be 362 | # disabled). 363 | import-graph= 364 | 365 | # Output a graph (.gv or any supported image format) of internal dependencies 366 | # to the given file (report RP0402 must not be disabled). 367 | int-import-graph= 368 | 369 | # Force import order to recognize a module as part of the standard 370 | # compatibility libraries. 371 | known-standard-library= 372 | 373 | # Force import order to recognize a module as part of a third party library. 374 | known-third-party=enchant 375 | 376 | # Couples of modules and preferred modules, separated by a comma. 377 | preferred-modules= 378 | 379 | 380 | [LOGGING] 381 | 382 | # The type of string formatting that logging methods do. `old` means using % 383 | # formatting, `new` is for `{}` formatting. 384 | logging-format-style=old 385 | 386 | # Logging modules to check that the string format arguments are in logging 387 | # function parameter format. 388 | logging-modules=logging 389 | 390 | 391 | [MESSAGES CONTROL] 392 | 393 | # Only show warnings with the listed confidence levels. Leave empty to show 394 | # all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, 395 | # UNDEFINED. 396 | confidence=HIGH, 397 | CONTROL_FLOW, 398 | INFERENCE, 399 | INFERENCE_FAILURE, 400 | UNDEFINED 401 | 402 | # Disable the message, report, category or checker with the given id(s). You 403 | # can either give multiple identifiers separated by comma (,) or put this 404 | # option multiple times (only on the command line, not in the configuration 405 | # file where it should appear only once). You can also use "--disable=all" to 406 | # disable everything first and then re-enable specific checks. For example, if 407 | # you want to run only the similarities checker, you can use "--disable=all 408 | # --enable=similarities". If you want to run only the classes checker, but have 409 | # no Warning level messages displayed, use "--disable=all --enable=classes 410 | # --disable=W". 411 | disable=raw-checker-failed, 412 | bad-inline-option, 413 | locally-disabled, 414 | file-ignored, 415 | suppressed-message, 416 | useless-suppression, 417 | deprecated-pragma, 418 | use-symbolic-message-instead 419 | 420 | # Enable the message, report, category or checker with the given id(s). You can 421 | # either give multiple identifier separated by comma (,) or put this option 422 | # multiple time (only on the command line, not in the configuration file where 423 | # it should appear only once). See also the "--disable" option for examples. 424 | enable=c-extension-no-member 425 | 426 | 427 | [METHOD_ARGS] 428 | 429 | # List of qualified names (i.e., library.method) which require a timeout 430 | # parameter e.g. 'requests.api.get,requests.api.post' 431 | timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request 432 | 433 | 434 | [MISCELLANEOUS] 435 | 436 | # List of note tags to take in consideration, separated by a comma. 437 | notes=FIXME, 438 | XXX, 439 | TODO 440 | 441 | # Regular expression of note tags to take in consideration. 442 | notes-rgx= 443 | 444 | 445 | [REFACTORING] 446 | 447 | # Maximum number of nested blocks for function / method body 448 | max-nested-blocks=5 449 | 450 | # Complete name of functions that never returns. When checking for 451 | # inconsistent-return-statements if a never returning function is called then 452 | # it will be considered as an explicit return statement and no message will be 453 | # printed. 454 | never-returning-functions=sys.exit,argparse.parse_error 455 | 456 | 457 | [REPORTS] 458 | 459 | # Python expression which should return a score less than or equal to 10. You 460 | # have access to the variables 'fatal', 'error', 'warning', 'refactor', 461 | # 'convention', and 'info' which contain the number of messages in each 462 | # category, as well as 'statement' which is the total number of statements 463 | # analyzed. This score is used by the global evaluation report (RP0004). 464 | evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) 465 | 466 | # Template used to display messages. This is a python new-style format string 467 | # used to format the message information. See doc for all details. 468 | msg-template= 469 | 470 | # Set the output format. Available formats are text, parseable, colorized, json 471 | # and msvs (visual studio). You can also give a reporter class, e.g. 472 | # mypackage.mymodule.MyReporterClass. 473 | #output-format= 474 | 475 | # Tells whether to display a full report or only the messages. 476 | reports=no 477 | 478 | # Activate the evaluation score. 479 | score=yes 480 | 481 | 482 | [SIMILARITIES] 483 | 484 | # Comments are removed from the similarity computation 485 | ignore-comments=yes 486 | 487 | # Docstrings are removed from the similarity computation 488 | ignore-docstrings=yes 489 | 490 | # Imports are removed from the similarity computation 491 | ignore-imports=yes 492 | 493 | # Signatures are removed from the similarity computation 494 | ignore-signatures=yes 495 | 496 | # Minimum lines number of a similarity. 497 | min-similarity-lines=4 498 | 499 | 500 | [SPELLING] 501 | 502 | # Limits count of emitted suggestions for spelling mistakes. 503 | max-spelling-suggestions=4 504 | 505 | # Spelling dictionary name. Available dictionaries: none. To make it work, 506 | # install the 'python-enchant' package. 507 | spelling-dict= 508 | 509 | # List of comma separated words that should be considered directives if they 510 | # appear at the beginning of a comment and should not be checked. 511 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: 512 | 513 | # List of comma separated words that should not be checked. 514 | spelling-ignore-words= 515 | 516 | # A path to a file that contains the private dictionary; one word per line. 517 | spelling-private-dict-file= 518 | 519 | # Tells whether to store unknown words to the private dictionary (see the 520 | # --spelling-private-dict-file option) instead of raising a message. 521 | spelling-store-unknown-words=no 522 | 523 | 524 | [STRING] 525 | 526 | # This flag controls whether inconsistent-quotes generates a warning when the 527 | # character used as a quote delimiter is used inconsistently within a module. 528 | check-quote-consistency=no 529 | 530 | # This flag controls whether the implicit-str-concat should generate a warning 531 | # on implicit string concatenation in sequences defined over several lines. 532 | check-str-concat-over-line-jumps=no 533 | 534 | 535 | [TYPECHECK] 536 | 537 | # List of decorators that produce context managers, such as 538 | # contextlib.contextmanager. Add to this list to register other decorators that 539 | # produce valid context managers. 540 | contextmanager-decorators=contextlib.contextmanager 541 | 542 | # List of members which are set dynamically and missed by pylint inference 543 | # system, and so shouldn't trigger E1101 when accessed. Python regular 544 | # expressions are accepted. 545 | generated-members= 546 | 547 | # Tells whether to warn about missing members when the owner of the attribute 548 | # is inferred to be None. 549 | ignore-none=yes 550 | 551 | # This flag controls whether pylint should warn about no-member and similar 552 | # checks whenever an opaque object is returned when inferring. The inference 553 | # can return multiple potential results while evaluating a Python object, but 554 | # some branches might not be evaluated, which results in partial inference. In 555 | # that case, it might be useful to still emit no-member and other checks for 556 | # the rest of the inferred objects. 557 | ignore-on-opaque-inference=yes 558 | 559 | # List of symbolic message names to ignore for Mixin members. 560 | ignored-checks-for-mixins=no-member, 561 | not-async-context-manager, 562 | not-context-manager, 563 | attribute-defined-outside-init 564 | 565 | # List of class names for which member attributes should not be checked (useful 566 | # for classes with dynamically set attributes). This supports the use of 567 | # qualified names. 568 | ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace 569 | 570 | # Show a hint with possible names when a member name was not found. The aspect 571 | # of finding the hint is based on edit distance. 572 | missing-member-hint=yes 573 | 574 | # The minimum edit distance a name should have in order to be considered a 575 | # similar match for a missing member name. 576 | missing-member-hint-distance=1 577 | 578 | # The total number of similar names that should be taken in consideration when 579 | # showing a hint for a missing member. 580 | missing-member-max-choices=1 581 | 582 | # Regex pattern to define which classes are considered mixins. 583 | mixin-class-rgx=.*[Mm]ixin 584 | 585 | # List of decorators that change the signature of a decorated function. 586 | signature-mutators= 587 | 588 | 589 | [VARIABLES] 590 | 591 | # List of additional names supposed to be defined in builtins. Remember that 592 | # you should avoid defining new builtins when possible. 593 | additional-builtins= 594 | 595 | # Tells whether unused global variables should be treated as a violation. 596 | allow-global-unused-variables=yes 597 | 598 | # List of names allowed to shadow builtins 599 | allowed-redefined-builtins= 600 | 601 | # List of strings which can identify a callback function by name. A callback 602 | # name must start or end with one of those strings. 603 | callbacks=cb_, 604 | _cb 605 | 606 | # A regular expression matching the name of dummy variables (i.e. expected to 607 | # not be used). 608 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 609 | 610 | # Argument names that match this expression will be ignored. 611 | ignored-argument-names=_.*|^ignored_|^unused_ 612 | 613 | # Tells whether we should check for unused import in __init__ files. 614 | init-import=no 615 | 616 | # List of qualified module names which can have objects that can redefine 617 | # builtins. 618 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 619 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifdef OS 2 | PYTHON ?= .venv/Scripts/python.exe 3 | TYPE_CHECK_COMMAND ?= echo Pytype package doesn't support Windows OS 4 | else 5 | PYTHON ?= .venv/bin/python 6 | TYPE_CHECK_COMMAND ?= ${PYTHON} -m pytype --config=pytype.cfg src 7 | endif 8 | 9 | SETTINGS_FILENAME = pyproject.toml 10 | 11 | PHONY = help install install-dev build format lint type-check secure test install-flit enable-pre-commit-hooks run 12 | 13 | help: 14 | @echo "--------------- HELP ---------------" 15 | @echo "To install the project -> make install" 16 | @echo "To install the project using symlinks (for development) -> make install-dev" 17 | @echo "To build the wheel package -> make build" 18 | @echo "To test the project -> make test" 19 | @echo "To test with coverage [all tests] -> make test-cov" 20 | @echo "To format code -> make format" 21 | @echo "To check linter -> make lint" 22 | @echo "To run type checker -> make type-check" 23 | @echo "To run all security related commands -> make secure" 24 | @echo "------------------------------------" 25 | 26 | install: 27 | ${PYTHON} -m flit install --env --deps=production 28 | 29 | install-dev: 30 | ${PYTHON} -m flit install -s --env --deps=develop --symlink 31 | 32 | install-flit: 33 | ${PYTHON} -m pip install flit==3.8.0 34 | 35 | enable-pre-commit-hooks: 36 | ${PYTHON} -m pre_commit install 37 | 38 | build: 39 | ${PYTHON} -m flit build --format wheel 40 | ${PYTHON} -m pip install dist/*.whl 41 | ${PYTHON} -c 'import pytemplate; print(pytemplate.__version__)' 42 | 43 | format: 44 | ${PYTHON} -m isort src tests --force-single-line-imports --settings-file ${SETTINGS_FILENAME} 45 | ${PYTHON} -m autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place src --exclude=__init__.py 46 | ${PYTHON} -m black src tests --config ${SETTINGS_FILENAME} 47 | ${PYTHON} -m isort src tests --settings-file ${SETTINGS_FILENAME} 48 | 49 | lint: 50 | ${PYTHON} -m flake8 --toml-config ${SETTINGS_FILENAME} --max-complexity 5 --max-cognitive-complexity=5 src 51 | ${PYTHON} -m black src tests --check --diff --config ${SETTINGS_FILENAME} 52 | ${PYTHON} -m isort src tests --check --diff --settings-file ${SETTINGS_FILENAME} 53 | 54 | type-check: 55 | @$(TYPE_CHECK_COMMAND) 56 | 57 | secure: 58 | ${PYTHON} -m bandit -r src --config ${SETTINGS_FILENAME} 59 | 60 | test: 61 | ${PYTHON} -m pytest -svvv -m "not slow and not integration" tests 62 | 63 | test-slow: 64 | ${PYTHON} -m pytest -svvv -m "slow" tests 65 | 66 | test-integration: 67 | ${PYTHON} -m pytest -svvv -m "integration" tests 68 | 69 | run: 70 | ${PYTHON} -m src.pytemplate.main 71 | -------------------------------------------------------------------------------- /ci/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-buster 2 | 3 | WORKDIR /app 4 | RUN apt-get update -y && apt-get install make -y 5 | 6 | COPY . . 7 | RUN pip install --upgrade pip 8 | RUN pip install flit==3.8.0 9 | RUN FLIT_ROOT_INSTALL=1 flit install 10 | 11 | ENTRYPOINT pytest -s tests 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core==3.8.0"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "pydistributedkv" 7 | authors = [ 8 | { name = "rzayev sehriyar", email = "rzayev.sehriyar@gmail.com" }, 9 | ] 10 | dynamic = ["version"] 11 | description = "This is a Python Project Template" 12 | # Add here the production dependencies 13 | dependencies = [ 14 | "pydantic", 15 | "fastapi", 16 | "requests", 17 | "uvicorn" 18 | ] 19 | 20 | [project.optional-dependencies] 21 | dev = [ 22 | "black", 23 | "isort", 24 | "autoflake", 25 | "pytype; platform_system != 'Windows'", 26 | "flake8", 27 | "Flake8-pyproject", 28 | "bandit", 29 | "flake8-bugbear", 30 | "flake8-cognitive-complexity", 31 | "pre-commit", 32 | "safety", 33 | "pip-audit", 34 | ] 35 | test = [ 36 | "pytest", 37 | "pytest-cov", 38 | "pytest-xdist", 39 | "pytest-asyncio", 40 | "pytest-sugar", 41 | "httpx", 42 | ] 43 | 44 | [tool.isort] 45 | profile = "black" 46 | line_length = 140 47 | py_version = 312 48 | order_by_type = false 49 | skip = [".gitignore", ".dockerignore"] 50 | extend_skip = [".md", ".json"] 51 | skip_glob = ["docs/*"] 52 | 53 | [tool.flake8] 54 | max-line-length = 140 55 | select = ["C", "E", "F", "W", "B", "B9"] 56 | ignore = ["E203", "E501", "W503", "C812", "E731", "F811"] 57 | exclude = ["__init__.py"] 58 | extend-immutable-calls = ["Query", "fastapi.Query"] 59 | 60 | [tool.black] 61 | line-length = 140 62 | target-version = ['py312'] 63 | include = '\.pyi?$' 64 | 65 | [tool.bandit] 66 | skips = ["B311", "B404", "B104"] 67 | 68 | [tool.pytest.ini_options] 69 | pythonpath = [ 70 | "src" 71 | ] 72 | testpaths = [ 73 | "tests" 74 | ] 75 | markers = [ 76 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 77 | "integration: marks tests as integration relatively slow (deselect with '-m \"not integration\"')", 78 | "serial", 79 | ] 80 | addopts = [ 81 | "--strict-markers", 82 | "--strict-config", 83 | "-ra", 84 | ] 85 | -------------------------------------------------------------------------------- /pytype.cfg: -------------------------------------------------------------------------------- 1 | # NOTE: All relative paths are relative to the location of this file. 2 | 3 | [pytype] 4 | 5 | # Space-separated list of files or directories to exclude. 6 | exclude = 7 | **/*_test.py 8 | **/test_*.py 9 | 10 | # Space-separated list of files or directories to process. 11 | inputs = 12 | . 13 | 14 | # Keep going past errors to analyze as many files as possible. 15 | keep_going = False 16 | 17 | # Run N jobs in parallel. When 'auto' is used, this will be equivalent to the 18 | # number of CPUs on the host system. 19 | jobs = 4 20 | 21 | # All pytype output goes here. 22 | output = .pytype 23 | 24 | # Platform (e.g., "linux", "win32") that the target code runs on. 25 | platform = linux 26 | 27 | # Paths to source code directories, separated by ':'. 28 | pythonpath = 29 | ./src 30 | 31 | # Python version (major.minor) of the target code. 32 | python_version = 3.12 33 | 34 | # Always use function return type annotations. This flag is temporary and will 35 | # be removed once this behavior is enabled by default. 36 | always_use_return_annotations = True 37 | 38 | # Enable parameter count checks for overriding methods. This flag is temporary 39 | # and will be removed once this behavior is enabled by default. 40 | overriding_parameter_count_checks = True 41 | 42 | # Enable return type checks for overriding methods. This flag is temporary and 43 | # will be removed once this behavior is enabled by default. 44 | overriding_return_type_checks = True 45 | 46 | # Use the enum overlay for more precise enum checking. This flag is temporary 47 | # and will be removed once this behavior is enabled by default. 48 | use_enum_overlay = True 49 | 50 | # Opt-in: Do not allow Any as a return type. 51 | no_return_any = False 52 | 53 | # Experimental: Support pyglib's @cached.property. 54 | enable_cached_property = True 55 | 56 | # Experimental: Infer precise return types even for invalid function calls. 57 | precise_return = True 58 | 59 | # Experimental: Solve unknown types to label with structural types. 60 | protocols = False 61 | 62 | # Experimental: Only load submodules that are explicitly imported. 63 | strict_import = True 64 | 65 | # Experimental: Enable exhaustive checking of function parameter types. 66 | strict_parameter_checks = True 67 | 68 | # Experimental: Emit errors for comparisons between incompatible primitive 69 | # types. 70 | strict_primitive_comparisons = True 71 | 72 | # Comma or space separated list of error names to ignore. 73 | disable = 74 | pyi-error 75 | 76 | # Don't report errors. 77 | report_errors = True 78 | -------------------------------------------------------------------------------- /src/pydistributedkv/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.6.0" 2 | -------------------------------------------------------------------------------- /src/pydistributedkv/configurator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/configurator/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/configurator/settings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/configurator/settings/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/configurator/settings/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # API timeouts in seconds 4 | API_TIMEOUT = int(os.getenv("API_TIMEOUT", "5")) 5 | 6 | # Log segmentation settings 7 | # Default max segment size: 1MB 8 | MAX_SEGMENT_SIZE = int(os.getenv("MAX_SEGMENT_SIZE", str(1024 * 1024))) 9 | 10 | # Heartbeat configuration 11 | HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", 10)) # seconds 12 | HEARTBEAT_TIMEOUT = HEARTBEAT_INTERVAL * 3 # After this many seconds with no heartbeat, mark server as down 13 | 14 | # Compaction settings 15 | compaction_interval = int(os.getenv("COMPACTION_INTERVAL", "3600")) # Default: 1 hour 16 | -------------------------------------------------------------------------------- /src/pydistributedkv/configurator/settings/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | import os 3 | 4 | # Leader config 5 | LEADER_CONFIG = { 6 | "host": "0.0.0.0", 7 | "port": 8000, 8 | "wal_path": "data/leader/wal.log", 9 | "leader_url": "http://localhost:8000", 10 | } 11 | 12 | # Follower config 13 | FOLLOWER_CONFIGS = { 14 | "follower-1": { 15 | "host": "0.0.0.0", 16 | "port": 8001, 17 | "wal_path": "data/followers/follower-1/wal.log", 18 | "leader_url": "http://localhost:8000", 19 | "follower_url": "http://localhost:8001", 20 | }, 21 | "follower-2": { 22 | "host": "0.0.0.0", 23 | "port": 8002, 24 | "wal_path": "data/followers/follower-2/wal.log", 25 | "leader_url": "http://localhost:8000", 26 | "follower_url": "http://localhost:8002", 27 | }, 28 | "follower-3": { 29 | "host": "0.0.0.0", 30 | "port": 8003, 31 | "wal_path": "data/followers/follower-3/wal.log", 32 | "leader_url": "http://localhost:8000", 33 | "follower_url": "http://localhost:8003", 34 | }, 35 | "follower-4": { 36 | "host": "0.0.0.0", 37 | "port": 8004, 38 | "wal_path": "data/followers/follower-4/wal.log", 39 | "leader_url": "http://localhost:8000", 40 | "follower_url": "http://localhost:8004", 41 | }, 42 | } 43 | 44 | # Create directories for WAL files 45 | for config in [LEADER_CONFIG] + list(FOLLOWER_CONFIGS.values()): 46 | os.makedirs(os.path.dirname(config["wal_path"]), exist_ok=True) 47 | -------------------------------------------------------------------------------- /src/pydistributedkv/domain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/domain/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/domain/models.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import zlib 5 | from enum import Enum 6 | from typing import Any, Dict, List, Optional, Set, Tuple 7 | 8 | from pydantic import BaseModel 9 | 10 | 11 | class OperationType(str, Enum): 12 | SET = "SET" 13 | DELETE = "DELETE" 14 | GET = "GET" 15 | 16 | 17 | class LogEntry(BaseModel): 18 | id: int 19 | operation: OperationType 20 | key: str 21 | value: Optional[Any] = None 22 | crc: Optional[int] = None 23 | version: Optional[int] = None # Version number for the key 24 | 25 | def calculate_crc(self) -> int: 26 | """Calculate CRC for this entry based on its content except the CRC itself.""" 27 | # Create a copy of the entry without the CRC field 28 | data_for_crc = self.model_dump() 29 | data_for_crc.pop("crc", None) 30 | # Convert to a stable string representation and calculate CRC32 31 | json_str = json.dumps(data_for_crc, sort_keys=True) 32 | return zlib.crc32(json_str.encode()) 33 | 34 | def validate_crc(self) -> bool: 35 | """Validate that the stored CRC matches the calculated one.""" 36 | if self.crc is None: 37 | return False 38 | return self.crc == self.calculate_crc() 39 | 40 | 41 | class KeyValue(BaseModel): 42 | value: Any 43 | version: Optional[int] = None # Optional version for specific version retrieval 44 | 45 | 46 | class VersionedValue(BaseModel): 47 | """Model representing a value with version history""" 48 | 49 | current_version: int # Latest version number 50 | value: Any # Current value 51 | history: Optional[Dict[int, Any]] = None # Version -> Value mapping 52 | 53 | def get_value(self, version: Optional[int] = None) -> Optional[Any]: 54 | """Get value at specific version, or latest if version is None""" 55 | if version is None: 56 | return self.value 57 | 58 | if version == self.current_version: 59 | return self.value 60 | 61 | if self.history and version in self.history: 62 | return self.history[version] 63 | 64 | return None 65 | 66 | def update(self, value: Any, version: int) -> None: 67 | """Update with a new value and version""" 68 | if version <= self.current_version: 69 | # Ignore updates with older versions 70 | return 71 | 72 | # Save current value to history before updating 73 | if self.history is None: 74 | self.history = {} 75 | 76 | # Always keep history of previous versions 77 | self.history[self.current_version] = self.value 78 | 79 | self.value = value 80 | self.current_version = version 81 | 82 | 83 | class ReplicationStatus(BaseModel): 84 | follower_id: str 85 | last_replicated_id: int 86 | 87 | 88 | class ClientRequest(BaseModel): 89 | """Request from a client with unique identifiers to enable idempotent processing""" 90 | 91 | client_id: str 92 | request_id: str 93 | operation: Optional[OperationType] = None 94 | key: Optional[str] = None 95 | value: Optional[Any] = None 96 | version: Optional[int] = None # Specific version for GET or precondition for SET 97 | 98 | 99 | class ReplicationRequest(BaseModel): 100 | entries: list[dict[str, Any]] 101 | 102 | 103 | class FollowerRegistration(BaseModel): 104 | id: str 105 | url: str 106 | last_applied_id: int = 0 107 | 108 | 109 | class WAL: 110 | def __init__(self, log_file_path: str, max_segment_size: int = 1024 * 1024): # Default 1MB per segment 111 | self.log_dir = os.path.dirname(log_file_path) 112 | self.base_name = os.path.basename(log_file_path) 113 | self.max_segment_size = max_segment_size 114 | self.current_id = 0 115 | self.existing_ids: Set[int] = set() 116 | self.active_segment_path = "" 117 | 118 | self._ensure_log_dir_exists() 119 | self._initialize_segments() 120 | 121 | def _ensure_log_dir_exists(self): 122 | """Ensure that the directory for log files exists.""" 123 | os.makedirs(self.log_dir, exist_ok=True) 124 | 125 | def _initialize_segments(self): 126 | """Initialize segments, find existing ones, and determine the active segment.""" 127 | segments = self._get_all_segments() 128 | 129 | if not segments: 130 | # No segments yet, create the first one 131 | self.active_segment_path = self._create_segment_path(1) 132 | with open(self.active_segment_path, "w"): 133 | pass # Create empty file 134 | else: 135 | # Find the highest segment 136 | latest_segment = segments[-1] 137 | self.active_segment_path = latest_segment 138 | 139 | self._load_all_entries() 140 | 141 | def _get_all_segments(self) -> List[str]: 142 | """Get all segment files sorted by segment number.""" 143 | segment_pattern = os.path.join(self.log_dir, f"{self.base_name}.segment.*") 144 | segments = sorted(glob.glob(segment_pattern), key=self._extract_segment_number) 145 | return segments 146 | 147 | def _extract_segment_number(self, segment_path: str) -> int: 148 | """Extract the segment number from a segment file path.""" 149 | try: 150 | return int(segment_path.split(".")[-1]) 151 | except (ValueError, IndexError): 152 | return 0 153 | 154 | def _create_segment_path(self, segment_number: int) -> str: 155 | """Create a path for a new segment file with the given number.""" 156 | return os.path.join(self.log_dir, f"{self.base_name}.segment.{segment_number}") 157 | 158 | def _load_all_entries(self): 159 | """Load all entries from all segments to populate existing IDs and determine current ID.""" 160 | segments = self._get_all_segments() 161 | for segment in segments: 162 | self._load_entries_from_file(segment) 163 | 164 | def _load_entries_from_file(self, file_path: str): 165 | """Load entries from a specific file.""" 166 | try: 167 | with open(file_path, "r") as f: 168 | self._process_log_entries(f) 169 | except FileNotFoundError: 170 | pass 171 | 172 | def _process_log_entries(self, file_handle): 173 | """Process each line in the log file to extract entry IDs.""" 174 | for line in file_handle: 175 | self._process_log_entry(line) 176 | 177 | def _process_log_entry(self, line): 178 | """Process a single log entry line and update tracking data.""" 179 | try: 180 | entry = json.loads(line) 181 | entry_id = entry["id"] 182 | 183 | if not self._is_valid_entry(entry, entry_id): 184 | return 185 | 186 | self._update_tracking_data(entry_id) 187 | except (json.JSONDecodeError, KeyError): 188 | pass 189 | 190 | def _is_valid_entry(self, entry, entry_id): 191 | """Check if an entry has valid CRC.""" 192 | if "crc" not in entry: 193 | return True 194 | 195 | log_entry = LogEntry(**entry) 196 | if not log_entry.validate_crc(): 197 | print(f"Warning: Entry with ID {entry_id} has invalid CRC, skipping") 198 | return False 199 | return True 200 | 201 | def _update_tracking_data(self, entry_id): 202 | """Update tracking data for a valid entry.""" 203 | self.existing_ids.add(entry_id) 204 | if entry_id > self.current_id: 205 | self.current_id = entry_id 206 | 207 | def _get_next_segment_number(self) -> int: 208 | """Get the next segment number based on the active segment.""" 209 | current_segment_num = self._extract_segment_number(self.active_segment_path) 210 | return current_segment_num + 1 211 | 212 | def _roll_segment_if_needed(self): 213 | """Roll over to a new segment file if the current one exceeds the size limit.""" 214 | try: 215 | if os.path.exists(self.active_segment_path) and os.path.getsize(self.active_segment_path) >= self.max_segment_size: 216 | next_segment_num = self._get_next_segment_number() 217 | self.active_segment_path = self._create_segment_path(next_segment_num) 218 | # Create the new empty segment file 219 | with open(self.active_segment_path, "w"): 220 | pass 221 | print(f"Rolled over to new segment: {self.active_segment_path}") 222 | except OSError: 223 | # If there's an issue checking the file size, just continue with the current segment 224 | pass 225 | 226 | def append(self, operation: OperationType, key: str, value: Optional[Any] = None, version: Optional[int] = None) -> LogEntry: 227 | self.current_id += 1 228 | entry = LogEntry(id=self.current_id, operation=operation, key=key, value=value, version=version) 229 | # Calculate and set CRC 230 | entry.crc = entry.calculate_crc() 231 | return self.append_entry(entry) 232 | 233 | def append_entry(self, entry: LogEntry) -> LogEntry: 234 | """Append a pre-created entry, used for replication""" 235 | # Skip if entry already exists 236 | if entry.id in self.existing_ids: 237 | return entry 238 | 239 | # Update current_id if needed 240 | if entry.id > self.current_id: 241 | self.current_id = entry.id 242 | 243 | # Ensure entry has valid CRC 244 | if entry.crc is None: 245 | entry.crc = entry.calculate_crc() 246 | elif not entry.validate_crc(): 247 | # Recalculate CRC if invalid 248 | entry.crc = entry.calculate_crc() 249 | 250 | # Check if we need to roll over to a new segment 251 | self._roll_segment_if_needed() 252 | 253 | # Append entry to the active segment 254 | with open(self.active_segment_path, "a") as f: 255 | f.write(json.dumps(entry.model_dump()) + "\n") 256 | 257 | self.existing_ids.add(entry.id) 258 | return entry 259 | 260 | def has_entry(self, entry_id: int) -> bool: 261 | """Check if an entry with the given ID already exists in the WAL""" 262 | return entry_id in self.existing_ids 263 | 264 | def read_from(self, start_id: int = 0) -> list[LogEntry]: 265 | """Read log entries with ID >= start_id from all segments.""" 266 | entries = [] 267 | 268 | # Get all segments 269 | segments = self._get_all_segments() 270 | 271 | # Process each segment 272 | for segment in segments: 273 | try: 274 | self._append_entries(entries, segment, start_id) 275 | except FileNotFoundError: 276 | continue 277 | 278 | # Sort entries by ID to ensure correct order 279 | entries.sort(key=lambda e: e.id) 280 | return entries 281 | 282 | def _append_entries(self, entries, segment, start_id): 283 | with open(segment, "r") as f: 284 | for line in f: 285 | try: 286 | self._append_single_entry(entries, line, start_id) 287 | except (json.JSONDecodeError, ValueError) as e: 288 | print(f"Error parsing log entry: {str(e)}") 289 | continue 290 | 291 | def _append_single_entry(self, entries, line, start_id): 292 | entry = self._parse_log_entry(line) 293 | if entry and self._should_include_entry(entry, start_id): 294 | entries.append(entry) 295 | 296 | def _parse_log_entry(self, line: str) -> Optional[LogEntry]: 297 | """Parse a log entry from a line in the log file.""" 298 | try: 299 | entry_dict = json.loads(line) 300 | entry = LogEntry(**entry_dict) 301 | 302 | # Skip entries with invalid CRC 303 | if not entry.validate_crc(): 304 | print(f"Warning: Skipping entry with ID {entry.id} due to CRC validation failure") 305 | return None 306 | 307 | return entry 308 | except (json.JSONDecodeError, ValueError) as e: 309 | print(f"Error parsing log entry: {str(e)}") 310 | return None 311 | 312 | def _should_include_entry(self, entry: LogEntry, start_id: int) -> bool: 313 | """Check if an entry should be included based on its ID.""" 314 | return entry.id >= start_id 315 | 316 | def get_last_id(self) -> int: 317 | return self.current_id 318 | 319 | def get_segment_files(self) -> List[str]: 320 | """Get a list of all segment files.""" 321 | return self._get_all_segments() 322 | 323 | def get_active_segment(self) -> str: 324 | """Get the path of the currently active segment.""" 325 | return self.active_segment_path 326 | 327 | def compact_segments(self) -> Tuple[int, int]: 328 | """Compact segments by keeping only the latest operation for each key. 329 | 330 | Returns: 331 | Tuple containing: 332 | - Number of segments compacted 333 | - Number of entries removed 334 | """ 335 | segments = self._get_segments_for_compaction() 336 | if not segments: 337 | return 0, 0 338 | 339 | # Read all entries from segments marked for compaction 340 | entries = self._read_entries_from_segments(segments) 341 | if not entries: 342 | return 0, 0 343 | 344 | # Filter to get only the latest operation for each key 345 | entries_to_keep = self._filter_latest_entries(entries) 346 | 347 | # Calculate entries removed 348 | entries_removed = len(entries) - len(entries_to_keep) 349 | 350 | # Write the compacted entries and handle cleanup 351 | self._write_compacted_entries(segments, entries_to_keep) 352 | 353 | return len(segments), entries_removed 354 | 355 | def _get_segments_for_compaction(self) -> List[str]: 356 | """Get segments that should be compacted (all except the active one)""" 357 | segments = self._get_all_segments() 358 | 359 | # Don't compact if we have only one segment (which is the active one) 360 | if len(segments) <= 1: 361 | return [] 362 | 363 | # The last segment is the active one, we won't compact it 364 | return segments[:-1] 365 | 366 | def _read_entries_from_segments(self, segments: List[str]) -> List[LogEntry]: 367 | """Read all entries from the given segments""" 368 | entries = [] 369 | for segment in segments: 370 | segment_entries = self._read_entries_from_segment(segment) 371 | entries.extend(segment_entries) 372 | return entries 373 | 374 | def _filter_latest_entries(self, entries: List[LogEntry]) -> List[LogEntry]: 375 | """Filter entries to keep only the latest operation for each key""" 376 | key_to_latest_entry: Dict[str, LogEntry] = {} 377 | 378 | # Track latest operation for each key 379 | for entry in sorted(entries, key=lambda e: e.id): 380 | if entry.operation in [OperationType.SET, OperationType.DELETE]: 381 | key_to_latest_entry[entry.key] = entry 382 | 383 | # Extract the values and sort by ID 384 | result = list(key_to_latest_entry.values()) 385 | result.sort(key=lambda e: e.id) 386 | return result 387 | 388 | def _write_compacted_entries(self, segments_to_remove: List[str], entries: List[LogEntry]) -> None: 389 | """Write compacted entries to a new segment and clean up old segments""" 390 | # Create a new compacted segment 391 | compacted_segment_path = self._create_compacted_segment() 392 | 393 | # Write entries to the compacted segment 394 | with open(compacted_segment_path, "w") as f: 395 | for entry in entries: 396 | f.write(json.dumps(entry.model_dump()) + "\n") 397 | 398 | # Delete old segments after successful compaction 399 | self._delete_segments(segments_to_remove) 400 | 401 | # Update segment numbers to be continuous 402 | self._renumber_segments() 403 | 404 | def _delete_segments(self, segments: List[str]) -> None: 405 | """Delete the given segment files""" 406 | for segment in segments: 407 | try: 408 | os.remove(segment) 409 | except OSError as e: 410 | print(f"Error removing segment {segment}: {e}") 411 | 412 | def _read_entries_from_segment(self, segment_path: str) -> List[LogEntry]: 413 | """Read all entries from a segment file.""" 414 | if not self._is_valid_segment_file(segment_path): 415 | return [] 416 | 417 | entries = [] 418 | for line in self._read_segment_lines(segment_path): 419 | entry = self._parse_entry_from_line(line) 420 | if entry: 421 | entries.append(entry) 422 | 423 | return entries 424 | 425 | def _is_valid_segment_file(self, segment_path: str) -> bool: 426 | """Check if the segment file exists and is readable.""" 427 | return os.path.exists(segment_path) 428 | 429 | def _read_segment_lines(self, segment_path: str) -> List[str]: 430 | """Read all lines from a segment file, handling errors.""" 431 | lines = [] 432 | try: 433 | with open(segment_path, "r") as f: 434 | lines = f.readlines() 435 | except OSError: 436 | print(f"Error accessing segment file {segment_path}") 437 | return lines 438 | 439 | def _parse_entry_from_line(self, line: str) -> Optional[LogEntry]: 440 | """Parse a single entry from a line in the log file.""" 441 | try: 442 | entry_dict = json.loads(line) 443 | entry = LogEntry(**entry_dict) 444 | 445 | if entry.crc and not entry.validate_crc(): 446 | return None 447 | 448 | return entry 449 | except (json.JSONDecodeError, ValueError): 450 | # Skip invalid entries 451 | return None 452 | 453 | def _create_compacted_segment(self) -> str: 454 | """Create a new segment file for compacted entries.""" 455 | return os.path.join(self.log_dir, f"{self.base_name}.compacted.temp") 456 | 457 | def _renumber_segments(self): 458 | """Rename segments to ensure contiguous numbering after compaction.""" 459 | segments = self._get_all_segments() 460 | compacted_path = os.path.join(self.log_dir, f"{self.base_name}.compacted.temp") 461 | 462 | # Check if compacted file exists 463 | if not os.path.exists(compacted_path): 464 | return 465 | 466 | # Mark the active segment 467 | self._set_active_segment(segments) 468 | 469 | # Prepare for renumbering by renaming segments to temporary names 470 | self._prepare_segments_for_renumbering(segments) 471 | 472 | # Move the compacted file to be the first segment 473 | self._rename_compacted_to_first_segment(compacted_path) 474 | 475 | # Rename the remaining segments to have contiguous numbers 476 | self._rename_remaining_segments(segments) 477 | 478 | # Update the active segment path 479 | self._update_active_segment_path() 480 | 481 | def _set_active_segment(self, segments): 482 | """Mark the active segment (the last one).""" 483 | if segments: 484 | # The last segment is the active one 485 | self.active_segment_path = segments[-1] 486 | 487 | def _prepare_segments_for_renumbering(self, segments): 488 | """Rename existing segments to temporary names to avoid conflicts.""" 489 | for segment in segments[:-1]: # Skip the last (active) segment 490 | try: 491 | temp_path = f"{segment}.tmp" 492 | if os.path.exists(segment): 493 | os.rename(segment, temp_path) 494 | except OSError as e: 495 | print(f"Error renaming segment {segment}: {e}") 496 | 497 | def _rename_compacted_to_first_segment(self, compacted_path): 498 | """Rename the compacted file to be the first segment.""" 499 | try: 500 | new_first_segment = self._create_segment_path(1) 501 | os.rename(compacted_path, new_first_segment) 502 | except OSError as e: 503 | print(f"Error renaming compacted file: {e}") 504 | 505 | def _rename_remaining_segments(self, segments): 506 | """Rename remaining segments to have contiguous numbers.""" 507 | # Start from 2 because compacted file is now segment 1 508 | for i, segment in enumerate(segments[:-1], 2): 509 | try: 510 | temp_path = f"{segment}.tmp" 511 | if os.path.exists(temp_path): 512 | new_path = self._create_segment_path(i) 513 | os.rename(temp_path, new_path) 514 | except OSError as e: 515 | print(f"Error renaming segment {temp_path}: {e}") 516 | 517 | def _update_active_segment_path(self): 518 | """Update the active segment path to be the highest segment.""" 519 | segments = self._get_all_segments() 520 | if segments: 521 | self.active_segment_path = segments[-1] 522 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/entrypoints/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """Here you should define your CLI entrypoint""" 2 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/cli/run_follower.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import uvicorn 5 | 6 | from pydistributedkv.configurator.settings.config import FOLLOWER_CONFIGS 7 | 8 | if __name__ == "__main__": 9 | if len(sys.argv) != 2 or sys.argv[1] not in FOLLOWER_CONFIGS: 10 | print("Usage: python run_follower.py ") 11 | print(f"Available follower IDs: {list(FOLLOWER_CONFIGS.keys())}") 12 | sys.exit(1) 13 | 14 | follower_id = sys.argv[1] 15 | config = FOLLOWER_CONFIGS[follower_id] 16 | 17 | os.environ["WAL_PATH"] = config["wal_path"] 18 | os.environ["LEADER_URL"] = config["leader_url"] 19 | os.environ["FOLLOWER_ID"] = follower_id 20 | os.environ["FOLLOWER_URL"] = config["follower_url"] 21 | 22 | uvicorn.run("pydistributedkv.entrypoints.web.follower.follower:app", host=config["host"], port=config["port"], reload=True) 23 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/cli/run_leader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import uvicorn 4 | 5 | from pydistributedkv.configurator.settings.config import LEADER_CONFIG 6 | 7 | os.environ["WAL_PATH"] = LEADER_CONFIG["wal_path"] 8 | os.environ["LEADER_URL"] = LEADER_CONFIG["leader_url"] 9 | 10 | if __name__ == "__main__": 11 | uvicorn.run("pydistributedkv.entrypoints.web.leader.leader:app", host=LEADER_CONFIG["host"], port=LEADER_CONFIG["port"], reload=True) 12 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/main.py: -------------------------------------------------------------------------------- 1 | """Entrypoint file for using as main entrypoint to the application""" 2 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/web/__init__.py: -------------------------------------------------------------------------------- 1 | """Here you should define your Web entrypoint""" 2 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/web/follower/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/entrypoints/web/follower/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/web/follower/follower.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import Any, Dict, Optional, Tuple 5 | 6 | import requests 7 | from fastapi import FastAPI, HTTPException, Query 8 | 9 | from pydistributedkv.configurator.settings.base import API_TIMEOUT, compaction_interval, HEARTBEAT_INTERVAL, MAX_SEGMENT_SIZE 10 | from pydistributedkv.domain.models import ClientRequest, LogEntry, OperationType, ReplicationRequest, WAL 11 | from pydistributedkv.service.compaction import LogCompactionService 12 | from pydistributedkv.service.heartbeat import HeartbeatService 13 | from pydistributedkv.service.request_deduplication import RequestDeduplicationService 14 | from pydistributedkv.service.storage import KeyValueStorage 15 | 16 | # Configure logging 17 | logging.basicConfig(level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | app = FastAPI() 21 | 22 | # Initialize WAL and storage 23 | wal = WAL(os.getenv("WAL_PATH", "data/follower/wal.log"), max_segment_size=MAX_SEGMENT_SIZE) 24 | storage = KeyValueStorage(wal) 25 | 26 | # Initialize compaction service 27 | compaction_service = LogCompactionService(storage, compaction_interval=compaction_interval) 28 | 29 | # Request deduplication service 30 | request_deduplication = RequestDeduplicationService(service_name="follower") 31 | 32 | # Leader connection info 33 | leader_url = os.getenv("LEADER_URL", "http://localhost:8000") 34 | follower_id = os.getenv("FOLLOWER_ID", "follower-1") 35 | follower_url = os.getenv("FOLLOWER_URL", "http://localhost:8001") 36 | 37 | # Create heartbeat service 38 | heartbeat_service = HeartbeatService(service_name="follower", server_id=follower_id, server_url=follower_url) 39 | 40 | # Replication state 41 | last_applied_id = wal.get_last_id() # Initialize with the current last ID in WAL 42 | 43 | 44 | @app.on_event("startup") 45 | async def startup_event(): 46 | # Register the leader with the heartbeat service 47 | heartbeat_service.register_server("leader", leader_url) 48 | 49 | # Start heartbeat monitoring and sending 50 | await heartbeat_service.start_monitoring() 51 | await heartbeat_service.start_sending() 52 | 53 | # Start compaction service 54 | await compaction_service.start() 55 | 56 | # Register with leader 57 | try: 58 | response = requests.post( 59 | f"{leader_url}/register_follower", 60 | json={"id": follower_id, "url": follower_url, "last_applied_id": last_applied_id}, 61 | timeout=API_TIMEOUT, 62 | ) 63 | response_data = response.json() 64 | 65 | # If leader has entries we don't, fetch them 66 | leader_last_id = response_data.get("last_log_id", 0) 67 | if leader_last_id > last_applied_id: 68 | await sync_with_leader() 69 | except requests.RequestException as e: 70 | # In production, you'd implement retry logic 71 | logger.error(f"Failed to register with leader at {leader_url}: {str(e)}") 72 | 73 | 74 | @app.on_event("shutdown") 75 | async def shutdown_event(): 76 | # Stop heartbeat service 77 | await heartbeat_service.stop() 78 | 79 | # Stop compaction service 80 | await compaction_service.stop() 81 | 82 | logger.info("Follower server shutting down") 83 | 84 | 85 | async def sync_with_leader(): 86 | """Synchronize the follower with the leader by fetching and applying new log entries.""" 87 | global last_applied_id # This global declaration is necessary here 88 | try: 89 | entries = await fetch_entries_from_leader() 90 | if not entries: 91 | return 92 | 93 | new_entries = append_entries_to_wal(entries) 94 | if new_entries: 95 | last_applied_id = apply_entries_to_storage(new_entries) 96 | except requests.RequestException: 97 | print("Failed to sync with leader") 98 | 99 | 100 | async def fetch_entries_from_leader() -> list[LogEntry]: 101 | """Fetch new log entries from the leader.""" 102 | response = requests.get(f"{leader_url}/log_entries/{last_applied_id}", timeout=API_TIMEOUT) 103 | data = response.json() 104 | return _parse_and_validate_entries(data.get("entries", []), source="leader") 105 | 106 | 107 | def _parse_and_validate_entries(entry_data_list: list[dict], source: str = "") -> list[LogEntry]: 108 | """Parse and validate log entries from the provided data.""" 109 | valid_entries = [] 110 | 111 | for entry_data in entry_data_list: 112 | entry = _create_valid_entry(entry_data, source) 113 | if entry: 114 | valid_entries.append(entry) 115 | 116 | return valid_entries 117 | 118 | 119 | def _create_valid_entry(entry_data: dict, source: str = "") -> LogEntry | None: 120 | """Create and validate a single log entry.""" 121 | try: 122 | entry = LogEntry(**entry_data) 123 | if not entry.validate_crc(): 124 | print(f"Warning: Received entry with ID {entry.id} with invalid CRC from {source}") 125 | return None 126 | return entry 127 | except ValueError as e: 128 | print(f"Error parsing entry from {source}: {str(e)}") 129 | return None 130 | 131 | 132 | def append_entries_to_wal(entries: list[LogEntry]) -> list[LogEntry]: 133 | """Append new entries to the WAL and return only the newly added ones.""" 134 | new_entries = [] 135 | for entry in entries: 136 | # Ensure we only add entries with valid CRC 137 | if entry.validate_crc() and not wal.has_entry(entry.id): 138 | wal.append_entry(entry) 139 | new_entries.append(entry) 140 | return new_entries 141 | 142 | 143 | def apply_entries_to_storage(entries: list[LogEntry]) -> int: 144 | """Apply entries to storage and return the last applied ID.""" 145 | return storage.apply_entries(entries) 146 | 147 | 148 | @app.post("/replicate") 149 | async def replicate(req: ReplicationRequest): 150 | global last_applied_id # This global declaration is necessary here 151 | 152 | # Use the existing helper to parse and validate entries 153 | entries = _parse_and_validate_entries(req.entries, source="replication request") 154 | 155 | # Process and apply new entries 156 | new_entries = _process_new_entries(entries) 157 | 158 | # Update the last applied ID if we have new entries 159 | if new_entries: 160 | last_id = storage.apply_entries(new_entries) 161 | last_applied_id = max(last_applied_id, last_id) 162 | 163 | return {"status": "ok", "last_applied_id": last_applied_id} 164 | 165 | 166 | def _process_new_entries(entries: list[LogEntry]) -> list[LogEntry]: 167 | """Process and store only entries that don't exist in the WAL.""" 168 | if not entries: 169 | return [] 170 | 171 | new_entries = [] 172 | for entry in entries: 173 | if not wal.has_entry(entry.id): 174 | wal.append_entry(entry) 175 | new_entries.append(entry) 176 | return new_entries 177 | 178 | 179 | @app.get("/key/{key}") 180 | def get_key( 181 | key: str, 182 | version: Optional[int] = Query(None, description="Specific version to retrieve"), 183 | client_id: Optional[str] = Query(None), 184 | request_id: Optional[str] = Query(None), 185 | ): 186 | """Handle GET request for a specific key with deduplication support""" 187 | # Check for cached response from duplicate request 188 | cached_response = _check_request_cache(client_id, request_id, key, OperationType.GET) 189 | if cached_response: 190 | return cached_response 191 | 192 | # Get the value and prepare response 193 | result = _get_value_from_storage(key, version) 194 | if result is None: 195 | response = {"status": "error", "message": f"Key not found: {key}"} 196 | if version: 197 | response["message"] = f"Key not found or version {version} not available: {key}" 198 | _cache_response_if_needed(client_id, request_id, key, OperationType.GET, response) 199 | raise HTTPException(status_code=404, detail=response["message"]) 200 | 201 | value, actual_version = result 202 | 203 | # Create success response 204 | response = {"key": key, "value": value, "version": actual_version} 205 | 206 | # Cache the response if client tracking is enabled 207 | _cache_response_if_needed(client_id, request_id, key, OperationType.GET, response) 208 | 209 | return response 210 | 211 | 212 | def _check_request_cache(client_id: Optional[str], request_id: Optional[str], key: str, operation: OperationType) -> Optional[Dict]: 213 | """Check if this is a duplicate request with a cached response""" 214 | if not client_id or not request_id: 215 | logger.info(f"GET request for key={key} (no client ID)") 216 | return None 217 | 218 | logger.info(f"GET request for key={key} from client={client_id}, request={request_id}") 219 | previous_response = request_deduplication.get_processed_result(client_id, request_id, operation) 220 | 221 | if previous_response is not None: 222 | logger.info(f"✅ Returning cached response for GET key={key}, client={client_id}, request={request_id}") 223 | return previous_response 224 | 225 | return None 226 | 227 | 228 | def _get_value_from_storage(key: str, version: Optional[int] = None) -> Optional[Tuple[Any, int]]: 229 | """Get a value and its version from storage""" 230 | return storage.get_with_version(key, version) 231 | 232 | 233 | def _cache_response_if_needed( 234 | client_id: Optional[str], request_id: Optional[str], key: str, operation: OperationType, response: Dict 235 | ) -> None: 236 | """Cache the response if client tracking is enabled""" 237 | if not client_id or not request_id: 238 | return 239 | 240 | client_request = ClientRequest(client_id=client_id, request_id=request_id, operation=operation, key=key) 241 | request_deduplication.mark_request_processed(client_request, response) 242 | 243 | status = "error" if "status" in response and response["status"] == "error" else "success" 244 | logger.info(f"Cached {status} response for GET key={key}, client={client_id}, request={request_id}") 245 | 246 | 247 | @app.get("/status") 248 | def get_status(): 249 | return {"follower_id": follower_id, "last_applied_id": last_applied_id, "leader_url": leader_url} 250 | 251 | 252 | @app.get("/segments") 253 | def get_segments(): 254 | """Return information about the WAL segments""" 255 | segments = wal.get_segment_files() 256 | active_segment = wal.get_active_segment() 257 | 258 | segment_info = [] 259 | for segment in segments: 260 | try: 261 | size = os.path.getsize(segment) 262 | segment_info.append({"path": segment, "size": size, "is_active": segment == active_segment}) 263 | except FileNotFoundError: 264 | pass 265 | 266 | return {"segments": segment_info, "total_segments": len(segment_info), "max_segment_size": MAX_SEGMENT_SIZE} 267 | 268 | 269 | @app.get("/keys") 270 | def get_all_keys(): 271 | """Return all keys in the storage""" 272 | keys = storage.get_all_keys() 273 | return {"keys": keys, "count": len(keys)} 274 | 275 | 276 | @app.get("/request_status") 277 | def get_request_status(client_id: str, request_id: str, operation: Optional[str] = Query(None)): 278 | """Check if a client request has been processed""" 279 | logger.info(f"Checking status for client={client_id}, request={request_id}, operation={operation}") 280 | # Pass operation type to get_processed_result if provided 281 | result = request_deduplication.get_processed_result(client_id, request_id, operation) 282 | if result: 283 | logger.info(f"Found cached result for client={client_id}, request={request_id}, operation={operation}") 284 | return {"processed": True, "result": result} 285 | else: 286 | logger.info(f"No cached result found for client={client_id}, request={request_id}, operation={operation}") 287 | return {"processed": False} 288 | 289 | 290 | @app.get("/deduplication_stats") 291 | def get_deduplication_stats(): 292 | """Return statistics about the request deduplication service""" 293 | stats = request_deduplication.get_stats() 294 | logger.info(f"Returning deduplication stats: duplicates detected={stats['total_duplicates_detected']}") 295 | return stats 296 | 297 | 298 | @app.post("/heartbeat") 299 | def receive_heartbeat(data: dict): 300 | """Handle heartbeat from leader or other servers""" 301 | server_id = data.get("server_id") 302 | timestamp = data.get("timestamp", time.time()) 303 | 304 | if not server_id: 305 | return {"status": "error", "message": "Missing server_id"} 306 | 307 | heartbeat_service.record_heartbeat(server_id) 308 | logger.debug(f"Received heartbeat from {server_id} at {timestamp}") 309 | 310 | return {"status": "ok", "server_id": follower_id, "timestamp": time.time()} 311 | 312 | 313 | @app.get("/cluster_status") 314 | def get_cluster_status(): 315 | """Get status of the leader from this follower's perspective""" 316 | leader_status = heartbeat_service.get_server_status("leader") 317 | 318 | return { 319 | "follower": {"id": follower_id, "url": follower_url, "status": "healthy"}, # Follower always reports itself as healthy 320 | "leader": leader_status, 321 | "heartbeat_interval": HEARTBEAT_INTERVAL, 322 | } 323 | 324 | 325 | @app.get("/key/{key}/history") 326 | def get_key_history(key: str): 327 | """Get the version history of a key""" 328 | history = storage.get_version_history(key) 329 | 330 | if history is None: 331 | raise HTTPException(status_code=404, detail=f"Key not found: {key}") 332 | 333 | versions = sorted(history.keys()) 334 | 335 | return {"key": key, "versions": versions, "history": [{"version": v, "value": history[v]} for v in versions]} 336 | 337 | 338 | @app.get("/key/{key}/versions") 339 | def get_key_versions(key: str): 340 | """Get available versions for a key""" 341 | history = storage.get_version_history(key) 342 | 343 | if history is None: 344 | raise HTTPException(status_code=404, detail=f"Key not found: {key}") 345 | 346 | return {"key": key, "versions": sorted(history.keys()), "latest_version": storage.get_latest_version(key)} 347 | 348 | 349 | @app.post("/compaction/run") 350 | async def run_compaction(force: bool = Query(False, description="Force compaction even if minimum interval hasn't passed")): 351 | """Manually trigger log compaction""" 352 | try: 353 | segments_compacted, entries_removed = await compaction_service.run_compaction(force=force) 354 | return { 355 | "status": "success", 356 | "segments_compacted": segments_compacted, 357 | "entries_removed": entries_removed, 358 | } 359 | except Exception as e: 360 | logger.error(f"Error during manual compaction: {str(e)}") 361 | raise HTTPException(status_code=500, detail=f"Compaction error: {str(e)}") from e 362 | 363 | 364 | @app.get("/compaction/status") 365 | def get_compaction_status(): 366 | """Get the status of the log compaction service""" 367 | return compaction_service.get_status() 368 | 369 | 370 | @app.post("/compaction/configure") 371 | def configure_compaction( 372 | enabled: Optional[bool] = Query(None, description="Enable or disable compaction"), 373 | interval: Optional[int] = Query(None, description="Compaction interval in seconds"), 374 | ): 375 | """Configure the log compaction service""" 376 | changes = {} 377 | 378 | if enabled is not None: 379 | result = compaction_service.set_enabled(enabled) 380 | changes["enabled"] = result 381 | 382 | if interval is not None: 383 | result = compaction_service.set_compaction_interval(interval) 384 | changes["interval"] = result 385 | 386 | return {"status": "success", "changes": changes} 387 | -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/web/leader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/entrypoints/web/leader/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/entrypoints/web/leader/leader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import Any, Dict, Optional, Tuple 5 | 6 | import requests 7 | from fastapi import FastAPI, HTTPException, Query 8 | 9 | from pydistributedkv.configurator.settings.base import API_TIMEOUT, compaction_interval, HEARTBEAT_INTERVAL, MAX_SEGMENT_SIZE 10 | from pydistributedkv.domain.models import ClientRequest, FollowerRegistration, KeyValue, OperationType, WAL 11 | from pydistributedkv.service.compaction import LogCompactionService 12 | from pydistributedkv.service.heartbeat import HeartbeatService 13 | from pydistributedkv.service.request_deduplication import RequestDeduplicationService 14 | from pydistributedkv.service.storage import KeyValueStorage 15 | 16 | # Configure logging 17 | logging.basicConfig(level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | app = FastAPI() 21 | 22 | # Initialize WAL and storage 23 | wal = WAL(os.getenv("WAL_PATH", "data/leader/wal.log"), max_segment_size=MAX_SEGMENT_SIZE) 24 | storage = KeyValueStorage(wal) 25 | 26 | # Initialize compaction service 27 | compaction_service = LogCompactionService(storage, compaction_interval=compaction_interval) 28 | 29 | # Request deduplication service 30 | request_deduplication = RequestDeduplicationService(service_name="leader") 31 | 32 | # Track followers and their replication status 33 | followers: dict[str, str] = {} # follower_id -> url 34 | replication_status: dict[str, int] = {} # follower_id -> last_replicated_id 35 | 36 | # Create heartbeat service 37 | leader_id = "leader" 38 | leader_url = os.getenv("LEADER_URL", "http://localhost:8000") 39 | heartbeat_service = HeartbeatService(service_name="leader", server_id=leader_id, server_url=leader_url) 40 | 41 | 42 | @app.on_event("startup") 43 | async def startup_event(): 44 | # Start heartbeat monitoring and sending tasks 45 | await heartbeat_service.start_monitoring() 46 | await heartbeat_service.start_sending() 47 | 48 | # Start compaction service 49 | await compaction_service.start() 50 | 51 | logger.info("Leader server started with heartbeat and compaction services") 52 | 53 | 54 | @app.on_event("shutdown") 55 | async def shutdown_event(): 56 | # Stop heartbeat service 57 | await heartbeat_service.stop() 58 | 59 | # Stop compaction service 60 | await compaction_service.stop() 61 | 62 | logger.info("Leader server shutting down") 63 | 64 | 65 | @app.get("/key/{key}") 66 | def get_key( 67 | key: str, 68 | version: Optional[int] = Query(None, description="Specific version to retrieve"), 69 | client_id: Optional[str] = Query(None), 70 | request_id: Optional[str] = Query(None), 71 | ): 72 | """Handle GET request for a specific key with deduplication support""" 73 | # Check for cached response if client tracking is enabled 74 | cached_response = _check_request_cache(client_id, request_id, key, OperationType.GET) 75 | if cached_response: 76 | return cached_response 77 | 78 | # Get value from storage and handle errors 79 | result = _get_value_from_storage(key, version) 80 | if result is None: 81 | response = {"status": "error", "message": f"Key not found: {key}"} 82 | if version: 83 | response["message"] = f"Key not found or version {version} not available: {key}" 84 | _cache_response_if_needed(client_id, request_id, key, OperationType.GET, response) 85 | raise HTTPException(status_code=404, detail=response["message"]) 86 | 87 | value, actual_version = result 88 | 89 | # Create success response 90 | response = {"key": key, "value": value, "version": actual_version} 91 | 92 | # Cache the response if client tracking is enabled 93 | _cache_response_if_needed(client_id, request_id, key, OperationType.GET, response) 94 | 95 | return response 96 | 97 | 98 | def _check_request_cache(client_id: Optional[str], request_id: Optional[str], key: str, operation: OperationType) -> Optional[Dict]: 99 | """Check if this is a duplicate request with a cached response""" 100 | if not client_id or not request_id: 101 | logger.info(f"GET request for key={key} (no client ID)") 102 | return None 103 | 104 | logger.info(f"GET request for key={key} from client={client_id}, request={request_id}") 105 | previous_response = request_deduplication.get_processed_result(client_id, request_id, operation) 106 | 107 | if previous_response is not None: 108 | logger.info(f"✅ Returning cached response for GET key={key}, client={client_id}, request={request_id}") 109 | return previous_response 110 | 111 | return None 112 | 113 | 114 | def _get_value_from_storage(key: str, version: Optional[int] = None) -> Optional[Tuple[Any, int]]: 115 | """Get a value and its version from storage""" 116 | return storage.get_with_version(key, version) 117 | 118 | 119 | def _cache_response_if_needed( 120 | client_id: Optional[str], request_id: Optional[str], key: str, operation: OperationType, response: Dict, value: Any = None 121 | ) -> None: 122 | """Cache the response if client tracking is enabled""" 123 | if not client_id or not request_id: 124 | return 125 | 126 | client_request = ClientRequest(client_id=client_id, request_id=request_id, operation=operation, key=key, value=value) 127 | request_deduplication.mark_request_processed(client_request, response) 128 | 129 | status_type = "error" if "status" in response and response["status"] == "error" else "success" 130 | logger.info(f"Cached {status_type} response for {operation.name} key={key}, client={client_id}, request={request_id}") 131 | 132 | 133 | @app.put("/key/{key}") 134 | def set_key(key: str, kv: KeyValue, client_id: Optional[str] = Query(None), request_id: Optional[str] = Query(None)): 135 | """Handle PUT request to set a specific key with deduplication support""" 136 | # Check for cached response if client tracking is enabled 137 | cached_response = _check_request_cache(client_id, request_id, key, OperationType.SET) 138 | if cached_response: 139 | return cached_response 140 | 141 | # Process the request and get the resulting entry 142 | entry, version = _process_set_key_request(key, kv.value, kv.version) 143 | 144 | # Handle version conflict 145 | if entry is None: 146 | response = { 147 | "status": "error", 148 | "message": f"Version conflict: provided version {kv.version} is outdated", 149 | "current_version": version, 150 | } 151 | _cache_response_if_needed(client_id, request_id, key, OperationType.SET, response) 152 | raise HTTPException(status_code=409, detail=response["message"]) 153 | 154 | # Create and cache the response 155 | response = {"status": "ok", "id": entry.id, "key": key, "version": version} 156 | _cache_response_if_needed(client_id, request_id, key, OperationType.SET, response, kv.value) 157 | 158 | return response 159 | 160 | 161 | def _process_set_key_request(key: str, value: Any, version: Optional[int] = None): 162 | """Process the key-value storage operation and handle replication""" 163 | # Store the value 164 | entry, actual_version = storage.set(key, value, version) 165 | 166 | if entry is None: 167 | return None, actual_version 168 | 169 | logger.info(f"Added SET entry id={entry.id} for key={key}, version={actual_version}") 170 | 171 | # Replicate to followers asynchronously 172 | _replicate_to_followers(entry) 173 | 174 | return entry, actual_version 175 | 176 | 177 | @app.delete("/key/{key}") 178 | def delete_key(key: str, client_id: Optional[str] = Query(None), request_id: Optional[str] = Query(None)): 179 | """Handle DELETE request for a specific key with deduplication support""" 180 | # Check for cached response if client tracking is enabled 181 | cached_response = _check_request_cache(client_id, request_id, key, OperationType.DELETE) 182 | if cached_response: 183 | return cached_response 184 | 185 | # Process the delete request 186 | entry, status_code, error_msg = _process_delete_request(key) 187 | 188 | # Handle error case 189 | if status_code != 200: 190 | response = {"status": "error", "message": error_msg} 191 | _cache_response_if_needed(client_id, request_id, key, OperationType.DELETE, response) 192 | raise HTTPException(status_code=status_code, detail=error_msg) 193 | 194 | # Create success response 195 | response = {"status": "ok", "id": entry.id} 196 | 197 | # Cache the response if client tracking is enabled 198 | _cache_response_if_needed(client_id, request_id, key, OperationType.DELETE, response) 199 | 200 | return response 201 | 202 | 203 | def _process_delete_request(key: str) -> Tuple[Any, int, Optional[str]]: 204 | """Process the key deletion operation and handle replication""" 205 | # Delete the key from storage 206 | entry = storage.delete(key) 207 | 208 | if entry is None: 209 | error_msg = f"Key not found: {key}" 210 | logger.warning(error_msg) 211 | return None, 404, error_msg 212 | 213 | logger.info(f"Added DELETE entry id={entry.id} for key={key}") 214 | 215 | # Replicate to followers asynchronously 216 | _replicate_to_followers(entry) 217 | 218 | return entry, 200, None 219 | 220 | 221 | def _replicate_to_followers(entry): 222 | """Helper method to replicate an entry to all followers""" 223 | # Only replicate to healthy followers 224 | healthy_followers = heartbeat_service.get_healthy_servers() 225 | 226 | for follower_id, follower_url in healthy_followers.items(): 227 | try: 228 | logger.info(f"Replicating entry id={entry.id} to follower {follower_id}") 229 | requests.post( 230 | f"{follower_url}/replicate", 231 | json={"entries": [entry.model_dump()]}, 232 | timeout=API_TIMEOUT, 233 | ) 234 | replication_status[follower_id] = entry.id 235 | except requests.RequestException as e: 236 | logger.error(f"Failed to replicate entry id={entry.id} to follower {follower_id}: {str(e)}") 237 | # In production, you'd want better error handling and retry logic 238 | 239 | 240 | @app.post("/register_follower") 241 | def register_follower(follower_data: FollowerRegistration): 242 | follower_id = follower_data.id 243 | follower_url = follower_data.url 244 | last_applied_id = follower_data.last_applied_id 245 | 246 | followers[follower_id] = follower_url 247 | replication_status[follower_id] = last_applied_id 248 | 249 | # Register follower with heartbeat service 250 | heartbeat_service.register_server(follower_id, follower_url) 251 | 252 | return {"status": "ok", "last_log_id": wal.get_last_id()} 253 | 254 | 255 | @app.get("/log_entries/{last_id}") 256 | def get_log_entries(last_id: int): 257 | entries = wal.read_from(last_id + 1) 258 | return {"entries": [entry.model_dump() for entry in entries]} 259 | 260 | 261 | @app.get("/follower_status") 262 | def get_follower_status(): 263 | return { 264 | "followers": [{"id": f_id, "url": url, "last_replicated_id": replication_status.get(f_id, 0)} for f_id, url in followers.items()] 265 | } 266 | 267 | 268 | @app.get("/segments") 269 | def get_segments(): 270 | """Return information about the WAL segments""" 271 | segments = wal.get_segment_files() 272 | active_segment = wal.get_active_segment() 273 | 274 | segment_info = [] 275 | for segment in segments: 276 | try: 277 | size = os.path.getsize(segment) 278 | segment_info.append({"path": segment, "size": size, "is_active": segment == active_segment}) 279 | except FileNotFoundError: 280 | pass 281 | 282 | return {"segments": segment_info, "total_segments": len(segment_info), "max_segment_size": MAX_SEGMENT_SIZE} 283 | 284 | 285 | @app.get("/keys") 286 | def get_all_keys(): 287 | """Return all keys in the storage""" 288 | keys = storage.get_all_keys() 289 | return {"keys": keys, "count": len(keys)} 290 | 291 | 292 | @app.get("/request_status") 293 | def get_request_status(client_id: str, request_id: str, operation: Optional[str] = Query(None)): 294 | """Check if a client request has been processed""" 295 | logger.info(f"Checking status for client={client_id}, request={request_id}, operation={operation}") 296 | # Pass operation type to get_processed_result if provided 297 | result = request_deduplication.get_processed_result(client_id, request_id, operation) 298 | if result: 299 | logger.info(f"Found cached result for client={client_id}, request={request_id}, operation={operation}") 300 | return {"processed": True, "result": result} 301 | else: 302 | logger.info(f"No cached result found for client={client_id}, request={request_id}, operation={operation}") 303 | return {"processed": False} 304 | 305 | 306 | @app.get("/deduplication_stats") 307 | def get_deduplication_stats(): 308 | """Return statistics about the request deduplication service""" 309 | stats = request_deduplication.get_stats() 310 | logger.info(f"Returning deduplication stats: duplicates detected={stats['total_duplicates_detected']}") 311 | return stats 312 | 313 | 314 | @app.post("/heartbeat") 315 | def receive_heartbeat(data: dict): 316 | """Handle heartbeat from followers""" 317 | server_id = data.get("server_id") 318 | timestamp = data.get("timestamp", time.time()) 319 | 320 | if not server_id: 321 | return {"status": "error", "message": "Missing server_id"} 322 | 323 | heartbeat_service.record_heartbeat(server_id) 324 | logger.debug(f"Received heartbeat from {server_id} at {timestamp}") 325 | 326 | return {"status": "ok", "server_id": leader_id, "timestamp": time.time()} 327 | 328 | 329 | @app.get("/cluster_status") 330 | def get_cluster_status(): 331 | """Get status of all servers in the cluster""" 332 | return { 333 | "leader": {"id": leader_id, "url": leader_url, "status": "healthy"}, # Leader always reports itself as healthy 334 | "followers": heartbeat_service.get_all_statuses(), 335 | "heartbeat_interval": HEARTBEAT_INTERVAL, 336 | } 337 | 338 | 339 | @app.get("/key/{key}/history") 340 | def get_key_history(key: str): 341 | """Get the version history of a key""" 342 | history = storage.get_version_history(key) 343 | 344 | if history is None: 345 | raise HTTPException(status_code=404, detail=f"Key not found: {key}") 346 | 347 | versions = sorted(history.keys()) 348 | 349 | return {"key": key, "versions": versions, "history": [{"version": v, "value": history[v]} for v in versions]} 350 | 351 | 352 | @app.get("/key/{key}/versions") 353 | def get_key_versions(key: str): 354 | """Get available versions for a key""" 355 | history = storage.get_version_history(key) 356 | 357 | if history is None: 358 | raise HTTPException(status_code=404, detail=f"Key not found: {key}") 359 | 360 | return {"key": key, "versions": sorted(history.keys()), "latest_version": storage.get_latest_version(key)} 361 | 362 | 363 | @app.post("/compaction/run") 364 | async def run_compaction(force: bool = Query(False, description="Force compaction even if minimum interval hasn't passed")): 365 | """Manually trigger log compaction""" 366 | try: 367 | segments_compacted, entries_removed = await compaction_service.run_compaction(force=force) 368 | return { 369 | "status": "success", 370 | "segments_compacted": segments_compacted, 371 | "entries_removed": entries_removed, 372 | } 373 | except Exception as e: 374 | logger.error(f"Error during manual compaction: {str(e)}") 375 | raise HTTPException(status_code=500, detail=f"Compaction error: {str(e)}") from e 376 | 377 | 378 | @app.get("/compaction/status") 379 | def get_compaction_status(): 380 | """Get the status of the log compaction service""" 381 | return compaction_service.get_status() 382 | 383 | 384 | @app.post("/compaction/configure") 385 | def configure_compaction( 386 | enabled: Optional[bool] = Query(None, description="Enable or disable compaction"), 387 | interval: Optional[int] = Query(None, description="Compaction interval in seconds"), 388 | ): 389 | """Configure the log compaction service""" 390 | changes = {} 391 | 392 | if enabled is not None: 393 | result = compaction_service.set_enabled(enabled) 394 | changes["enabled"] = result 395 | 396 | if interval is not None: 397 | result = compaction_service.set_compaction_interval(interval) 398 | changes["interval"] = result 399 | 400 | return {"status": "success", "changes": changes} 401 | -------------------------------------------------------------------------------- /src/pydistributedkv/service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/service/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/service/compaction.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | from datetime import datetime 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | from pydistributedkv.service.storage import KeyValueStorage 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class LogCompactionService: 13 | """Service to handle periodic log compaction""" 14 | 15 | def __init__( 16 | self, 17 | storage: KeyValueStorage, 18 | compaction_interval: int = 3600, # Default: run compaction every hour 19 | min_compaction_interval: int = 600, # Minimum time between compactions (10 minutes) 20 | enabled: bool = True, 21 | ): 22 | self.storage = storage 23 | self.compaction_interval = compaction_interval 24 | self.min_compaction_interval = min_compaction_interval 25 | self.enabled = enabled 26 | self.last_compaction: Optional[datetime] = None 27 | self.compaction_running = False 28 | self.compaction_task = None 29 | self.compaction_history: List[dict] = [] 30 | 31 | async def start(self) -> None: 32 | """Start the compaction service""" 33 | if not self._can_start(): 34 | return 35 | 36 | self.compaction_task = asyncio.create_task(self._compaction_loop()) 37 | logger.info(f"Started log compaction service (interval: {self.compaction_interval}s)") 38 | 39 | def _can_start(self) -> bool: 40 | """Check if the service can start""" 41 | if self.compaction_task is not None: 42 | logger.warning("Compaction service already running") 43 | return False 44 | 45 | if not self.enabled: 46 | logger.info("Compaction service is disabled") 47 | return False 48 | 49 | return True 50 | 51 | async def stop(self) -> None: 52 | """Stop the compaction service""" 53 | if self.compaction_task: 54 | self.compaction_task.cancel() 55 | try: 56 | await self.compaction_task 57 | except asyncio.CancelledError: 58 | pass 59 | self.compaction_task = None 60 | logger.info("Stopped log compaction service") 61 | 62 | async def _compaction_loop(self) -> None: 63 | """Main compaction loop that runs periodically""" 64 | while True: 65 | try: 66 | # Sleep first to avoid immediate compaction on startup 67 | await asyncio.sleep(self.compaction_interval) 68 | await self._run_if_not_running() 69 | except asyncio.CancelledError: 70 | logger.info("Compaction loop cancelled") 71 | break 72 | except Exception as e: 73 | self._handle_loop_error(e) 74 | 75 | async def _run_if_not_running(self) -> None: 76 | """Run compaction if not already running""" 77 | if not self.compaction_running: 78 | await self.run_compaction() 79 | 80 | def _handle_loop_error(self, error: Exception) -> None: 81 | """Handle errors in the compaction loop""" 82 | logger.error(f"Error in compaction loop: {str(error)}") 83 | 84 | async def run_compaction(self, force: bool = False) -> Tuple[int, int]: 85 | """Run log compaction process 86 | 87 | Args: 88 | force: If True, runs compaction even if minimum interval hasn't passed 89 | 90 | Returns: 91 | Tuple of (segments_compacted, entries_removed) 92 | """ 93 | # Early returns for conditions where compaction should be skipped 94 | if self.compaction_running: 95 | logger.warning("Compaction already in progress, skipping") 96 | return 0, 0 97 | 98 | if not force and self._is_too_soon_for_compaction(): 99 | return 0, 0 100 | 101 | return await self._execute_compaction() 102 | 103 | def _is_too_soon_for_compaction(self) -> bool: 104 | """Check if it's too soon since the last compaction""" 105 | if not self.last_compaction: 106 | return False 107 | 108 | now = datetime.now() 109 | time_since_last = (now - self.last_compaction).total_seconds() 110 | 111 | if time_since_last < self.min_compaction_interval: 112 | logger.info(f"Skipping compaction, last run was {time_since_last:.1f}s ago " f"(min interval: {self.min_compaction_interval}s)") 113 | return True 114 | 115 | return False 116 | 117 | async def _execute_compaction(self) -> Tuple[int, int]: 118 | """Execute the actual compaction process""" 119 | self.compaction_running = True 120 | start_time = time.time() 121 | now = datetime.now() 122 | 123 | try: 124 | logger.info("Starting log compaction") 125 | result = self.storage.compact_log() 126 | self._record_compaction_result(result, start_time, now) 127 | return result 128 | except Exception as e: 129 | logger.error(f"Error during compaction: {str(e)}") 130 | raise 131 | finally: 132 | self.compaction_running = False 133 | 134 | def _record_compaction_result(self, result: Tuple[int, int], start_time: float, timestamp: datetime) -> None: 135 | """Record the result of a compaction operation""" 136 | segments_compacted, entries_removed = result 137 | duration = time.time() - start_time 138 | self.last_compaction = timestamp 139 | 140 | compaction_data = { 141 | "timestamp": timestamp.isoformat(), 142 | "duration_seconds": duration, 143 | "segments_compacted": segments_compacted, 144 | "entries_removed": entries_removed, 145 | } 146 | 147 | self._update_compaction_history(compaction_data) 148 | 149 | logger.info( 150 | f"Compaction completed in {duration:.2f}s: " f"compacted {segments_compacted} segments, removed {entries_removed} entries" 151 | ) 152 | 153 | def _update_compaction_history(self, compaction_data: Dict[str, Any]) -> None: 154 | """Update the compaction history, keeping only the latest entries""" 155 | self.compaction_history.append(compaction_data) 156 | if len(self.compaction_history) > 10: # Keep only last 10 entries 157 | self.compaction_history.pop(0) 158 | 159 | def get_status(self) -> dict: 160 | """Get the current status of the compaction service""" 161 | return { 162 | "enabled": self.enabled, 163 | "compaction_interval_seconds": self.compaction_interval, 164 | "min_compaction_interval_seconds": self.min_compaction_interval, 165 | "last_compaction": self.last_compaction.isoformat() if self.last_compaction else None, 166 | "compaction_running": self.compaction_running, 167 | "compaction_history": self.compaction_history, 168 | } 169 | 170 | def set_enabled(self, enabled: bool) -> bool: 171 | """Enable or disable the compaction service""" 172 | self.enabled = enabled 173 | return self.enabled 174 | 175 | def set_compaction_interval(self, interval: int) -> int: 176 | """Set the compaction interval in seconds""" 177 | if interval < 60: 178 | interval = 60 # Minimum 1 minute 179 | self.compaction_interval = interval 180 | return self.compaction_interval 181 | -------------------------------------------------------------------------------- /src/pydistributedkv/service/heartbeat.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | from typing import Dict, Optional 5 | 6 | import requests 7 | 8 | from pydistributedkv.configurator.settings.base import API_TIMEOUT, HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class HeartbeatService: 14 | """ 15 | Service for sending and monitoring heartbeats between servers in the cluster. 16 | """ 17 | 18 | def __init__(self, service_name: str, server_id: str, server_url: str): 19 | self.service_name = service_name 20 | self.server_id = server_id 21 | self.server_url = server_url 22 | self.servers: Dict[str, Dict] = {} # server_id -> {url, last_heartbeat, status} 23 | self._background_tasks = set() 24 | self._monitor_running = False 25 | self._send_running = False 26 | 27 | def register_server(self, server_id: str, server_url: str) -> None: 28 | """Register a server to be monitored""" 29 | current_time = time.time() 30 | self.servers[server_id] = {"url": server_url, "last_heartbeat": current_time, "status": "healthy"} 31 | logger.info(f"{self.service_name}: Registered server {server_id} at {server_url}") 32 | 33 | def deregister_server(self, server_id: str) -> None: 34 | """Deregister a server from monitoring""" 35 | if server_id in self.servers: 36 | del self.servers[server_id] 37 | logger.info(f"{self.service_name}: Deregistered server {server_id}") 38 | 39 | def record_heartbeat(self, server_id: str) -> None: 40 | """Record that a heartbeat was received from a server""" 41 | if server_id not in self.servers: 42 | logger.warning(f"{self.service_name}: Received heartbeat from unknown server {server_id}") 43 | return 44 | 45 | current_time = time.time() 46 | self.servers[server_id]["last_heartbeat"] = current_time 47 | 48 | # If server was previously down, mark it as healthy 49 | if self.servers[server_id]["status"] != "healthy": 50 | self.servers[server_id]["status"] = "healthy" 51 | logger.info(f"{self.service_name}: Server {server_id} is now healthy") 52 | 53 | def get_server_status(self, server_id: str) -> Optional[Dict]: 54 | """Get the status of a specific server""" 55 | return self.servers.get(server_id) 56 | 57 | def get_all_statuses(self) -> Dict: 58 | """Get the status of all servers""" 59 | return { 60 | server_id: { 61 | "url": info["url"], 62 | "status": info["status"], 63 | "last_heartbeat": info["last_heartbeat"], 64 | "seconds_since_last_heartbeat": time.time() - info["last_heartbeat"], 65 | } 66 | for server_id, info in self.servers.items() 67 | } 68 | 69 | def get_healthy_servers(self) -> Dict[str, str]: 70 | """Get a dictionary of healthy server IDs and URLs""" 71 | return {server_id: info["url"] for server_id, info in self.servers.items() if info["status"] == "healthy"} 72 | 73 | async def start_monitoring(self) -> None: 74 | """Start monitoring heartbeats from registered servers""" 75 | if self._monitor_running: 76 | return 77 | 78 | self._monitor_running = True 79 | task = asyncio.create_task(self._monitor_heartbeats()) 80 | self._background_tasks.add(task) 81 | task.add_done_callback(self._background_tasks.discard) 82 | logger.info(f"{self.service_name}: Started heartbeat monitoring") 83 | 84 | async def start_sending(self) -> None: 85 | """Start sending heartbeats to registered servers""" 86 | if self._send_running: 87 | return 88 | 89 | self._send_running = True 90 | task = asyncio.create_task(self._send_heartbeats()) 91 | self._background_tasks.add(task) 92 | task.add_done_callback(self._background_tasks.discard) 93 | logger.info(f"{self.service_name}: Started sending heartbeats") 94 | 95 | async def stop(self) -> None: 96 | """Stop all heartbeat activities""" 97 | self._monitor_running = False 98 | self._send_running = False 99 | # Wait for tasks to complete 100 | for task in self._background_tasks: 101 | task.cancel() 102 | logger.info(f"{self.service_name}: Stopped heartbeat service") 103 | 104 | async def _monitor_heartbeats(self) -> None: 105 | """Monitor heartbeats and mark servers as down if they miss heartbeats""" 106 | while self._monitor_running: 107 | current_time = time.time() 108 | self._check_server_heartbeats(current_time) 109 | await asyncio.sleep(HEARTBEAT_INTERVAL) 110 | 111 | def _check_server_heartbeats(self, current_time: float) -> None: 112 | """Check each server's heartbeat and update status if needed""" 113 | for server_id, info in self.servers.items(): 114 | if info["status"] == "down": 115 | continue 116 | 117 | time_since_last_heartbeat = current_time - info["last_heartbeat"] 118 | self._update_server_status(server_id, info, time_since_last_heartbeat) 119 | 120 | def _update_server_status(self, server_id: str, info: Dict, elapsed_time: float) -> None: 121 | """Update server status based on heartbeat timeout""" 122 | if elapsed_time <= HEARTBEAT_TIMEOUT: 123 | return 124 | 125 | info["status"] = "down" 126 | logger.warning(f"{self.service_name}: Server {server_id} marked as down. " f"No heartbeat for {elapsed_time:.1f}s") 127 | 128 | async def _send_heartbeats(self) -> None: 129 | """Send periodic heartbeats to all registered servers""" 130 | while self._send_running: 131 | await self._send_heartbeats_to_all_servers() 132 | await asyncio.sleep(HEARTBEAT_INTERVAL) 133 | 134 | async def _send_heartbeats_to_all_servers(self) -> None: 135 | """Send heartbeats to all registered servers in parallel""" 136 | for server_id, info in list(self.servers.items()): 137 | # Send heartbeat even to servers marked as down (to detect recovery) 138 | self._schedule_heartbeat(server_id, info) 139 | 140 | def _schedule_heartbeat(self, server_id: str, info: dict) -> None: 141 | """Schedule a non-blocking heartbeat to a specific server""" 142 | try: 143 | server_url = info["url"] 144 | asyncio.create_task(self._send_single_heartbeat(server_id, server_url)) 145 | except Exception as e: 146 | logger.error(f"{self.service_name}: Error preparing heartbeat to {server_id}: {str(e)}") 147 | 148 | async def _send_single_heartbeat(self, server_id: str, server_url: str) -> None: 149 | """Send a single heartbeat to a specific server""" 150 | try: 151 | response = await asyncio.to_thread( 152 | requests.post, f"{server_url}/heartbeat", json={"server_id": self.server_id, "timestamp": time.time()}, timeout=API_TIMEOUT 153 | ) 154 | 155 | if response.status_code == 200: 156 | logger.debug(f"{self.service_name}: Heartbeat sent to {server_id}") 157 | else: 158 | logger.warning(f"{self.service_name}: Heartbeat to {server_id} failed with status {response.status_code}") 159 | except requests.RequestException as e: 160 | logger.warning(f"{self.service_name}: Failed to send heartbeat to {server_id}: {str(e)}") 161 | -------------------------------------------------------------------------------- /src/pydistributedkv/service/request_deduplication.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from collections import defaultdict 4 | from typing import Any, Dict, List, Optional, Tuple 5 | 6 | from pydistributedkv.domain.models import ClientRequest 7 | 8 | # Configure logging 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class RequestDeduplicationService: 14 | """Service to track processed client requests and prevent duplicate processing""" 15 | 16 | def __init__(self, max_cache_size: int = 10000, expiry_seconds: int = 3600, service_name: str = "deduplication"): 17 | """ 18 | Initialize the request deduplication service. 19 | 20 | Args: 21 | max_cache_size: Maximum number of client requests to track 22 | expiry_seconds: Time in seconds after which cached requests should expire 23 | service_name: Name of the service using this deduplication (for logging) 24 | """ 25 | # Change to store operation type as part of the key 26 | # Structure: {client_id: {(request_id, operation): (timestamp, result)}} 27 | self.processed_requests: Dict[str, Dict[Tuple[str, str], Tuple[float, Any]]] = defaultdict(dict) 28 | self.max_cache_size = max_cache_size 29 | self.expiry_seconds = expiry_seconds 30 | self.service_name = service_name 31 | 32 | # Statistics 33 | self.total_requests_cached = 0 34 | self.total_duplicates_detected = 0 35 | self.total_cache_cleanups = 0 36 | 37 | # Track different types of duplicates 38 | self.same_operation_duplicates = 0 39 | self.different_operation_duplicates = 0 40 | 41 | logger.info( 42 | f"[{service_name}] Request deduplication service initialized with max_cache_size={max_cache_size}, expiry_seconds={expiry_seconds}" 43 | ) 44 | 45 | def mark_request_processed(self, client_request: ClientRequest, result: Any): 46 | """Mark a client request as processed with its result""" 47 | self._clean_expired_requests() 48 | 49 | client_id = client_request.client_id 50 | request_id = client_request.request_id 51 | operation = client_request.operation or "UNKNOWN" 52 | key = client_request.key or "N/A" 53 | 54 | # Use a tuple of (request_id, operation) as the cache key 55 | cache_key = (request_id, str(operation)) 56 | 57 | # Store the result with the current timestamp 58 | self.processed_requests[client_id][cache_key] = (time.time(), result) 59 | self.total_requests_cached += 1 60 | 61 | logger.info(f"[{self.service_name}] Cached result for client={client_id}, request={request_id}, operation={operation}, key={key}") 62 | 63 | # If we've exceeded our cache size, remove the oldest entries 64 | if len(self.processed_requests) > self.max_cache_size: 65 | self._clean_oldest_requests() 66 | 67 | def get_processed_result(self, client_id: str, request_id: str, operation: Optional[str] = None) -> Optional[Any]: 68 | """ 69 | Check if a request has been processed and return its result if found. 70 | Now takes into account the operation type. 71 | 72 | Args: 73 | client_id: Client identifier 74 | request_id: Request identifier 75 | operation: Operation type (GET, SET, DELETE) - must match for cache hit 76 | 77 | Returns: 78 | The stored result if the request was already processed with the same operation, None otherwise 79 | """ 80 | self._clean_expired_requests() 81 | 82 | if client_id not in self.processed_requests: 83 | return None 84 | 85 | # Check for exact operation match 86 | result = self._check_exact_operation_match(client_id, request_id, operation) 87 | if result is not None: 88 | return result 89 | 90 | # Log different operations with same request ID (no result returned) 91 | self._log_different_operations(client_id, request_id, operation) 92 | 93 | return None 94 | 95 | def _check_exact_operation_match(self, client_id: str, request_id: str, operation: Optional[str]) -> Optional[Any]: 96 | """Check if we have an exact match for request ID and operation""" 97 | if not operation: 98 | return None 99 | 100 | cache_key = (request_id, str(operation)) 101 | if cache_key not in self.processed_requests[client_id]: 102 | return None 103 | 104 | timestamp, result = self.processed_requests[client_id][cache_key] 105 | self.total_duplicates_detected += 1 106 | self.same_operation_duplicates += 1 107 | 108 | # Calculate how long ago this request was first processed 109 | time_since_original = time.time() - timestamp 110 | 111 | logger.warning( 112 | f"[{self.service_name}] DUPLICATE REQUEST DETECTED: client={client_id}, request={request_id}, " 113 | f"operation={operation}, originally processed {time_since_original:.2f} seconds ago" 114 | ) 115 | return result 116 | 117 | def _log_different_operations(self, client_id: str, request_id: str, operation: Optional[str]) -> None: 118 | """Log when different operations are attempted with the same request ID""" 119 | if not operation: 120 | return 121 | 122 | for (req_id, op), (_, _) in list(self.processed_requests[client_id].items()): 123 | if req_id == request_id and op != str(operation): 124 | self.different_operation_duplicates += 1 125 | logger.warning( 126 | f"[{self.service_name}] DIFFERENT OPERATION ATTEMPTED: client={client_id}, request={request_id}, " 127 | f"previous_op={op}, current_op={operation}" 128 | ) 129 | break 130 | 131 | def _clean_expired_requests(self): 132 | """Remove expired entries from the cache""" 133 | current_time = time.time() 134 | expired_count = self._remove_expired_entries(current_time) 135 | self._remove_empty_clients() 136 | 137 | if expired_count > 0: 138 | logger.info(f"[{self.service_name}] Cleaned up {expired_count} expired cache entries") 139 | self.total_cache_cleanups += 1 140 | 141 | def _remove_expired_entries(self, current_time: float) -> int: 142 | """Remove expired request entries and return count of removed entries""" 143 | expired_count = 0 144 | 145 | for _, requests in list(self.processed_requests.items()): 146 | # Find and remove expired requests for this client 147 | expired_requests = [req_key for req_key, (timestamp, _) in requests.items() if current_time - timestamp > self.expiry_seconds] 148 | 149 | # Remove expired requests 150 | for req_key in expired_requests: 151 | del requests[req_key] 152 | expired_count += 1 153 | 154 | return expired_count 155 | 156 | def _remove_empty_clients(self) -> None: 157 | """Remove client entries that have no requests""" 158 | empty_clients = [client_id for client_id, requests in self.processed_requests.items() if not requests] 159 | 160 | for client_id in empty_clients: 161 | del self.processed_requests[client_id] 162 | 163 | def _clean_oldest_requests(self): 164 | """Remove the oldest entries when the cache exceeds max size""" 165 | # Calculate how many entries to remove 166 | total_entries = self._count_total_entries() 167 | entries_to_remove = max(0, total_entries - self.max_cache_size) 168 | 169 | if entries_to_remove <= 0: 170 | return 171 | 172 | logger.info(f"[{self.service_name}] Cache size limit reached, removing {entries_to_remove} oldest entries") 173 | 174 | # Get sorted entries by age (oldest first) 175 | oldest_entries = self._get_entries_sorted_by_age() 176 | 177 | # Remove oldest entries 178 | self._remove_oldest_entries(oldest_entries, entries_to_remove) 179 | self.total_cache_cleanups += 1 180 | 181 | def _count_total_entries(self) -> int: 182 | """Count total entries across all clients""" 183 | return sum(len(requests) for requests in self.processed_requests.values()) 184 | 185 | def _get_entries_sorted_by_age(self) -> List[Tuple[float, str, Tuple[str, str]]]: 186 | """Get all entries sorted by timestamp (oldest first)""" 187 | all_entries = [] 188 | for client_id, requests in self.processed_requests.items(): 189 | for req_key, (timestamp, _) in requests.items(): 190 | all_entries.append((timestamp, client_id, req_key)) 191 | 192 | all_entries.sort() # Sort by timestamp (oldest first) 193 | return all_entries 194 | 195 | def _remove_oldest_entries(self, sorted_entries: List[Tuple[float, str, Tuple[str, str]]], count: int) -> None: 196 | """Remove the specified number of oldest entries""" 197 | entries_to_remove = sorted_entries[: min(count, len(sorted_entries))] 198 | 199 | for _, client_id, req_key in entries_to_remove: 200 | self._remove_entry_if_exists(client_id, req_key) 201 | 202 | def _remove_entry_if_exists(self, client_id: str, req_key: Tuple[str, str]) -> None: 203 | """Remove a single entry if it exists and clean up empty client entries""" 204 | if client_id not in self.processed_requests: 205 | return 206 | 207 | client_requests = self.processed_requests[client_id] 208 | if req_key in client_requests: 209 | del client_requests[req_key] 210 | 211 | # Clean up empty client entry 212 | if not client_requests: 213 | del self.processed_requests[client_id] 214 | 215 | def get_stats(self) -> dict: 216 | """Return statistics about the deduplication service""" 217 | total_cached = 0 218 | unique_request_ids = set() 219 | 220 | for _, requests in self.processed_requests.items(): 221 | total_cached += len(requests) 222 | for req_id, _ in requests.keys(): 223 | unique_request_ids.add(req_id) 224 | 225 | return { 226 | "service_name": self.service_name, 227 | "current_cache_size": total_cached, 228 | "unique_request_ids": len(unique_request_ids), 229 | "total_client_count": len(self.processed_requests), 230 | "total_requests_cached": self.total_requests_cached, 231 | "total_duplicates_detected": self.total_duplicates_detected, 232 | "same_operation_duplicates": self.same_operation_duplicates, 233 | "different_operation_duplicates": self.different_operation_duplicates, 234 | "total_cache_cleanups": self.total_cache_cleanups, 235 | } 236 | -------------------------------------------------------------------------------- /src/pydistributedkv/service/storage.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | from pydistributedkv.domain.models import LogEntry, OperationType, VersionedValue, WAL 4 | 5 | 6 | class KeyValueStorage: 7 | def __init__(self, wal: WAL): 8 | self.wal = wal 9 | # Changed from Dict[str, Any] to Dict[str, VersionedValue] 10 | self.data: Dict[str, VersionedValue] = {} 11 | self._replay_log() 12 | 13 | def _replay_log(self): 14 | """Replay the WAL to rebuild the in-memory state""" 15 | entries = self.wal.read_from(0) 16 | entries_count = len(entries) 17 | print(f"Replaying {entries_count} entries from WAL...") 18 | 19 | for i, entry in enumerate(entries): 20 | self._apply_log_entry(entry) 21 | 22 | # Log progress for larger datasets 23 | if entries_count > 1000 and i % 1000 == 0: 24 | print(f"Replayed {i}/{entries_count} entries...") 25 | 26 | print(f"Finished replaying {entries_count} entries, data store contains {len(self.data)} keys") 27 | 28 | def _apply_log_entry(self, entry: LogEntry) -> None: 29 | """Apply a single log entry to the in-memory state""" 30 | if entry.operation == OperationType.SET: 31 | self._apply_set_operation(entry) 32 | elif entry.operation == OperationType.DELETE: 33 | self._apply_delete_operation(entry) 34 | 35 | def _apply_set_operation(self, entry: LogEntry) -> None: 36 | """Apply SET operation to the in-memory state""" 37 | version = entry.version if entry.version is not None else 1 38 | 39 | if entry.key in self.data: 40 | self.data[entry.key].update(entry.value, version) 41 | else: 42 | self.data[entry.key] = VersionedValue(current_version=version, value=entry.value) 43 | 44 | def _apply_delete_operation(self, entry: LogEntry) -> None: 45 | """Apply DELETE operation to the in-memory state""" 46 | if entry.key in self.data: 47 | del self.data[entry.key] 48 | 49 | def set(self, key: str, value: Any, version: Optional[int] = None) -> Tuple[LogEntry | None, int]: 50 | """Set a key-value pair and log the operation 51 | 52 | Returns: 53 | Tuple[LogEntry, int]: The log entry and the actual version used 54 | """ 55 | # Determine version and check for conflicts 56 | next_version, conflict = self._determine_version(key, version) 57 | if conflict: 58 | current_version = self.data[key].current_version 59 | return None, current_version 60 | 61 | # Create and log the entry with version 62 | entry = self.wal.append(OperationType.SET, key, value, version=next_version) 63 | 64 | # Update in-memory state 65 | self._update_in_memory_state(key, value, next_version) 66 | 67 | return entry, next_version 68 | 69 | def _determine_version(self, key: str, requested_version: Optional[int]) -> Tuple[int, bool]: 70 | """Determine the appropriate version and check for conflicts 71 | 72 | Returns: 73 | Tuple[int, bool]: (next_version, has_conflict) 74 | """ 75 | if key not in self.data: 76 | return self._handle_new_key(requested_version) 77 | 78 | return self._handle_existing_key(key, requested_version) 79 | 80 | def _handle_new_key(self, requested_version: Optional[int]) -> Tuple[int, bool]: 81 | """Handle version determination for a new key""" 82 | next_version = requested_version if requested_version and requested_version > 1 else 1 83 | return next_version, False 84 | 85 | def _handle_existing_key(self, key: str, requested_version: Optional[int]) -> Tuple[int, bool]: 86 | """Handle version determination for an existing key""" 87 | current_version = self.data[key].current_version 88 | if requested_version and requested_version <= current_version: 89 | return current_version, True 90 | return current_version + 1, False 91 | 92 | def _update_in_memory_state(self, key: str, value: Any, version: int) -> None: 93 | """Update the in-memory state with the new value and version""" 94 | if key in self.data: 95 | self.data[key].update(value, version) 96 | else: 97 | self.data[key] = VersionedValue(current_version=version, value=value) 98 | 99 | def get(self, key: str, version: Optional[int] = None) -> Optional[Any]: 100 | """Get a value by key and optional version""" 101 | if key not in self.data: 102 | return None 103 | 104 | return self.data[key].get_value(version) 105 | 106 | def get_with_version(self, key: str, version: Optional[int] = None) -> Optional[Tuple[Any, int]]: 107 | """Get a value and its version by key""" 108 | if key not in self.data: 109 | return None 110 | 111 | versioned_value = self.data[key] 112 | value = versioned_value.get_value(version) 113 | 114 | if value is None: 115 | return None 116 | 117 | actual_version = version if version is not None else versioned_value.current_version 118 | return (value, actual_version) 119 | 120 | def get_version_history(self, key: str) -> Optional[Dict[int, Any]]: 121 | """Get the version history for a key""" 122 | if key not in self.data: 123 | return None 124 | 125 | versioned_value = self.data[key] 126 | # Start with the current version 127 | history = {versioned_value.current_version: versioned_value.value} 128 | 129 | # Add all historical versions if they exist 130 | if versioned_value.history: 131 | history.update(versioned_value.history) 132 | 133 | return history 134 | 135 | def delete(self, key: str) -> Optional[LogEntry]: 136 | """Delete a key and log the operation""" 137 | if key in self.data: 138 | entry = self.wal.append(OperationType.DELETE, key) 139 | del self.data[key] 140 | return entry 141 | return None 142 | 143 | def apply_entries(self, entries: List[LogEntry]) -> int: 144 | """Apply multiple log entries and return the last applied ID""" 145 | last_id = 0 146 | for entry in entries: 147 | self._apply_log_entry(entry) 148 | last_id = entry.id 149 | return last_id 150 | 151 | def get_all_keys(self) -> List[str]: 152 | """Get a list of all keys in the storage""" 153 | return list(self.data.keys()) 154 | 155 | def get_latest_version(self, key: str) -> Optional[int]: 156 | """Get the latest version number for a key""" 157 | if key not in self.data: 158 | return None 159 | return self.data[key].current_version 160 | 161 | def compact_log(self) -> Tuple[int, int]: 162 | """Compact the write-ahead log by removing redundant entries. 163 | 164 | Returns: 165 | Tuple containing (segments_compacted, entries_removed) 166 | """ 167 | segments_compacted, entries_removed = self.wal.compact_segments() 168 | return segments_compacted, entries_removed 169 | -------------------------------------------------------------------------------- /src/pydistributedkv/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/src/pydistributedkv/utils/__init__.py -------------------------------------------------------------------------------- /src/pydistributedkv/utils/common.py: -------------------------------------------------------------------------------- 1 | """Common shared file for supplementary utils""" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # This file makes the tests directory a proper Python package 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/tests/conftest.py -------------------------------------------------------------------------------- /tests/domain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/tests/domain/__init__.py -------------------------------------------------------------------------------- /tests/domain/test_models.py: -------------------------------------------------------------------------------- 1 | # test_distributed_kv.py 2 | import os 3 | import shutil 4 | import tempfile 5 | 6 | from pydistributedkv.domain.models import OperationType, WAL 7 | from pydistributedkv.service.storage import KeyValueStorage 8 | 9 | 10 | # Test the WAL component 11 | def test_wal(): 12 | temp_dir = tempfile.mkdtemp() 13 | try: 14 | log_path = os.path.join(temp_dir, "test_wal.log") 15 | 16 | # Import here to avoid import errors 17 | 18 | wal = WAL(log_path) 19 | 20 | # Test appending entries 21 | entry1 = wal.append(OperationType.SET, "key1", "value1") 22 | assert entry1.id == 1 23 | assert entry1.operation == OperationType.SET 24 | assert entry1.key == "key1" 25 | assert entry1.value == "value1" 26 | 27 | entry2 = wal.append(OperationType.SET, "key2", "value2") 28 | assert entry2.id == 2 29 | 30 | entry3 = wal.append(OperationType.DELETE, "key1") 31 | assert entry3.id == 3 32 | assert entry3.operation == OperationType.DELETE 33 | assert entry3.key == "key1" 34 | 35 | # Test reading entries 36 | entries = wal.read_from(0) 37 | assert len(entries) == 3 38 | 39 | entries = wal.read_from(2) 40 | assert len(entries) == 2 41 | assert entries[0].id == 2 42 | 43 | # Test persistence and reloading 44 | wal2 = WAL(log_path) 45 | assert wal2.get_last_id() == 3 46 | 47 | entries = wal2.read_from(0) 48 | assert len(entries) == 3 49 | finally: 50 | shutil.rmtree(temp_dir) 51 | 52 | 53 | # Test the storage component 54 | def test_storage(): 55 | temp_dir = tempfile.mkdtemp() 56 | try: 57 | log_path = os.path.join(temp_dir, "test_storage.log") 58 | 59 | wal = WAL(log_path) 60 | storage = KeyValueStorage(wal) 61 | 62 | # Test basic operations 63 | storage.set("key1", "value1") 64 | assert storage.get("key1") == "value1" 65 | 66 | storage.set("key2", {"nested": "value"}) 67 | assert storage.get("key2") == {"nested": "value"} 68 | 69 | storage.delete("key1") 70 | assert storage.get("key1") is None 71 | 72 | # Test durability: recreate the storage with the same WAL 73 | storage2 = KeyValueStorage(wal) 74 | assert storage2.get("key1") is None 75 | assert storage2.get("key2") == {"nested": "value"} 76 | finally: 77 | shutil.rmtree(temp_dir) 78 | -------------------------------------------------------------------------------- /tests/domain/test_wal_crc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | 6 | from pydistributedkv.domain.models import LogEntry, OperationType, WAL 7 | 8 | 9 | def test_log_entry_crc_calculation(): 10 | # Create a log entry and verify CRC calculation works 11 | entry = LogEntry(id=1, operation=OperationType.SET, key="test_key", value="test_value") 12 | 13 | # Initially CRC is None 14 | assert entry.crc is None 15 | 16 | # Calculate CRC 17 | crc = entry.calculate_crc() 18 | assert isinstance(crc, int) 19 | 20 | # Set CRC and validate 21 | entry.crc = crc 22 | assert entry.validate_crc() is True 23 | 24 | # Modify entry and verify CRC becomes invalid 25 | entry.value = "modified_value" 26 | assert entry.validate_crc() is False 27 | 28 | # Recalculate CRC after modification 29 | new_crc = entry.calculate_crc() 30 | assert new_crc != crc 31 | 32 | entry.crc = new_crc 33 | assert entry.validate_crc() is True 34 | 35 | 36 | def test_wal_skips_invalid_crc(): 37 | temp_dir = tempfile.mkdtemp() 38 | try: 39 | log_path = os.path.join(temp_dir, "wal.log") 40 | 41 | # Create a WAL and add some entries 42 | wal = WAL(log_path) 43 | entry1 = wal.append(OperationType.SET, "key1", "value1") 44 | entry2 = wal.append(OperationType.SET, "key2", "value2") 45 | 46 | # Get the actual segment file path that WAL is using 47 | active_segment = wal.get_active_segment() 48 | 49 | # Manually corrupt the second entry in the log file 50 | with open(active_segment, "r") as f: 51 | lines = f.readlines() 52 | 53 | # Parse the second entry, modify its value but keep the old CRC 54 | corrupted_entry = json.loads(lines[1]) 55 | corrupted_entry["value"] = "corrupted_value" 56 | 57 | # Write back the corrupted log 58 | with open(active_segment, "w") as f: 59 | f.write(lines[0]) # Write the first entry unchanged 60 | f.write(json.dumps(corrupted_entry) + "\n") 61 | 62 | # Create a new WAL instance to load from the corrupted file 63 | wal2 = WAL(log_path) 64 | 65 | # Only the valid entry should be loaded 66 | assert wal2.get_last_id() == 1 67 | assert 1 in wal2.existing_ids 68 | assert 2 not in wal2.existing_ids 69 | 70 | # Reading entries should only return the valid one 71 | entries = wal2.read_from(0) 72 | assert len(entries) == 1 73 | assert entries[0].id == 1 74 | 75 | finally: 76 | shutil.rmtree(temp_dir) 77 | 78 | 79 | def test_wal_handles_missing_crc(): 80 | temp_dir = tempfile.mkdtemp() 81 | try: 82 | log_path = os.path.join(temp_dir, "wal.log") 83 | 84 | # Create a WAL to initialize the segment structure 85 | wal = WAL(log_path) 86 | active_segment = wal.get_active_segment() 87 | 88 | # Clear the segment file to start fresh 89 | open(active_segment, "w").close() 90 | 91 | # Manually create a log file with an entry missing CRC 92 | entry_without_crc = {"id": 1, "operation": "SET", "key": "test_key", "value": "test_value"} 93 | 94 | with open(active_segment, "w") as f: 95 | f.write(json.dumps(entry_without_crc) + "\n") 96 | 97 | # Load the WAL and verify it skips the entry without CRC 98 | wal = WAL(log_path) 99 | 100 | # The entry should be loaded despite missing CRC (legacy compatibility) 101 | assert wal.get_last_id() == 1 102 | assert 1 in wal.existing_ids 103 | 104 | finally: 105 | shutil.rmtree(temp_dir) 106 | 107 | 108 | def test_append_entry_recalculates_invalid_crc(): 109 | temp_dir = tempfile.mkdtemp() 110 | try: 111 | log_path = os.path.join(temp_dir, "wal.log") 112 | wal = WAL(log_path) 113 | 114 | # Create an entry with an invalid CRC 115 | entry = LogEntry(id=1, operation=OperationType.SET, key="key1", value="value1", crc=12345) 116 | 117 | # Append the entry - WAL should recalculate the CRC 118 | appended_entry = wal.append_entry(entry) 119 | 120 | # The entry should have a valid CRC now 121 | assert appended_entry.validate_crc() is True 122 | assert appended_entry.crc != 12345 # CRC should be recalculated 123 | 124 | # Load the WAL again to verify the entry was stored with valid CRC 125 | wal2 = WAL(log_path) 126 | entries = wal2.read_from(0) 127 | 128 | assert len(entries) == 1 129 | assert entries[0].validate_crc() is True 130 | 131 | finally: 132 | shutil.rmtree(temp_dir) 133 | 134 | 135 | def test_duplicate_entry_handling(): 136 | temp_dir = tempfile.mkdtemp() 137 | try: 138 | log_path = os.path.join(temp_dir, "test_duplicate.log") 139 | wal = WAL(log_path) 140 | 141 | # Add an entry 142 | original_entry = wal.append(OperationType.SET, "key1", "value1") 143 | 144 | # Try to add a duplicate entry with the same ID but different content 145 | duplicate_entry = LogEntry(id=1, operation=OperationType.SET, key="duplicate_key", value="duplicate_value") 146 | duplicate_entry.crc = duplicate_entry.calculate_crc() 147 | 148 | # WAL should ignore the duplicate entry 149 | wal.append_entry(duplicate_entry) 150 | 151 | # Verify the original entry wasn't replaced 152 | entries = wal.read_from(0) 153 | assert len(entries) == 1 154 | assert entries[0].key == "key1" 155 | assert entries[0].value == "value1" 156 | 157 | finally: 158 | shutil.rmtree(temp_dir) 159 | 160 | 161 | def test_corrupted_json_handling(): 162 | temp_dir = tempfile.mkdtemp() 163 | try: 164 | log_path = os.path.join(temp_dir, "test_corrupted_json.log") 165 | 166 | # Create a WAL with some valid entries 167 | wal = WAL(log_path) 168 | wal.append(OperationType.SET, "key1", "value1") 169 | wal.append(OperationType.SET, "key2", "value2") 170 | 171 | # Get the active segment 172 | active_segment = wal.get_active_segment() 173 | 174 | # Append some corrupted JSON to the log file 175 | with open(active_segment, "a") as f: 176 | f.write("{this is not valid JSON}\n") 177 | 178 | # Add a valid entry after the corrupted one - the WAL should process this correctly 179 | valid_entry = LogEntry(id=3, operation=OperationType.SET, key="key3", value="value3") 180 | valid_entry.crc = valid_entry.calculate_crc() 181 | f.write(json.dumps(valid_entry.model_dump()) + "\n") 182 | 183 | # Load the WAL again 184 | wal2 = WAL(log_path) 185 | 186 | # It should have loaded all valid entries (including entry 3) and skipped the corrupted one 187 | valid_entries = wal2.read_from(0) 188 | assert len(valid_entries) == 3 189 | assert valid_entries[0].id == 1 190 | assert valid_entries[1].id == 2 191 | assert valid_entries[2].id == 3 192 | 193 | finally: 194 | shutil.rmtree(temp_dir) 195 | -------------------------------------------------------------------------------- /tests/domain/test_wal_segmentation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import unittest 6 | from unittest.mock import patch 7 | 8 | from pydistributedkv.domain.models import LogEntry, OperationType, WAL 9 | from pydistributedkv.service.storage import KeyValueStorage 10 | 11 | 12 | class TestWALSegmentation(unittest.TestCase): 13 | def setUp(self): 14 | # Create a temporary directory for test WAL files 15 | self.temp_dir = tempfile.mkdtemp() 16 | self.wal_path = os.path.join(self.temp_dir, "wal.log") 17 | 18 | # Use a small segment size for testing 19 | self.small_segment_size = 100 # bytes 20 | 21 | def tearDown(self): 22 | # Clean up the temporary directory after tests 23 | shutil.rmtree(self.temp_dir) 24 | 25 | def test_segment_creation(self): 26 | """Test that segments are created with correct naming pattern""" 27 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 28 | 29 | # Verify initial segment 30 | segments = wal.get_segment_files() 31 | self.assertEqual(len(segments), 1) 32 | self.assertTrue(segments[0].endswith("wal.log.segment.1")) 33 | 34 | # Verify active segment 35 | active_segment = wal.get_active_segment() 36 | self.assertEqual(active_segment, segments[0]) 37 | 38 | def test_segment_rollover(self): 39 | """Test that new segments are created when size limit is reached""" 40 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 41 | 42 | # Add entries to force segment rollover 43 | # Each JSON entry will be ~50-70 bytes, so a few entries should trigger rollover 44 | for i in range(10): 45 | wal.append(OperationType.SET, f"key{i}", f"value{i}" * 5) # Make value larger to exceed segment size quicker 46 | 47 | # Verify multiple segments were created 48 | segments = wal.get_segment_files() 49 | self.assertGreater(len(segments), 1, f"Expected multiple segments but got {len(segments)}") 50 | 51 | # Verify segments have sequential numbering 52 | segment_numbers = [int(seg.split(".")[-1]) for seg in segments] 53 | self.assertEqual(segment_numbers, list(range(1, len(segments) + 1))) 54 | 55 | # Verify the active segment is the last one 56 | active_segment = wal.get_active_segment() 57 | self.assertEqual(active_segment, segments[-1]) 58 | 59 | def test_segment_size_limit(self): 60 | """Test that segments respect the configured size limit""" 61 | # Increase segment size to handle the larger entries with version information 62 | self.small_segment_size = 150 # bytes - increased to accommodate version field 63 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 64 | 65 | # Generate entries with much larger values to ensure segments fill up properly 66 | # Use a fixed string size to make segments more predictable 67 | large_value = "X" * 60 # Fixed size string to make entries larger 68 | 69 | # Add enough entries to ensure multiple segments with substantial content 70 | for i in range(10): # Fewer entries, but each one is larger 71 | wal.append(OperationType.SET, f"key{i}", large_value) 72 | 73 | segments = wal.get_segment_files() 74 | self.assertGreater(len(segments), 1, "Test needs multiple segments to be valid") 75 | 76 | # Only test the size of completed segments (not the active one) 77 | # The active segment might not be full yet 78 | completed_segments = segments[:-1] 79 | 80 | print(f"Testing {len(segments)} segments with size limit {self.small_segment_size}") 81 | print(f"Active segment: {segments[-1]}") 82 | 83 | for i, segment in enumerate(completed_segments): 84 | size = os.path.getsize(segment) 85 | print(f"Completed segment {i+1}: {segment} size = {size} bytes") 86 | 87 | # 1. Check that size doesn't greatly exceed the limit 88 | allowed_buffer = 20 # Bytes of buffer for overhead 89 | self.assertLessEqual( 90 | size, 91 | self.small_segment_size + allowed_buffer, 92 | f"Segment {segment} size {size} exceeds limit {self.small_segment_size} by more than {allowed_buffer} bytes", 93 | ) 94 | 95 | # 2. Skip the minimum size check if we have only one completed segment 96 | # The first segment might not be full if we roll over quickly due to entry size 97 | if len(completed_segments) > 1 and i > 0: # Skip first segment, check others if we have multiple 98 | min_expected_size = self.small_segment_size * 0.5 99 | self.assertGreater( 100 | size, 101 | min_expected_size, 102 | f"Segment {segment} size {size} is too small relative to limit {self.small_segment_size} (min expected {min_expected_size})", 103 | ) 104 | 105 | def test_replay_across_segments(self): 106 | """Test that log replay works correctly across multiple segments""" 107 | # Create a WAL with small segments and add many entries 108 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 109 | 110 | # Add entries that will span multiple segments 111 | expected_data = {} 112 | for i in range(30): 113 | key = f"key{i}" 114 | value = f"value{i}" 115 | 116 | if i % 3 == 0 and i > 0: # Delete some keys periodically 117 | prev_key = f"key{i-3}" 118 | wal.append(OperationType.DELETE, prev_key) 119 | if prev_key in expected_data: 120 | del expected_data[prev_key] 121 | else: 122 | wal.append(OperationType.SET, key, value) 123 | expected_data[key] = value 124 | 125 | # Verify we have multiple segments 126 | segments = wal.get_segment_files() 127 | self.assertGreater(len(segments), 2, "Expected at least 3 segments for this test") 128 | 129 | # Create a new WAL instance that will replay the log 130 | new_wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 131 | 132 | # Create a storage that will replay the log 133 | storage = KeyValueStorage(new_wal) 134 | 135 | # Verify the replayed data matches expectations 136 | for key, expected_value in expected_data.items(): 137 | self.assertEqual(storage.get(key), expected_value, f"Replayed value for {key} doesn't match expected value") 138 | 139 | # Verify keys that should be deleted are not present 140 | for i in range(30): 141 | if i % 3 == 0 and i > 0: 142 | self.assertIsNone(storage.get(f"key{i-3}"), f"Key key{i-3} should have been deleted") 143 | 144 | def test_read_from_specific_id(self): 145 | """Test that reading from a specific ID works across segments""" 146 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 147 | 148 | # Add entries that will span multiple segments 149 | entries_by_id = {} 150 | for i in range(30): 151 | entry = wal.append(OperationType.SET, f"key{i}", f"value{i}") 152 | entries_by_id[entry.id] = entry 153 | 154 | # Pick a starting ID in the middle 155 | start_id = wal.get_last_id() // 2 156 | 157 | # Read entries from that ID 158 | entries = wal.read_from(start_id) 159 | 160 | # Verify we got the correct entries 161 | self.assertEqual(len(entries), len(entries_by_id) - start_id + 1) 162 | for entry in entries: 163 | self.assertGreaterEqual(entry.id, start_id) 164 | self.assertEqual(entry.key, entries_by_id[entry.id].key) 165 | self.assertEqual(entry.value, entries_by_id[entry.id].value) 166 | self.assertEqual(entry.operation, entries_by_id[entry.id].operation) 167 | 168 | def test_data_integrity_across_segments(self): 169 | """Test that CRC validation works across segments""" 170 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 171 | 172 | # Add entries that will span multiple segments 173 | for i in range(20): 174 | wal.append(OperationType.SET, f"key{i}", f"value{i}") 175 | 176 | # Verify we have multiple segments 177 | segments = wal.get_segment_files() 178 | self.assertGreater(len(segments), 1, "Expected multiple segments for this test") 179 | 180 | # Read all entries 181 | entries = wal.read_from(0) 182 | 183 | # Verify all entries have valid CRCs 184 | for entry in entries: 185 | self.assertTrue(entry.validate_crc(), f"Entry {entry.id} failed CRC validation") 186 | 187 | def test_integrity_after_corruption(self): 188 | """Test that corrupted entries are handled properly during replay""" 189 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 190 | 191 | # Add entries that will span multiple segments 192 | for i in range(15): 193 | wal.append(OperationType.SET, f"key{i}", f"value{i}") 194 | 195 | # Verify we have multiple segments 196 | segments = wal.get_segment_files() 197 | self.assertGreater(len(segments), 1, "Expected multiple segments for this test") 198 | 199 | # Corrupt the middle segment by modifying an entry's value but keeping its original CRC 200 | middle_segment = segments[len(segments) // 2] 201 | with open(middle_segment, "r") as f: 202 | lines = f.readlines() 203 | 204 | # Modify at least one line to have invalid CRC 205 | if lines: 206 | line_to_corrupt = lines[0] # Take the first entry in the middle segment 207 | entry = json.loads(line_to_corrupt) 208 | original_crc = entry["crc"] # Save the original CRC 209 | 210 | # Modify the value but keep the original CRC - this will cause CRC validation to fail 211 | entry["value"] = "corrupted_value" 212 | 213 | # Write back to file with the original CRC which is now invalid 214 | lines[0] = json.dumps(entry) + "\n" 215 | 216 | with open(middle_segment, "w") as f: 217 | f.writelines(lines) 218 | 219 | # Create a new WAL instance that will replay the log and capture print statements 220 | with patch("builtins.print") as mock_print: 221 | new_wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 222 | # Explicitly read all entries to trigger validation 223 | entries = new_wal.read_from(0) 224 | 225 | # Verify warning was printed about invalid CRC or skipping entries 226 | crc_warning_printed = False 227 | for call in mock_print.call_args_list: 228 | call_args = " ".join(str(arg) for arg in call[0]) 229 | if "invalid CRC" in call_args.lower() or "skipping" in call_args.lower() or "crc validation" in call_args.lower(): 230 | crc_warning_printed = True 231 | break 232 | 233 | self.assertTrue(crc_warning_printed, "Expected CRC validation warning") 234 | 235 | # The WAL should still initialize and contain valid entries 236 | entries = new_wal.read_from(0) 237 | self.assertGreater(len(entries), 0, "Expected some valid entries despite corruption") 238 | -------------------------------------------------------------------------------- /tests/integrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/tests/integrations/__init__.py -------------------------------------------------------------------------------- /tests/integrations/delete_key.http: -------------------------------------------------------------------------------- 1 | ### DELETE request to leader server 2 | DELETE http://localhost:8000/key/mykey 3 | -------------------------------------------------------------------------------- /tests/integrations/get_key.http: -------------------------------------------------------------------------------- 1 | ### GET request to leader server 2 | GET http://localhost:8000/key/mykey 3 | 4 | ### GET request to follower server 5 | 6 | GET http://localhost:8001/key/mykey 7 | 8 | 9 | ### GET request to leader server 10 | 11 | GET http://localhost:8000/key/new 12 | -------------------------------------------------------------------------------- /tests/integrations/set_key.http: -------------------------------------------------------------------------------- 1 | ### PUT request to leader server 2 | PUT http://localhost:8000/key/mykey 3 | Content-Type: application/json 4 | 5 | {"value": "myvalue"} 6 | 7 | 8 | ### Another request 9 | 10 | PUT http://localhost:8000/key/new 11 | Content-Type: application/json 12 | 13 | {"value": "new-myvalue"} 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /tests/service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShahriyarR/py-distributed-kv/c404cccae0e190fdfd2bd08a7caed53a86620f3d/tests/service/__init__.py -------------------------------------------------------------------------------- /tests/service/test_compaction_service.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import logging 4 | from unittest import mock 5 | 6 | import pytest 7 | 8 | from pydistributedkv.domain.models import WAL 9 | from pydistributedkv.service.compaction import LogCompactionService 10 | from pydistributedkv.service.storage import KeyValueStorage 11 | 12 | 13 | @pytest.fixture 14 | def mock_storage(): 15 | """Create a mock storage object""" 16 | storage = mock.MagicMock(spec=KeyValueStorage) 17 | storage.compact_log.return_value = (3, 100) # Default: 3 segments compacted, 100 entries removed 18 | return storage 19 | 20 | 21 | @pytest.fixture 22 | def compaction_service(mock_storage): 23 | """Create a compaction service with a mock storage""" 24 | return LogCompactionService( 25 | storage=mock_storage, 26 | compaction_interval=1, # Use short interval for testing 27 | min_compaction_interval=0.1, # Use short min interval for testing 28 | enabled=True, 29 | ) 30 | 31 | 32 | class TestLogCompactionService: 33 | def test_initialization(self, mock_storage): 34 | """Test that the service initializes with correct parameters""" 35 | # Default initialization 36 | service = LogCompactionService(storage=mock_storage) 37 | assert service.storage == mock_storage 38 | assert service.compaction_interval == 3600 # Default: 1 hour 39 | assert service.min_compaction_interval == 600 # Default: 10 minutes 40 | assert service.enabled is True 41 | assert service.last_compaction is None 42 | assert service.compaction_running is False 43 | assert service.compaction_task is None 44 | assert service.compaction_history == [] 45 | 46 | # Custom initialization 47 | service = LogCompactionService(storage=mock_storage, compaction_interval=300, min_compaction_interval=60, enabled=False) 48 | assert service.compaction_interval == 300 49 | assert service.min_compaction_interval == 60 50 | assert service.enabled is False 51 | 52 | def test_set_enabled(self, compaction_service): 53 | """Test enabling and disabling the service""" 54 | # Initially enabled 55 | assert compaction_service.enabled is True 56 | 57 | # Disable 58 | result = compaction_service.set_enabled(False) 59 | assert result is False 60 | assert compaction_service.enabled is False 61 | 62 | # Enable 63 | result = compaction_service.set_enabled(True) 64 | assert result is True 65 | assert compaction_service.enabled is True 66 | 67 | def test_set_compaction_interval(self, compaction_service): 68 | """Test setting the compaction interval""" 69 | # Initial interval 70 | assert compaction_service.compaction_interval == 1 71 | 72 | # Set to 120 seconds 73 | result = compaction_service.set_compaction_interval(120) 74 | assert result == 120 75 | assert compaction_service.compaction_interval == 120 76 | 77 | # Test minimum limit (should be clamped to 60) 78 | result = compaction_service.set_compaction_interval(30) 79 | assert result == 60 80 | assert compaction_service.compaction_interval == 60 81 | 82 | @pytest.mark.asyncio 83 | async def test_start_already_running(self, compaction_service): 84 | """Test starting the service when it's already running""" 85 | # Mock the compaction task 86 | compaction_service.compaction_task = mock.MagicMock() 87 | 88 | # Try to start it again - use the module-level logger from the compaction service module 89 | with mock.patch("pydistributedkv.service.compaction.logger.warning") as mock_warning: 90 | await compaction_service.start() 91 | mock_warning.assert_called_once_with("Compaction service already running") 92 | 93 | # Task should not have been changed 94 | assert compaction_service.compaction_task is not None 95 | 96 | @pytest.mark.asyncio 97 | async def test_start_disabled(self, compaction_service): 98 | """Test starting the service when it's disabled""" 99 | compaction_service.enabled = False 100 | 101 | # Fix: Use the module-level logger from the compaction service module 102 | with mock.patch("pydistributedkv.service.compaction.logger.info") as mock_info: 103 | await compaction_service.start() 104 | mock_info.assert_called_once_with("Compaction service is disabled") 105 | 106 | # No task should be created 107 | assert compaction_service.compaction_task is None 108 | 109 | @pytest.mark.asyncio 110 | async def test_start_and_stop(self, compaction_service): 111 | """Test starting and stopping the compaction service""" 112 | # Start the service 113 | with mock.patch("pydistributedkv.service.compaction.logger.info") as mock_info: 114 | await compaction_service.start() 115 | mock_info.assert_called_with(f"Started log compaction service (interval: {compaction_service.compaction_interval}s)") 116 | 117 | # Task should have been created 118 | assert compaction_service.compaction_task is not None 119 | 120 | # Stop the service 121 | with mock.patch("pydistributedkv.service.compaction.logger.info") as mock_info: 122 | await compaction_service.stop() 123 | mock_info.assert_called_with("Stopped log compaction service") 124 | 125 | # Task should be None after stopping 126 | assert compaction_service.compaction_task is None 127 | 128 | @pytest.mark.asyncio 129 | async def test_compaction_loop(self, compaction_service): 130 | """Test the compaction loop logic""" 131 | # Mock the run_compaction method 132 | compaction_service.run_compaction = mock.AsyncMock() 133 | 134 | # Create a mock task that will be cancelled 135 | async def mock_loop(): 136 | try: 137 | # Sleep first 138 | await asyncio.sleep(compaction_service.compaction_interval) 139 | # Run compaction 140 | await compaction_service.run_compaction() 141 | # Raise CancelledError to simulate the loop being cancelled 142 | raise asyncio.CancelledError() 143 | except asyncio.CancelledError: 144 | raise 145 | 146 | # Patch the _compaction_loop method to use our mock 147 | with mock.patch.object(compaction_service, "_compaction_loop", mock_loop): 148 | # Start the service 149 | await compaction_service.start() 150 | # Let the mock loop run 151 | await asyncio.sleep(compaction_service.compaction_interval * 1.5) 152 | # Stop the service 153 | await compaction_service.stop() 154 | 155 | # Check if run_compaction was called 156 | compaction_service.run_compaction.assert_called_once() 157 | 158 | @pytest.mark.asyncio 159 | async def test_run_compaction_already_running(self, compaction_service): 160 | """Test that run_compaction skips when already running""" 161 | # Set compaction_running to True 162 | compaction_service.compaction_running = True 163 | 164 | # Use the module-level logger instead of the global logging module 165 | with mock.patch("pydistributedkv.service.compaction.logger.warning") as mock_warning: 166 | result = await compaction_service.run_compaction() 167 | mock_warning.assert_called_once_with("Compaction already in progress, skipping") 168 | 169 | # Should return (0, 0) without doing anything 170 | assert result == (0, 0) 171 | 172 | @pytest.mark.asyncio 173 | async def test_run_compaction_too_soon(self, compaction_service): 174 | """Test that run_compaction skips when run too soon after the last run""" 175 | # Set last compaction time to now 176 | compaction_service.last_compaction = datetime.datetime.now() 177 | 178 | # Use the module-level logger instead of the global logging module 179 | with mock.patch("pydistributedkv.service.compaction.logger.info") as mock_info: 180 | result = await compaction_service.run_compaction() 181 | # Check that appropriate message was logged 182 | assert "Skipping compaction" in mock_info.call_args[0][0] 183 | 184 | # Should return (0, 0) without doing anything 185 | assert result == (0, 0) 186 | 187 | @pytest.mark.asyncio 188 | async def test_run_compaction_force(self, compaction_service): 189 | """Test that run_compaction with force=True ignores minimum interval""" 190 | # Set last compaction time to now 191 | compaction_service.last_compaction = datetime.datetime.now() 192 | 193 | # Ensure storage.compact_log returns a known value 194 | mock_result = (5, 200) # 5 segments, 200 entries 195 | compaction_service.storage.compact_log.return_value = mock_result 196 | 197 | result = await compaction_service.run_compaction(force=True) 198 | 199 | # Should call compact_log and return its result 200 | compaction_service.storage.compact_log.assert_called_once() 201 | assert result == mock_result 202 | 203 | # Should update history 204 | assert len(compaction_service.compaction_history) == 1 205 | assert compaction_service.compaction_history[0]["segments_compacted"] == mock_result[0] 206 | assert compaction_service.compaction_history[0]["entries_removed"] == mock_result[1] 207 | 208 | @pytest.mark.asyncio 209 | async def test_run_compaction_normal(self, compaction_service): 210 | """Test normal compaction run""" 211 | # No previous compaction 212 | assert compaction_service.last_compaction is None 213 | 214 | # Ensure storage.compact_log returns a known value 215 | mock_result = (2, 50) # 2 segments, 50 entries 216 | compaction_service.storage.compact_log.return_value = mock_result 217 | 218 | # Use the module-level logger instead of the global logging module 219 | with mock.patch("pydistributedkv.service.compaction.logger.info") as mock_info: 220 | result = await compaction_service.run_compaction() 221 | 222 | # Check that start and completion logs were made 223 | log_messages = [call[0][0] for call in mock_info.call_args_list] 224 | assert any("Starting log compaction" in msg for msg in log_messages) 225 | assert any("Compaction completed" in msg for msg in log_messages) 226 | 227 | # Should return storage.compact_log result 228 | assert result == mock_result 229 | 230 | # Should have updated last_compaction 231 | assert compaction_service.last_compaction is not None 232 | 233 | # Should have updated history 234 | assert len(compaction_service.compaction_history) == 1 235 | assert compaction_service.compaction_history[0]["segments_compacted"] == mock_result[0] 236 | assert compaction_service.compaction_history[0]["entries_removed"] == mock_result[1] 237 | assert "duration_seconds" in compaction_service.compaction_history[0] 238 | assert "timestamp" in compaction_service.compaction_history[0] 239 | 240 | @pytest.mark.asyncio 241 | async def test_run_compaction_error(self, compaction_service): 242 | """Test compaction run with error""" 243 | # Make storage.compact_log raise an exception 244 | compaction_service.storage.compact_log.side_effect = ValueError("Test error") 245 | 246 | # Should propagate the error 247 | with pytest.raises(ValueError, match="Test error"): 248 | await compaction_service.run_compaction() 249 | 250 | # compaction_running should be reset to False 251 | assert compaction_service.compaction_running is False 252 | 253 | # No history should be added 254 | assert len(compaction_service.compaction_history) == 0 255 | 256 | @pytest.mark.asyncio 257 | async def test_compaction_history_limit(self, compaction_service): 258 | """Test compaction history is limited to 10 entries""" 259 | # Make 12 compactions 260 | for i in range(12): 261 | compaction_service.storage.compact_log.return_value = (i, i * 10) 262 | await compaction_service.run_compaction(force=True) 263 | 264 | # History should be limited to 10 entries 265 | assert len(compaction_service.compaction_history) == 10 266 | 267 | # The oldest 2 entries should have been removed 268 | assert compaction_service.compaction_history[0]["segments_compacted"] == 2 269 | assert compaction_service.compaction_history[0]["entries_removed"] == 20 270 | 271 | def test_get_status(self, compaction_service): 272 | """Test getting status information""" 273 | # Set some data 274 | compaction_service.last_compaction = datetime.datetime(2023, 1, 1, 12, 0, 0) 275 | compaction_service.compaction_history.append( 276 | {"timestamp": "2023-01-01T12:00:00", "duration_seconds": 0.5, "segments_compacted": 3, "entries_removed": 100} 277 | ) 278 | 279 | status = compaction_service.get_status() 280 | 281 | assert status["enabled"] == compaction_service.enabled 282 | assert status["compaction_interval_seconds"] == compaction_service.compaction_interval 283 | assert status["min_compaction_interval_seconds"] == compaction_service.min_compaction_interval 284 | assert status["last_compaction"] == "2023-01-01T12:00:00" 285 | assert status["compaction_running"] == compaction_service.compaction_running 286 | assert len(status["compaction_history"]) == 1 287 | -------------------------------------------------------------------------------- /tests/service/test_heartbeat.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from unittest.mock import AsyncMock, MagicMock, patch 4 | 5 | import pytest 6 | 7 | from pydistributedkv.service.heartbeat import HeartbeatService 8 | 9 | # Test constants 10 | SERVICE_NAME = "test-service" 11 | SERVER_ID = "test-server" 12 | SERVER_URL = "http://localhost:8000" 13 | OTHER_SERVER_ID = "other-server" 14 | OTHER_SERVER_URL = "http://localhost:8001" 15 | 16 | 17 | @pytest.fixture 18 | def heartbeat_service(): 19 | """Create a HeartbeatService instance for testing""" 20 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 21 | return service 22 | 23 | 24 | def test_init(heartbeat_service): 25 | """Test initialization of HeartbeatService""" 26 | assert heartbeat_service.service_name == SERVICE_NAME 27 | assert heartbeat_service.server_id == SERVER_ID 28 | assert heartbeat_service.server_url == SERVER_URL 29 | assert heartbeat_service.servers == {} 30 | assert not heartbeat_service._monitor_running 31 | assert not heartbeat_service._send_running 32 | 33 | 34 | def test_register_server(heartbeat_service): 35 | """Test registering a server""" 36 | # Given 37 | assert len(heartbeat_service.servers) == 0 38 | 39 | # When 40 | heartbeat_service.register_server(OTHER_SERVER_ID, OTHER_SERVER_URL) 41 | 42 | # Then 43 | assert len(heartbeat_service.servers) == 1 44 | assert OTHER_SERVER_ID in heartbeat_service.servers 45 | assert heartbeat_service.servers[OTHER_SERVER_ID]["url"] == OTHER_SERVER_URL 46 | assert heartbeat_service.servers[OTHER_SERVER_ID]["status"] == "healthy" 47 | assert "last_heartbeat" in heartbeat_service.servers[OTHER_SERVER_ID] 48 | 49 | 50 | def test_deregister_server(heartbeat_service): 51 | """Test deregistering a server""" 52 | # Given 53 | heartbeat_service.register_server(OTHER_SERVER_ID, OTHER_SERVER_URL) 54 | assert len(heartbeat_service.servers) == 1 55 | 56 | # When 57 | heartbeat_service.deregister_server(OTHER_SERVER_ID) 58 | 59 | # Then 60 | assert len(heartbeat_service.servers) == 0 61 | assert OTHER_SERVER_ID not in heartbeat_service.servers 62 | 63 | 64 | def test_deregister_nonexistent_server(heartbeat_service): 65 | """Test deregistering a server that doesn't exist""" 66 | # Given 67 | assert "nonexistent" not in heartbeat_service.servers 68 | 69 | # When/Then - should not raise exception 70 | heartbeat_service.deregister_server("nonexistent") 71 | 72 | 73 | def test_record_heartbeat(heartbeat_service): 74 | """Test recording a heartbeat from a server""" 75 | # Given 76 | heartbeat_service.register_server(OTHER_SERVER_ID, OTHER_SERVER_URL) 77 | initial_heartbeat = heartbeat_service.servers[OTHER_SERVER_ID]["last_heartbeat"] 78 | 79 | # Wait briefly to ensure timestamp changes 80 | time.sleep(0.001) 81 | 82 | # When 83 | heartbeat_service.record_heartbeat(OTHER_SERVER_ID) 84 | 85 | # Then 86 | assert heartbeat_service.servers[OTHER_SERVER_ID]["last_heartbeat"] > initial_heartbeat 87 | 88 | 89 | def test_record_heartbeat_unknown_server(heartbeat_service): 90 | """Test recording a heartbeat from an unknown server""" 91 | # When/Then - Should log warning but not raise exception 92 | heartbeat_service.record_heartbeat("unknown-server") 93 | # Verify no server was added 94 | assert "unknown-server" not in heartbeat_service.servers 95 | 96 | 97 | def test_record_heartbeat_recovers_down_server(heartbeat_service): 98 | """Test that a down server becomes healthy when heartbeat is received""" 99 | # Given 100 | heartbeat_service.register_server(OTHER_SERVER_ID, OTHER_SERVER_URL) 101 | heartbeat_service.servers[OTHER_SERVER_ID]["status"] = "down" 102 | assert heartbeat_service.servers[OTHER_SERVER_ID]["status"] == "down" 103 | 104 | # When 105 | heartbeat_service.record_heartbeat(OTHER_SERVER_ID) 106 | 107 | # Then 108 | assert heartbeat_service.servers[OTHER_SERVER_ID]["status"] == "healthy" 109 | 110 | 111 | def test_get_server_status(heartbeat_service): 112 | """Test getting status of a specific server""" 113 | # Given 114 | heartbeat_service.register_server(OTHER_SERVER_ID, OTHER_SERVER_URL) 115 | 116 | # When 117 | status = heartbeat_service.get_server_status(OTHER_SERVER_ID) 118 | 119 | # Then 120 | assert status is not None 121 | assert status["url"] == OTHER_SERVER_URL 122 | assert status["status"] == "healthy" 123 | 124 | 125 | def test_get_server_status_unknown(heartbeat_service): 126 | """Test getting status of an unknown server""" 127 | # When 128 | status = heartbeat_service.get_server_status("unknown-server") 129 | 130 | # Then 131 | assert status is None 132 | 133 | 134 | def test_get_all_statuses(heartbeat_service): 135 | """Test getting status of all servers""" 136 | # Given 137 | heartbeat_service.register_server("server1", "http://server1:8001") 138 | heartbeat_service.register_server("server2", "http://server2:8002") 139 | 140 | # When 141 | statuses = heartbeat_service.get_all_statuses() 142 | 143 | # Then 144 | assert len(statuses) == 2 145 | assert "server1" in statuses 146 | assert "server2" in statuses 147 | assert "url" in statuses["server1"] 148 | assert "status" in statuses["server1"] 149 | assert "last_heartbeat" in statuses["server1"] 150 | assert "seconds_since_last_heartbeat" in statuses["server1"] 151 | 152 | 153 | def test_get_healthy_servers(heartbeat_service): 154 | """Test getting only healthy servers""" 155 | # Given 156 | heartbeat_service.register_server("server1", "http://server1:8001") 157 | heartbeat_service.register_server("server2", "http://server2:8002") 158 | heartbeat_service.servers["server2"]["status"] = "down" 159 | 160 | # When 161 | healthy_servers = heartbeat_service.get_healthy_servers() 162 | 163 | # Then 164 | assert len(healthy_servers) == 1 165 | assert "server1" in healthy_servers 166 | assert "server2" not in healthy_servers 167 | assert healthy_servers["server1"] == "http://server1:8001" 168 | 169 | 170 | @pytest.mark.asyncio 171 | async def test_start_monitoring(): 172 | """Test starting heartbeat monitoring""" 173 | # Create the service 174 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 175 | 176 | # We need to mock the _monitor_heartbeats method in a way that 177 | # allows us to verify it was called without actually running it 178 | with patch.object(service, "_monitor_heartbeats") as mock_coroutine: 179 | # Configure the mock to return a coroutine that can be awaited 180 | # but doesn't actually do anything 181 | mock_coroutine.return_value = asyncio.sleep(0) 182 | 183 | # When 184 | await service.start_monitoring() 185 | 186 | # Then 187 | assert service._monitor_running is True 188 | assert len(service._background_tasks) > 0 189 | 190 | # Verify the mock was called when creating the task 191 | mock_coroutine.assert_called_once() 192 | 193 | # Cleanup 194 | await service.stop() 195 | 196 | 197 | @pytest.mark.asyncio 198 | async def test_start_sending(): 199 | """Test starting heartbeat sending""" 200 | # Create the service 201 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 202 | 203 | # Similar approach to test_start_monitoring 204 | with patch.object(service, "_send_heartbeats") as mock_coroutine: 205 | # Configure the mock to return a coroutine that can be awaited 206 | mock_coroutine.return_value = asyncio.sleep(0) 207 | 208 | # When 209 | await service.start_sending() 210 | 211 | # Then 212 | assert service._send_running is True 213 | assert len(service._background_tasks) > 0 214 | 215 | # Verify the mock was called when creating the task 216 | mock_coroutine.assert_called_once() 217 | 218 | # Cleanup 219 | await service.stop() 220 | 221 | 222 | @pytest.mark.asyncio 223 | async def test_stop(): 224 | """Test stopping the heartbeat service""" 225 | # Given 226 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 227 | 228 | # Create mocks that just return immediately awaitable coroutines 229 | with ( 230 | patch.object(service, "_monitor_heartbeats", return_value=asyncio.sleep(0)), 231 | patch.object(service, "_send_heartbeats", return_value=asyncio.sleep(0)), 232 | ): 233 | 234 | # Start both services 235 | await service.start_monitoring() 236 | await service.start_sending() 237 | 238 | # Confirm they started correctly 239 | assert service._monitor_running is True 240 | assert service._send_running is True 241 | assert len(service._background_tasks) == 2 242 | 243 | # When 244 | await service.stop() 245 | 246 | # Then 247 | assert service._monitor_running is False 248 | assert service._send_running is False 249 | 250 | 251 | @pytest.mark.asyncio 252 | async def test_monitor_heartbeats_marks_server_down(): 253 | """Test that monitor detects and marks servers as down after timeout""" 254 | # Mock HEARTBEAT_TIMEOUT for testing 255 | with patch("pydistributedkv.service.heartbeat.HEARTBEAT_TIMEOUT", 0.1): 256 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 257 | service.register_server(OTHER_SERVER_ID, OTHER_SERVER_URL) 258 | 259 | # Set last_heartbeat far in the past to trigger the timeout 260 | service.servers[OTHER_SERVER_ID]["last_heartbeat"] = time.time() - 1.0 261 | 262 | # Directly call the check method to test its behavior 263 | service._check_server_heartbeats(time.time()) 264 | 265 | # Verify server is marked down 266 | assert service.servers[OTHER_SERVER_ID]["status"] == "down" 267 | 268 | 269 | @pytest.mark.asyncio 270 | async def test_send_heartbeats_to_all_servers(): 271 | """Test sending heartbeats to all registered servers""" 272 | # Given 273 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 274 | service.register_server("server1", "http://server1:8001") 275 | service.register_server("server2", "http://server2:8002") 276 | 277 | # Mock _schedule_heartbeat to test if it's called for each server 278 | with patch.object(service, "_schedule_heartbeat") as mock_schedule: 279 | # When 280 | await service._send_heartbeats_to_all_servers() 281 | 282 | # Then 283 | assert mock_schedule.call_count == 2 284 | # Check that it was called for each server 285 | calls = mock_schedule.call_args_list 286 | servers_called = {call.args[0] for call in calls} 287 | assert servers_called == {"server1", "server2"} 288 | 289 | 290 | @pytest.mark.asyncio 291 | async def test_send_single_heartbeat_success(): 292 | """Test sending a single heartbeat successfully""" 293 | # Given 294 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 295 | 296 | # Create a mock response 297 | mock_response = MagicMock() 298 | mock_response.status_code = 200 299 | 300 | # Mock asyncio.to_thread to return our mock response 301 | with patch("asyncio.to_thread", return_value=mock_response) as mock_to_thread: 302 | # When 303 | await service._send_single_heartbeat(OTHER_SERVER_ID, OTHER_SERVER_URL) 304 | 305 | # Then 306 | mock_to_thread.assert_called_once() 307 | # Check that the heartbeat URL is correct 308 | call_args = mock_to_thread.call_args 309 | assert f"{OTHER_SERVER_URL}/heartbeat" in call_args.args[1] 310 | # Check that the server_id is included in the payload 311 | assert call_args.kwargs["json"]["server_id"] == SERVER_ID 312 | 313 | 314 | @pytest.mark.asyncio 315 | async def test_send_single_heartbeat_failure(): 316 | """Test sending a single heartbeat with HTTP error""" 317 | # Given 318 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 319 | 320 | # Create a mock response with error status 321 | mock_response = MagicMock() 322 | mock_response.status_code = 500 323 | 324 | # Mock asyncio.to_thread to return our mock response 325 | with patch("asyncio.to_thread", return_value=mock_response) as mock_to_thread: 326 | # When 327 | await service._send_single_heartbeat(OTHER_SERVER_ID, OTHER_SERVER_URL) 328 | 329 | # Then 330 | mock_to_thread.assert_called_once() 331 | # No exception should be raised 332 | 333 | 334 | @pytest.mark.asyncio 335 | async def test_send_single_heartbeat_network_error(): 336 | """Test sending a single heartbeat with network error""" 337 | # Given 338 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 339 | 340 | # Mock asyncio.to_thread to raise a RequestException 341 | import requests 342 | 343 | with patch("asyncio.to_thread", side_effect=requests.RequestException("Network error")): 344 | # When/Then - should not raise exception outside 345 | await service._send_single_heartbeat(OTHER_SERVER_ID, OTHER_SERVER_URL) 346 | # No assertion needed - we're testing that no exception is raised 347 | 348 | 349 | @pytest.mark.asyncio 350 | async def test_schedule_heartbeat(): 351 | """Test scheduling a heartbeat to a server""" 352 | # Given 353 | service = HeartbeatService(SERVICE_NAME, SERVER_ID, SERVER_URL) 354 | 355 | # Mock asyncio.create_task to verify it's called with a coroutine 356 | mock_task = MagicMock() 357 | 358 | with patch("asyncio.create_task", return_value=mock_task) as mock_create_task: 359 | # Mock _send_single_heartbeat to verify it's called 360 | with patch.object(service, "_send_single_heartbeat", return_value=AsyncMock()): 361 | # When 362 | service._schedule_heartbeat(OTHER_SERVER_ID, {"url": OTHER_SERVER_URL}) 363 | 364 | # Then 365 | mock_create_task.assert_called_once() 366 | -------------------------------------------------------------------------------- /tests/service/test_request_deduplication.py: -------------------------------------------------------------------------------- 1 | import time 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from pydistributedkv.domain.models import ClientRequest, OperationType 7 | from pydistributedkv.service.request_deduplication import RequestDeduplicationService 8 | 9 | 10 | class TestRequestDeduplicationService: 11 | """Tests for the request deduplication service""" 12 | 13 | @pytest.fixture 14 | def dedup_service(self): 15 | """Create a fresh deduplication service for each test""" 16 | return RequestDeduplicationService(max_cache_size=10, expiry_seconds=1) 17 | 18 | def test_basic_deduplication(self, dedup_service): 19 | """Test that a request is correctly identified as a duplicate""" 20 | # First request 21 | client_request = ClientRequest(client_id="client1", request_id="req1", operation=OperationType.GET, key="test_key") 22 | result = {"key": "test_key", "value": "test_value"} 23 | dedup_service.mark_request_processed(client_request, result) 24 | 25 | # Same request should be identified as duplicate 26 | cached_result = dedup_service.get_processed_result("client1", "req1", OperationType.GET) 27 | assert cached_result == result 28 | assert dedup_service.total_duplicates_detected == 1 29 | assert dedup_service.same_operation_duplicates == 1 30 | 31 | def test_operation_type_differentiation(self, dedup_service): 32 | """Test that different operation types with the same request ID are treated as different requests""" 33 | # Mark a GET request as processed 34 | get_request = ClientRequest(client_id="client1", request_id="req1", operation=OperationType.GET, key="test_key") 35 | get_result = {"key": "test_key", "value": "test_value"} 36 | dedup_service.mark_request_processed(get_request, get_result) 37 | 38 | # Different operation with the same request ID should NOT be considered a duplicate 39 | set_result = dedup_service.get_processed_result("client1", "req1", OperationType.SET) 40 | assert set_result is None 41 | assert dedup_service.different_operation_duplicates == 1 42 | 43 | # But the GET operation should still be considered a duplicate 44 | get_result_again = dedup_service.get_processed_result("client1", "req1", OperationType.GET) 45 | assert get_result_again == get_result 46 | assert dedup_service.same_operation_duplicates == 1 47 | 48 | def test_different_clients_same_request_id(self, dedup_service): 49 | """Test that the same request ID from different clients is treated as different requests""" 50 | # First client request 51 | client1_request = ClientRequest(client_id="client1", request_id="req1", operation=OperationType.GET, key="test_key") 52 | client1_result = {"key": "test_key", "value": "client1_value"} 53 | dedup_service.mark_request_processed(client1_request, client1_result) 54 | 55 | # Second client with same request ID 56 | client2_request = ClientRequest(client_id="client2", request_id="req1", operation=OperationType.GET, key="test_key") 57 | client2_result = {"key": "test_key", "value": "client2_value"} 58 | dedup_service.mark_request_processed(client2_request, client2_result) 59 | 60 | # Both clients should get their own results 61 | cached_result1 = dedup_service.get_processed_result("client1", "req1", OperationType.GET) 62 | cached_result2 = dedup_service.get_processed_result("client2", "req1", OperationType.GET) 63 | 64 | assert cached_result1 == client1_result 65 | assert cached_result2 == client2_result 66 | assert cached_result1 != cached_result2 67 | 68 | def test_expiry_of_cached_results(self, dedup_service): 69 | """Test that cached results expire after the configured time""" 70 | # Set a very short expiry time for testing 71 | dedup_service.expiry_seconds = 0.1 72 | 73 | # Cache a request 74 | client_request = ClientRequest(client_id="client1", request_id="req1", operation=OperationType.GET, key="test_key") 75 | result = {"key": "test_key", "value": "test_value"} 76 | dedup_service.mark_request_processed(client_request, result) 77 | 78 | # Request should be cached initially 79 | assert dedup_service.get_processed_result("client1", "req1", OperationType.GET) == result 80 | 81 | # Wait for the cache to expire 82 | time.sleep(0.2) 83 | 84 | # After expiry, the result should be gone 85 | assert dedup_service.get_processed_result("client1", "req1", OperationType.GET) is None 86 | 87 | @patch("time.time") 88 | def test_cache_cleanup_on_access(self, mock_time, dedup_service): 89 | """Test that expired entries are cleaned up when accessing the cache""" 90 | # Set up mock times 91 | mock_time.return_value = 1000 # Starting time 92 | 93 | # Cache a request 94 | client_request = ClientRequest(client_id="client1", request_id="req1", operation=OperationType.GET, key="test_key") 95 | result = {"key": "test_key", "value": "test_value"} 96 | dedup_service.mark_request_processed(client_request, result) 97 | 98 | # Add another entry 99 | client_request2 = ClientRequest(client_id="client1", request_id="req2", operation=OperationType.GET, key="test_key2") 100 | result2 = {"key": "test_key2", "value": "test_value2"} 101 | dedup_service.mark_request_processed(client_request2, result2) 102 | 103 | # Advance time beyond expiry 104 | mock_time.return_value = 1000 + dedup_service.expiry_seconds + 1 105 | 106 | # Access the cache, which should trigger cleanup 107 | dedup_service.get_processed_result("client1", "req1", OperationType.GET) 108 | 109 | # All entries should be gone 110 | assert len(dedup_service.processed_requests) == 0 111 | assert dedup_service.total_cache_cleanups >= 1 112 | 113 | def test_same_request_different_operations(self, dedup_service): 114 | """Test handling multiple operations on the same request ID""" 115 | # First operation: SET 116 | set_request = ClientRequest( 117 | client_id="client1", request_id="req1", operation=OperationType.SET, key="test_key", value="initial_value" 118 | ) 119 | set_result = {"status": "ok", "id": 1} 120 | dedup_service.mark_request_processed(set_request, set_result) 121 | 122 | # Second operation: DELETE 123 | delete_request = ClientRequest(client_id="client1", request_id="req1", operation=OperationType.DELETE, key="test_key") 124 | delete_result = {"status": "ok", "id": 2} 125 | dedup_service.mark_request_processed(delete_request, delete_result) 126 | 127 | # Each operation should have its own cached result 128 | cached_set_result = dedup_service.get_processed_result("client1", "req1", OperationType.SET) 129 | cached_delete_result = dedup_service.get_processed_result("client1", "req1", OperationType.DELETE) 130 | 131 | assert cached_set_result == set_result 132 | assert cached_delete_result == delete_result 133 | assert cached_set_result != cached_delete_result 134 | 135 | def test_get_stats(self, dedup_service): 136 | """Test the statistics collection feature""" 137 | # Cache a few different requests 138 | for i in range(3): 139 | client_request = ClientRequest(client_id="client1", request_id=f"req{i}", operation=OperationType.GET, key=f"key{i}") 140 | dedup_service.mark_request_processed(client_request, {"value": f"value{i}"}) 141 | 142 | # Create some duplicates 143 | dedup_service.get_processed_result("client1", "req0", OperationType.GET) 144 | dedup_service.get_processed_result("client1", "req1", OperationType.GET) 145 | dedup_service.get_processed_result("client1", "req0", OperationType.SET) # Different operation 146 | 147 | # Check stats 148 | stats = dedup_service.get_stats() 149 | 150 | assert stats["total_requests_cached"] == 3 151 | assert stats["total_duplicates_detected"] == 2 152 | assert stats["same_operation_duplicates"] == 2 153 | assert stats["different_operation_duplicates"] == 1 154 | assert stats["current_cache_size"] == 3 155 | assert stats["unique_request_ids"] == 3 156 | assert stats["total_client_count"] == 1 157 | -------------------------------------------------------------------------------- /tests/service/test_storage_with_segmented_wal.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import unittest 6 | from unittest.mock import patch 7 | 8 | from pydistributedkv.domain.models import OperationType, WAL 9 | from pydistributedkv.service.storage import KeyValueStorage 10 | 11 | 12 | class TestStorageWithSegmentedWAL(unittest.TestCase): 13 | def setUp(self): 14 | # Create a temporary directory for test WAL files 15 | self.temp_dir = tempfile.mkdtemp() 16 | self.wal_path = os.path.join(self.temp_dir, "wal.log") 17 | 18 | # Use a small segment size for testing 19 | self.small_segment_size = 200 # bytes 20 | 21 | def tearDown(self): 22 | # Clean up the temporary directory after tests 23 | shutil.rmtree(self.temp_dir) 24 | 25 | def test_storage_replay_after_restart(self): 26 | """Test that storage correctly rebuilds state after restart with segmented WAL""" 27 | # Create initial storage with some data 28 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 29 | storage = KeyValueStorage(wal) 30 | 31 | # Add enough data to span multiple segments 32 | test_data = {} 33 | for i in range(20): 34 | key = f"test-key-{i}" 35 | value = f"test-value-{i}" * 3 # Make the value larger to reach segment limit faster 36 | storage.set(key, value) 37 | test_data[key] = value 38 | 39 | # Delete some keys to verify delete operations are also replayed 40 | for i in range(0, 20, 4): # Delete every 4th key 41 | key = f"test-key-{i}" 42 | storage.delete(key) 43 | if key in test_data: 44 | del test_data[key] 45 | 46 | # Verify we have multiple segments 47 | segments = wal.get_segment_files() 48 | self.assertGreater(len(segments), 1, "Expected multiple segments for this test") 49 | 50 | # Create a new storage instance that will replay the log 51 | new_wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 52 | 53 | # Create a new storage with the existing WAL 54 | with patch("builtins.print"): # suppress print statements during replay 55 | new_storage = KeyValueStorage(new_wal) 56 | 57 | # Verify all data was correctly replayed 58 | for key, expected_value in test_data.items(): 59 | self.assertEqual(new_storage.get(key), expected_value, f"Replayed value for {key} doesn't match expected value") 60 | 61 | # Verify deleted keys are not present 62 | for i in range(0, 20, 4): 63 | self.assertIsNone(new_storage.get(f"test-key-{i}"), f"Key test-key-{i} should have been deleted") 64 | 65 | def test_storage_with_complex_operations(self): 66 | """Test storage with complex operations across multiple segments""" 67 | # Create a WAL with small segment size 68 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 69 | storage = KeyValueStorage(wal) 70 | 71 | # Test complex data types 72 | complex_data = [ 73 | {"key": "dict-key", "value": {"nested": "value", "list": [1, 2, 3], "number": 42}}, 74 | {"key": "list-key", "value": [1, "string", {"nested": "object"}, None]}, 75 | {"key": "number-key", "value": 12345.6789}, 76 | {"key": "bool-key", "value": True}, 77 | {"key": "null-key", "value": None}, 78 | {"key": "string-key", "value": "This is a longer string value that should take up more space in the log."}, 79 | ] 80 | 81 | # Set the values 82 | expected_data = {} 83 | for item in complex_data: 84 | storage.set(item["key"], item["value"]) 85 | expected_data[item["key"]] = item["value"] 86 | 87 | # Update some values to create more log entries 88 | for i in range(10): 89 | storage.set("counter", i) 90 | expected_data["counter"] = i 91 | 92 | # Delete and recreate a key multiple times 93 | for i in range(5): 94 | storage.set("temp-key", f"temp-value-{i}") 95 | if i % 2 == 0: # Delete on even iterations 96 | storage.delete("temp-key") 97 | if "temp-key" in expected_data: 98 | del expected_data["temp-key"] 99 | else: 100 | expected_data["temp-key"] = f"temp-value-{i}" 101 | 102 | # Verify multiple segments were created 103 | segments = wal.get_segment_files() 104 | self.assertGreater(len(segments), 1, f"Expected multiple segments but got {len(segments)}") 105 | 106 | # Create a new storage instance to test replay 107 | new_wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 108 | with patch("builtins.print"): # suppress print statements 109 | new_storage = KeyValueStorage(new_wal) 110 | 111 | # Verify all data was correctly replayed 112 | for key, expected_value in expected_data.items(): 113 | actual_value = new_storage.get(key) 114 | 115 | # For complex objects, compare serialized JSON to handle floating-point differences 116 | if isinstance(expected_value, (dict, list)): 117 | self.assertEqual( 118 | json.dumps(actual_value, sort_keys=True), 119 | json.dumps(expected_value, sort_keys=True), 120 | f"Replayed value for {key} doesn't match expected value", 121 | ) 122 | else: 123 | self.assertEqual(actual_value, expected_value, f"Replayed value for {key} doesn't match expected value") 124 | 125 | def test_wal_with_huge_values(self): 126 | """Test WAL segmentation with large values that exceed segment size""" 127 | # Create a WAL with small segment size 128 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 129 | storage = KeyValueStorage(wal) 130 | 131 | # Add a value that is larger than the segment size 132 | large_value = "x" * (self.small_segment_size * 2) # Value larger than segment size 133 | storage.set("large-key", large_value) 134 | 135 | # Add some more normal values 136 | storage.set("key1", "value1") 137 | storage.set("key2", "value2") 138 | 139 | # Verify we have at least two segments (since large value should force segment rollover) 140 | segments = wal.get_segment_files() 141 | self.assertGreater(len(segments), 1, "Expected multiple segments after adding large value") 142 | 143 | # Create a new storage instance to verify replay works 144 | new_wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 145 | new_storage = KeyValueStorage(new_wal) 146 | 147 | # Verify all values, including the large one, were replayed correctly 148 | self.assertEqual(new_storage.get("large-key"), large_value) 149 | self.assertEqual(new_storage.get("key1"), "value1") 150 | self.assertEqual(new_storage.get("key2"), "value2") 151 | 152 | def test_wal_truncated_segment(self): 153 | """Test handling of truncated segment files""" 154 | # Create a WAL with small segment size 155 | wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 156 | 157 | # Add enough entries to create at least 2 segments 158 | for i in range(20): 159 | wal.append(OperationType.SET, f"key{i}", f"value{i}" * 5) 160 | 161 | # Verify we have multiple segments 162 | segments = wal.get_segment_files() 163 | self.assertGreater(len(segments), 1, "Expected multiple segments for this test") 164 | 165 | # Truncate the last segment file (remove the last few bytes) 166 | with open(segments[-1], "r+") as f: 167 | content = f.read() 168 | # Truncate the file to half its size 169 | truncated_content = content[: len(content) // 2] 170 | f.seek(0) 171 | f.write(truncated_content) 172 | f.truncate() 173 | 174 | # Create a new WAL instance with the truncated segment 175 | # It should still load successfully but skip the corrupted entries 176 | with patch("builtins.print"): # suppress print statements 177 | new_wal = WAL(self.wal_path, max_segment_size=self.small_segment_size) 178 | 179 | # Verify we can still read entries from the WAL 180 | entries = new_wal.read_from(0) 181 | self.assertGreater(len(entries), 0, "Expected some valid entries despite truncation") 182 | -------------------------------------------------------------------------------- /tests/test_compaction_integration.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import tempfile 4 | import time 5 | 6 | import pytest 7 | 8 | from pydistributedkv.domain.models import OperationType, WAL 9 | from pydistributedkv.service.compaction import LogCompactionService 10 | from pydistributedkv.service.storage import KeyValueStorage 11 | 12 | 13 | @pytest.fixture 14 | def temp_dir(): 15 | """Create a temporary directory for test files""" 16 | with tempfile.TemporaryDirectory() as tmpdirname: 17 | yield tmpdirname 18 | 19 | 20 | @pytest.fixture 21 | def wal(temp_dir): 22 | """Create a WAL with a small segment size for testing""" 23 | log_path = os.path.join(temp_dir, "test_log") 24 | # Use a very small segment size to trigger multiple segments 25 | return WAL(log_path, max_segment_size=100) 26 | 27 | 28 | @pytest.fixture 29 | def storage(wal): 30 | """Create a storage with the test WAL""" 31 | return KeyValueStorage(wal) 32 | 33 | 34 | @pytest.fixture 35 | def compaction_service(storage): 36 | """Create a compaction service with test settings""" 37 | return LogCompactionService( 38 | storage=storage, 39 | compaction_interval=0.5, # Run every 0.5 seconds 40 | min_compaction_interval=0.1, # Allow compaction after 0.1 seconds 41 | enabled=True, 42 | ) 43 | 44 | 45 | class TestCompactionIntegration: 46 | 47 | @pytest.mark.asyncio 48 | async def test_compaction_with_real_storage(self, compaction_service, storage): 49 | """Test that compaction works with real storage and WAL""" 50 | # Add many entries to create multiple segments 51 | for i in range(100): 52 | key = f"key_{i}" 53 | value = f"value_{i}" 54 | storage.set(key, value) 55 | 56 | # Overwrite some keys to create redundancy 57 | if i % 3 == 0: # Every third key gets overwritten 58 | storage.set(key, f"updated_{value}") 59 | 60 | # Delete some keys 61 | if i % 7 == 0: # Every seventh key gets deleted 62 | storage.delete(key) 63 | 64 | # Check we have data and segments 65 | assert len(storage.get_all_keys()) > 0 66 | segments_before = len(storage.wal.get_segment_files()) 67 | assert segments_before > 1, "Test needs multiple segments to be meaningful" 68 | 69 | # Run compaction manually 70 | segments_compacted, entries_removed = await compaction_service.run_compaction(force=True) 71 | 72 | # Verify compaction results 73 | assert segments_compacted > 0, "No segments were compacted" 74 | assert entries_removed > 0, "No entries were removed" 75 | 76 | # Check history was updated 77 | assert len(compaction_service.compaction_history) == 1 78 | assert compaction_service.compaction_history[0]["segments_compacted"] == segments_compacted 79 | assert compaction_service.compaction_history[0]["entries_removed"] == entries_removed 80 | 81 | # Verify status returns correct information 82 | status = compaction_service.get_status() 83 | assert status["compaction_history"][0]["segments_compacted"] == segments_compacted 84 | 85 | # Verify data integrity - all accessible keys should still have correct values 86 | for key in storage.get_all_keys(): 87 | if key.startswith("key_"): 88 | i = int(key.split("_")[1]) 89 | expected_value = None 90 | 91 | if i % 7 == 0: 92 | # These keys should have been deleted 93 | continue 94 | elif i % 3 == 0: 95 | # These keys should have updated values 96 | expected_value = f"updated_value_{i}" 97 | else: 98 | # These keys should have original values 99 | expected_value = f"value_{i}" 100 | 101 | actual_value = storage.get(key) 102 | assert actual_value == expected_value, f"Key {key} has incorrect value" 103 | 104 | @pytest.mark.asyncio 105 | async def test_compaction_service_lifecycle(self, compaction_service): 106 | """Test the full lifecycle of a compaction service""" 107 | # Start the service 108 | await compaction_service.start() 109 | 110 | # Verify it's running 111 | assert compaction_service.compaction_task is not None 112 | 113 | # Let it run for a bit to complete at least one compaction cycle 114 | await asyncio.sleep(1.0) 115 | 116 | # Stop the service 117 | await compaction_service.stop() 118 | 119 | # Verify it's stopped 120 | assert compaction_service.compaction_task is None 121 | 122 | # Check if at least one compaction happened 123 | assert len(compaction_service.compaction_history) > 0 124 | -------------------------------------------------------------------------------- /tests/test_storage_versioning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from pydistributedkv.domain.models import LogEntry, OperationType, VersionedValue, WAL 8 | from pydistributedkv.service.storage import KeyValueStorage 9 | 10 | 11 | class TestStorageVersioning: 12 | @pytest.fixture 13 | def wal(self): 14 | """Create a temporary WAL for testing""" 15 | with tempfile.TemporaryDirectory() as tmp_dir: 16 | wal_path = os.path.join(tmp_dir, "test_wal.log") 17 | yield WAL(wal_path) 18 | 19 | @pytest.fixture 20 | def storage(self, wal): 21 | """Create a storage with the test WAL""" 22 | return KeyValueStorage(wal) 23 | 24 | def test_set_new_key(self, storage): 25 | """Test setting a new key creates version 1""" 26 | entry, version = storage.set("key1", "value1") 27 | 28 | assert version == 1 29 | assert storage.get("key1") == "value1" 30 | assert storage.get_latest_version("key1") == 1 31 | 32 | def test_set_existing_key_increments_version(self, storage): 33 | """Test setting an existing key increments the version""" 34 | storage.set("key1", "value1") 35 | entry, version = storage.set("key1", "value2") 36 | 37 | assert version == 2 38 | assert storage.get("key1") == "value2" 39 | assert storage.get_latest_version("key1") == 2 40 | 41 | def test_setting_specific_version_for_new_key(self, storage): 42 | """Test setting a specific version for a new key""" 43 | entry, version = storage.set("key1", "value1", version=5) 44 | 45 | assert version == 5 46 | assert storage.get("key1") == "value1" 47 | assert storage.get_latest_version("key1") == 5 48 | 49 | def test_version_conflict_returns_none(self, storage): 50 | """Test setting with an outdated version returns None and current version""" 51 | storage.set("key1", "value1") # Version 1 52 | storage.set("key1", "value2") # Version 2 53 | 54 | # Try to set with version 1 which is outdated 55 | entry, version = storage.set("key1", "value3", version=1) 56 | 57 | assert entry is None 58 | assert version == 2 # Current version 59 | assert storage.get("key1") == "value2" # Value unchanged 60 | 61 | def test_get_with_version(self, storage): 62 | """Test getting a specific version of a key""" 63 | storage.set("key1", "value1") # Version 1 64 | storage.set("key1", "value2") # Version 2 65 | storage.set("key1", "value3") # Version 3 66 | 67 | assert storage.get("key1", version=1) == "value1" 68 | assert storage.get("key1", version=2) == "value2" 69 | assert storage.get("key1", version=3) == "value3" 70 | assert storage.get("key1") == "value3" # Latest version 71 | assert storage.get("key1", version=4) is None # Non-existent version 72 | 73 | def test_get_with_version_returns_tuple(self, storage): 74 | """Test get_with_version returns both value and version""" 75 | storage.set("key1", "value1") # Version 1 76 | storage.set("key1", "value2") # Version 2 77 | 78 | result = storage.get_with_version("key1", version=1) 79 | assert result == ("value1", 1) 80 | 81 | result = storage.get_with_version("key1") # Latest version 82 | assert result == ("value2", 2) 83 | 84 | def test_get_version_history(self, storage): 85 | """Test getting the version history of a key""" 86 | storage.set("key1", "value1") # Version 1 87 | storage.set("key1", "value2") # Version 2 88 | storage.set("key1", "value3") # Version 3 89 | 90 | history = storage.get_version_history("key1") 91 | 92 | assert history == {1: "value1", 2: "value2", 3: "value3"} 93 | 94 | def test_get_version_history_nonexistent_key(self, storage): 95 | """Test getting history for a non-existent key returns None""" 96 | assert storage.get_version_history("nonexistent") is None 97 | 98 | def test_replay_log_with_versions(self, wal): 99 | """Test replaying a WAL with versioned entries""" 100 | # Create entries in the WAL 101 | wal.append(OperationType.SET, "key1", "value1", version=1) 102 | wal.append(OperationType.SET, "key1", "value2", version=2) 103 | wal.append(OperationType.SET, "key2", "value-a", version=1) 104 | 105 | # Create a new storage that will replay the WAL 106 | storage = KeyValueStorage(wal) 107 | 108 | assert storage.get("key1") == "value2" 109 | assert storage.get_latest_version("key1") == 2 110 | 111 | assert storage.get("key2") == "value-a" 112 | assert storage.get_latest_version("key2") == 1 113 | 114 | # Check version histories 115 | assert storage.get_version_history("key1") == {1: "value1", 2: "value2"} 116 | assert storage.get_version_history("key2") == {1: "value-a"} 117 | 118 | def test_delete_removes_all_versions(self, storage): 119 | """Test that delete removes a key and all its versions""" 120 | storage.set("key1", "value1") # Version 1 121 | storage.set("key1", "value2") # Version 2 122 | 123 | entry = storage.delete("key1") 124 | 125 | assert entry is not None 126 | assert storage.get("key1") is None 127 | assert storage.get_version_history("key1") is None 128 | assert storage.get_latest_version("key1") is None 129 | -------------------------------------------------------------------------------- /tests/test_versioned_value.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pydistributedkv.domain.models import VersionedValue 4 | 5 | 6 | class TestVersionedValue: 7 | def test_init(self): 8 | """Test VersionedValue initialization""" 9 | vv = VersionedValue(current_version=1, value="test") 10 | 11 | assert vv.current_version == 1 12 | assert vv.value == "test" 13 | assert vv.history is None 14 | 15 | def test_get_value_latest(self): 16 | """Test getting the latest value""" 17 | vv = VersionedValue(current_version=3, value="version3", history={1: "version1", 2: "version2"}) 18 | 19 | # When version is None, it should return the latest value 20 | assert vv.get_value() == "version3" 21 | 22 | def test_get_value_specific_version(self): 23 | """Test getting a specific version""" 24 | vv = VersionedValue(current_version=3, value="version3", history={1: "version1", 2: "version2"}) 25 | 26 | assert vv.get_value(version=2) == "version2" 27 | assert vv.get_value(version=1) == "version1" 28 | assert vv.get_value(version=3) == "version3" 29 | 30 | def test_get_value_nonexistent_version(self): 31 | """Test getting a version that doesn't exist""" 32 | vv = VersionedValue(current_version=3, value="version3", history={1: "version1", 2: "version2"}) 33 | 34 | assert vv.get_value(version=4) is None 35 | assert vv.get_value(version=0) is None 36 | 37 | def test_update_next_version(self): 38 | """Test updating to the next consecutive version""" 39 | vv = VersionedValue(current_version=1, value="version1") 40 | 41 | vv.update("version2", 2) 42 | 43 | assert vv.current_version == 2 44 | assert vv.value == "version2" 45 | assert vv.history == {1: "version1"} 46 | 47 | def test_update_nonconsecutive_version(self): 48 | """Test updating with a gap in version numbers""" 49 | vv = VersionedValue(current_version=1, value="version1") 50 | 51 | vv.update("version3", 3) 52 | 53 | assert vv.current_version == 3 54 | assert vv.value == "version3" 55 | assert vv.history == {1: "version1"} 56 | 57 | def test_update_older_version(self): 58 | """Test that updating with an older version is ignored""" 59 | vv = VersionedValue(current_version=3, value="version3", history={1: "version1", 2: "version2"}) 60 | 61 | vv.update("old_value", 2) 62 | 63 | # The update should be ignored 64 | assert vv.current_version == 3 65 | assert vv.value == "version3" 66 | assert 2 in vv.history 67 | assert vv.history[2] == "version2" 68 | 69 | def test_update_same_version(self): 70 | """Test that updating with the same version is ignored""" 71 | vv = VersionedValue(current_version=3, value="version3", history={1: "version1", 2: "version2"}) 72 | 73 | vv.update("new_version3", 3) 74 | 75 | # The update should be ignored 76 | assert vv.current_version == 3 77 | assert vv.value == "version3" 78 | 79 | def test_multiple_updates(self): 80 | """Test multiple sequential updates""" 81 | vv = VersionedValue(current_version=1, value="version1") 82 | 83 | vv.update("version2", 2) 84 | vv.update("version3", 3) 85 | vv.update("version4", 4) 86 | 87 | assert vv.current_version == 4 88 | assert vv.value == "version4" 89 | assert vv.history == {1: "version1", 2: "version2", 3: "version3"} 90 | 91 | # Check that we can retrieve all versions 92 | assert vv.get_value(1) == "version1" 93 | assert vv.get_value(2) == "version2" 94 | assert vv.get_value(3) == "version3" 95 | assert vv.get_value(4) == "version4" 96 | assert vv.get_value() == "version4" 97 | --------------------------------------------------------------------------------