├── .github └── workflows │ ├── build_docs.yml │ ├── end2endtest.yml │ ├── lint.yml │ ├── publish.yml │ └── unittest.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.md ├── README.md ├── docs ├── Makefile ├── README.md ├── make.bat └── source │ ├── conf.py │ ├── git_theta.checkpoints.rst │ ├── git_theta.lsh.rst │ ├── git_theta.merges.rst │ ├── git_theta.rst │ ├── git_theta.updates.rst │ └── modules.rst ├── examples └── git_theta_example.md ├── git_theta ├── __init__.py ├── api.py ├── async_utils.py ├── checkpoints │ ├── __init__.py │ ├── base.py │ ├── flax_checkpoint.py │ ├── pickled_dict_checkpoint.py │ ├── safetensors_checkpoint.py │ └── tensorflow_checkpoint.py ├── filters.py ├── git_utils.py ├── hooks │ ├── post-commit │ └── pre-push ├── lsh │ ├── __init__.py │ ├── base.py │ ├── euclidean_lsh.py │ ├── pool.py │ └── types.py ├── merges │ ├── __init__.py │ ├── average.py │ ├── base.py │ ├── context.py │ └── take.py ├── metadata.py ├── params.py ├── scripts │ ├── __init__.py │ ├── git_theta_cli.py │ ├── git_theta_diff.py │ ├── git_theta_filter.py │ └── git_theta_merge.py ├── theta.py ├── types.py ├── updates │ ├── __init__.py │ ├── base.py │ ├── dense.py │ ├── ia3.py │ ├── low_rank.py │ └── sparse.py └── utils.py ├── plugins ├── README.md └── json-checkpoint │ ├── README.md │ ├── git_theta_json_checkpoint │ ├── __init__.py │ └── checkpoints.py │ └── setup.py ├── pyproject.toml ├── requirements-ci.txt ├── requirements-dev.txt ├── setup.py └── tests ├── checkpoints ├── checkpoints_test.py ├── safetensors_checkpoint_test.py └── tensorflow_checkpoint_test.py ├── conftest.py ├── end2end ├── README.md ├── checkout │ └── test.sh ├── clean.sh ├── commit │ └── test.sh ├── ia3 │ └── test.sh ├── inprocess │ ├── test.py │ └── test.sh ├── low-rank │ └── test.sh ├── make-test.sh ├── model.py ├── runner.sh ├── smudge │ ├── clean.sh │ └── test.sh ├── sparse │ ├── clean.sh │ └── test.sh ├── utils.sh └── verify.py ├── git_utils_test.py ├── helpers ├── __init__.py └── utils.py ├── metadata_test.py ├── params_test.py ├── theta_test.py ├── trie_test.py ├── updates ├── base_test.py ├── ia3_test.py ├── low_rank_test.py └── sparse_update_test.py └── utils_test.py /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: BuildDocs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set Up Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: "3.8" 22 | - name: Install Dependencies and Package 23 | run: | 24 | python -m pip install --upgrade pip 25 | python -m pip install .[all,docs] 26 | - name: Build docs 27 | working-directory: ./docs/ 28 | run: | 29 | make html 30 | -------------------------------------------------------------------------------- /.github/workflows/end2endtest.yml: -------------------------------------------------------------------------------- 1 | name: End2EndTests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build: 14 | runs-on: ${{ matrix.os }} 15 | defaults: 16 | run: 17 | shell: bash 18 | strategy: 19 | matrix: 20 | os: [ubuntu-latest, windows-latest, macos-latest, macos-13] 21 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 22 | exclude: 23 | - os: macos-latest 24 | python-version: "3.8" 25 | - os: macos-latest 26 | python-version: "3.9" 27 | - os: macos-latest 28 | python-version: "3.10" 29 | - os: macos-13 30 | python-version: "3.11" 31 | - os: macos-13 32 | python-version: "3.12" 33 | 34 | steps: 35 | - uses: actions/checkout@v2 36 | - name: Set Up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | cache: "pip" 41 | cache-dependency-path: "setup.py" 42 | - name: Install Dependencies and Package 43 | run: | 44 | python -m pip install --upgrade pip 45 | python -m pip install .[pytorch] 46 | - name: Run End2End Tests 47 | working-directory: ./tests/end2end 48 | run: | 49 | ./runner.sh 50 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | jobs: 13 | black: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - uses: psf/black@stable 18 | with: 19 | version: 23.1.0 20 | isort: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v2 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: 3.8 27 | # Install package and deps so third-party packages are sorted 28 | - name: Install Dependencies and Package 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m pip install .[test,all] 32 | - uses: isort/isort-action@master 33 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Setup Python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.8" 17 | - name: Install Build Package 18 | run: | 19 | python -m pip install --upgrade build setuptools wheel twine 20 | - name: Build Package 21 | run: | 22 | python -m build --sdist --wheel --outdir dist/ . 23 | - name: Publish Package 24 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 25 | uses: pypa/gh-action-pypi-publish@release/v1 26 | with: 27 | user: __token__ 28 | password: ${{ secrets.PYPI_PASSWORD }} 29 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | name: Unittests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build: 14 | runs-on: ${{ matrix.os }} 15 | defaults: 16 | run: 17 | shell: bash 18 | strategy: 19 | matrix: 20 | os: [ubuntu-latest, windows-latest, macos-latest, macos-13] 21 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 22 | exclude: 23 | - os: macos-latest 24 | python-version: "3.8" 25 | - os: macos-latest 26 | python-version: "3.9" 27 | - os: macos-latest 28 | python-version: "3.10" 29 | - os: macos-13 30 | python-version: "3.11" 31 | - os: macos-13 32 | python-version: "3.12" 33 | 34 | steps: 35 | - uses: actions/checkout@v2 36 | - name: Set Up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | cache: "pip" 41 | cache-dependency-path: "setup.py" 42 | - name: Install Dependencies and Package 43 | run: | 44 | python -m pip install --upgrade pip 45 | # Install pinned deps when testing to avoid backtracking 46 | python -m pip install -r requirements-ci.txt 47 | python -m pip install .[test,all] 48 | - name: Run Unit Tests 49 | run: | 50 | pytest 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.modeldiff 2 | *.onnx 3 | .DS_Store 4 | .huskyrc.json 5 | out 6 | log.log 7 | **/node_modules 8 | *.pyc 9 | *.vsix 10 | **/.vscode/.ropeproject/** 11 | **/testFiles/**/.cache/** 12 | *.noseids 13 | .nyc_output 14 | .vscode-test 15 | __pycache__ 16 | npm-debug.log 17 | **/.mypy_cache/** 18 | !yarn.lock 19 | coverage/ 20 | cucumber-report.json 21 | **/.vscode-test/** 22 | **/.vscode test/** 23 | **/.vscode-smoke/** 24 | **/.venv*/ 25 | port.txt 26 | precommit.hook 27 | pythonFiles/lib/** 28 | debug_coverage*/** 29 | languageServer/** 30 | languageServer.*/** 31 | obj/** 32 | .pytest_cache 33 | tmp/** 34 | .python-version 35 | .vs/ 36 | test-results*.xml 37 | xunit-test-results.xml 38 | build/ci/performance/performance-results.json 39 | !build/ 40 | debug*.log 41 | debugpy*.log 42 | pydevd*.log 43 | nodeLanguageServer/** 44 | nodeLanguageServer.*/** 45 | dist/** 46 | data/* 47 | .idea/* 48 | .ipynb_checkpoints 49 | *.egg-info 50 | build/ 51 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - repo: https://github.com/psf/black 8 | rev: 23.1.0 9 | hooks: 10 | - id: black 11 | - repo: https://github.com/pycqa/isort 12 | rev: 5.12.0 13 | hooks: 14 | - id: isort 15 | name: isort (python) 16 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Git-Theta docs 2 | 3 | This directory contains the code for building `git-theta`'s documentation. 4 | Specifically, it builds the documentation for the Python library, not the command line utilities. 5 | The individual submodule files can be regenerated by running 6 | ``` 7 | sphinx-apidoc -o source git_theta 8 | ``` 9 | This must be done whenever a new submodule is added or an existing submodule is renamed. 10 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | import git_theta 17 | 18 | sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "git-theta" 24 | copyright = "2023, r-three" 25 | author = "r-three" 26 | 27 | # The full version, including alpha/beta/rc tags 28 | release = git_theta.__version__ 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.imgmath", 39 | "numpydoc", 40 | ] 41 | 42 | # Add any paths that contain templates here, relative to this directory. 43 | templates_path = ["_templates"] 44 | 45 | # List of patterns, relative to source directory, that match files and 46 | # directories to ignore when looking for source files. 47 | # This pattern also affects html_static_path and html_extra_path. 48 | exclude_patterns = [] 49 | 50 | 51 | # -- Options for HTML output ------------------------------------------------- 52 | 53 | # The theme to use for HTML and HTML Help pages. See the documentation for 54 | # a list of builtin themes. 55 | # 56 | html_theme = "default" 57 | 58 | # Add any paths that contain custom static files (such as style sheets) here, 59 | # relative to this directory. They are copied after the builtin static files, 60 | # so a file named "default.css" will overwrite the builtin "default.css". 61 | html_static_path = ["_static"] 62 | 63 | autodoc_member_order = "bysource" 64 | 65 | root_doc = "modules" 66 | -------------------------------------------------------------------------------- /docs/source/git_theta.checkpoints.rst: -------------------------------------------------------------------------------- 1 | git\_theta.checkpoints package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | git\_theta.checkpoints.base module 8 | ---------------------------------- 9 | 10 | .. automodule:: git_theta.checkpoints.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | git\_theta.checkpoints.flax\_checkpoint module 16 | ---------------------------------------------- 17 | 18 | .. automodule:: git_theta.checkpoints.flax_checkpoint 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | git\_theta.checkpoints.pickled\_dict\_checkpoint module 24 | ------------------------------------------------------- 25 | 26 | .. automodule:: git_theta.checkpoints.pickled_dict_checkpoint 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | git\_theta.checkpoints.tensorflow\_checkpoint module 32 | ---------------------------------------------------- 33 | 34 | .. automodule:: git_theta.checkpoints.tensorflow_checkpoint 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: git_theta.checkpoints 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/source/git_theta.lsh.rst: -------------------------------------------------------------------------------- 1 | git\_theta.lsh package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | git\_theta.lsh.base module 8 | -------------------------- 9 | 10 | .. automodule:: git_theta.lsh.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | git\_theta.lsh.euclidean\_lsh module 16 | ------------------------------------ 17 | 18 | .. automodule:: git_theta.lsh.euclidean_lsh 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | git\_theta.lsh.pool module 24 | -------------------------- 25 | 26 | .. automodule:: git_theta.lsh.pool 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | git\_theta.lsh.types module 32 | --------------------------- 33 | 34 | .. automodule:: git_theta.lsh.types 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: git_theta.lsh 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/source/git_theta.merges.rst: -------------------------------------------------------------------------------- 1 | git\_theta.merges package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | git\_theta.merges.average module 8 | -------------------------------- 9 | 10 | .. automodule:: git_theta.merges.average 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | git\_theta.merges.base module 16 | ----------------------------- 17 | 18 | .. automodule:: git_theta.merges.base 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | git\_theta.merges.context module 24 | -------------------------------- 25 | 26 | .. automodule:: git_theta.merges.context 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | git\_theta.merges.take module 32 | ----------------------------- 33 | 34 | .. automodule:: git_theta.merges.take 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: git_theta.merges 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/source/git_theta.rst: -------------------------------------------------------------------------------- 1 | git\_theta package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | git_theta.checkpoints 11 | git_theta.lsh 12 | git_theta.merges 13 | git_theta.updates 14 | 15 | Submodules 16 | ---------- 17 | 18 | git\_theta.async\_utils module 19 | ------------------------------ 20 | 21 | .. automodule:: git_theta.async_utils 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | git\_theta.git\_utils module 27 | ---------------------------- 28 | 29 | .. automodule:: git_theta.git_utils 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | git\_theta.metadata module 35 | -------------------------- 36 | 37 | .. automodule:: git_theta.metadata 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | 42 | git\_theta.params module 43 | ------------------------ 44 | 45 | .. automodule:: git_theta.params 46 | :members: 47 | :undoc-members: 48 | :show-inheritance: 49 | 50 | git\_theta.theta module 51 | ----------------------- 52 | 53 | .. automodule:: git_theta.theta 54 | :members: 55 | :undoc-members: 56 | :show-inheritance: 57 | 58 | git\_theta.types module 59 | ----------------------- 60 | 61 | .. automodule:: git_theta.types 62 | :members: 63 | :undoc-members: 64 | :show-inheritance: 65 | 66 | git\_theta.utils module 67 | ----------------------- 68 | 69 | .. automodule:: git_theta.utils 70 | :members: 71 | :undoc-members: 72 | :show-inheritance: 73 | 74 | Module contents 75 | --------------- 76 | 77 | .. automodule:: git_theta 78 | :members: 79 | :undoc-members: 80 | :show-inheritance: 81 | -------------------------------------------------------------------------------- /docs/source/git_theta.updates.rst: -------------------------------------------------------------------------------- 1 | git\_theta.updates package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | git\_theta.updates.base module 8 | ------------------------------ 9 | 10 | .. automodule:: git_theta.updates.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | git\_theta.updates.dense module 16 | ------------------------------- 17 | 18 | .. automodule:: git_theta.updates.dense 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | git\_theta.updates.low\_rank module 24 | ----------------------------------- 25 | 26 | .. automodule:: git_theta.updates.low_rank 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | git\_theta.updates.sparse module 32 | -------------------------------- 33 | 34 | .. automodule:: git_theta.updates.sparse 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: git_theta.updates 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | git_theta 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | git_theta 8 | -------------------------------------------------------------------------------- /examples/git_theta_example.md: -------------------------------------------------------------------------------- 1 | Implemented a Hello World exampe using a design similar to the proposal in [#16](https://github.com/r-three/checkpoint-vcs/issues/16). 2 | 3 | Below is a demo of the implemented proof of concept: 4 | 5 | 1. First initialize a new git repo 6 | ``` 7 | git init 8 | ``` 9 | 10 | 2. Initialize the git filter driver for model files: 11 | ``` 12 | git-theta init 13 | ``` 14 | For this simple example, models are stored as json files. The command defines a filter that captures model files by adding 15 | ``` 16 | *.json filter=theta 17 | ``` 18 | to `.git/info/attributes` and defines a smudge and clean filter by adding 19 | ``` 20 | [filter "theta"] 21 | clean = git-theta-filter clean %f 22 | smudge = git-theta-filter smudge %f 23 | required = true 24 | ``` 25 | to `.git/config`. This means that when `foo.json` is being added to staging area, `git-theta-filter clean foo.json` is run and when it is checked out, `git-theta-filter smudge foo.json` is run. 26 | 27 | 3. Create a file `my_model.json` in the repo containing 28 | ``` 29 | { 30 | "layer1": { 31 | "w": [1,2,3,4], 32 | "b": [10] 33 | }, 34 | "layer2": { 35 | "w": [-1,-2,-3,-4], 36 | "b": [-10] 37 | }, 38 | "other_params": { 39 | "alpha": 0.1, 40 | "lr": 0.001 41 | } 42 | } 43 | ``` 44 | 45 | 4. Run `git-theta add my_model.json`. 46 | 47 | `git-theta` is a python program that (1) loads `my_model.json`, (2) saves each individual parameter group to the filesystem under `.git_theta/my_model`, (3) runs git add on each parameter group file saved under `.git_theta`, (4) runs git add on `my_model.json`. 48 | 49 | Note that when `my_model.json` is added to the staging area (step 4), it gets intercepted by the previously defined clean filter for *.json files. The clean filter runs `git-theta-filter clean my_model.json`. `git-theta-filter clean` is another python program that replaces the contents `my_model.json` with a dictionary containing `{'model_dir': '.git_theta/my_model', 'model_hash': }` 50 | 51 | After all this, the staging area contains a snapshot of the model's parameter groups under `.git_theta/my_model` and a file called `my_model.json` that doesn't actually have the model parameters but instead some metadata about where to find the parameters at a later time. Although the staged version of `my_model.json` only contains metadata, the working copy still contains the model parameters. 52 | 53 | The output of `git status` at this point is: 54 | 55 | ``` 56 | Changes to be committed: 57 | (use "git rm --cached ..." to unstage) 58 | new file: .git_theta/my_model/layer1/b 59 | new file: .git_theta/my_model/layer1/w 60 | new file: .git_theta/my_model/layer2/b 61 | new file: .git_theta/my_model/layer2/w 62 | new file: .git_theta/my_model/other_params/alpha 63 | new file: .git_theta/my_model/other_params/lr 64 | new file: my_model.json 65 | ``` 66 | 67 | 5. `git commit` to commit the model 68 | 6. Modify one parameter group in `my_model.json` 69 | 7. Run `git-theta add my_model.json` to stage the changes to the model. 70 | 71 | At this point in time only the modified parameter group's file under `.git_theta` has been modified. The output of `git status` at this point is: 72 | 73 | ``` 74 | Changes to be committed: 75 | (use "git restore --staged ..." to unstage) 76 | modified: .git_theta/my_model/layer1/w 77 | modified: my_model.json 78 | ``` 79 | 80 | 8. `git commit` to commit the change 81 | 9. Look at the output of `git log` to get the commit hashes: 82 | 83 | ``` 84 | commit 11b37c3aacd5d6f4ec23986d95e739ed44433d6c (HEAD -> master) 85 | Author: Nikhil Kandpal 86 | Date: Wed Oct 5 00:31:02 2022 -0400 87 | 88 | modify layer1/w 89 | 90 | commit d13dca536d2690cb758c3866acb2abd1e0f32790 91 | Author: Nikhil Kandpal 92 | Date: Wed Oct 5 00:29:21 2022 -0400 93 | 94 | initial commit 95 | ``` 96 | 97 | 10. Checkout the first commit and make a new branch to test whether we can re-create the initial model 98 | 99 | `git checkout d13dca536d2690cb758c3866acb2abd1e0f32790 -b my_branch` 100 | 101 | When this occurs, the smudge filter intercepts the model file and is called with `git-theta-filter smudge my_model.json`. `git-theta-filter smudge` is a python program that reads the metadata file and reconstructs the model checkpoint from the data in `.git_theta/my_model`. 102 | 103 | 11. Check the contents of `my_model.json` 104 | 105 | ``` 106 | cat my_model.json 107 | {"other_params": {"lr": 0.001, "alpha": 0.1}, "layer1": {"b": [10], "w": [1, 2, 3, 4]}, "layer2": {"b": [-10], "w": [-1, -2, -3, -4]}} 108 | ``` 109 | 110 | Notice that the model parameters are what we started with (although in a different order since the model checkpoints are json) and `layer1/w` has its starting parameters. 111 | 112 | 12. Switch back to HEAD and check `my_model.json` 113 | 114 | ``` 115 | git switch - 116 | cat my_model.json 117 | {"other_params": {"lr": 0.001, "alpha": 0.1}, "layer1": {"b": [10], "w": [10, 20, 30, 40]}, "layer2": {"b": [-10], "w": [-1, -2, -3, -4]}} 118 | ``` 119 | 120 | At the HEAD commit, the parameters for `layer1/w` are the modified version once again. 121 | -------------------------------------------------------------------------------- /git_theta/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.0" 2 | 3 | from git_theta import ( 4 | checkpoints, 5 | filters, 6 | git_utils, 7 | lsh, 8 | metadata, 9 | params, 10 | scripts, 11 | theta, 12 | updates, 13 | utils, 14 | ) 15 | from git_theta.api import load_from_git, save_to_git 16 | -------------------------------------------------------------------------------- /git_theta/api.py: -------------------------------------------------------------------------------- 1 | """User facing functions to interact with git-theta without incurring extra io costs.""" 2 | 3 | import io 4 | import sys 5 | from typing import Optional, Union 6 | 7 | import git 8 | 9 | import git_theta 10 | from git_theta import checkpoints, filters, git_utils, metadata, utils 11 | 12 | 13 | def save_to_git( 14 | state_dict, 15 | path: str, 16 | commit_msg: str, 17 | tag: Optional[str] = None, 18 | checkpoint_type: str = "pytorch", 19 | checkout: bool = False, 20 | ) -> git.Commit: 21 | """Save a model using git-theta without needing to write it to the working tree. 22 | 23 | Args: 24 | state_dict: The model weights in the framework-native format. 25 | path: The path where the model will be saved. 26 | commit_msg: The message to include in the new commit. 27 | tag: If provided, a tag to add to the new commit. 28 | checkpoint_type: The checkpoint format name, used to get the checkpoint plugin. 29 | checkout: If true, the new commit will be checked out (This incurs extra 30 | compute and I/O cost as the model will be moved from git-storage to the 31 | working tree). 32 | 33 | Returns: 34 | The GitPython object representing the commit made with this save. Includes 35 | information like the sha. 36 | """ 37 | repo = git_utils.get_git_repo() 38 | # Convert the deep learning framework native state dict into our checkpoint format. 39 | checkpoint_handler = checkpoints.get_checkpoint_handler(checkpoint_type) 40 | ckpt = checkpoint_handler.from_framework(state_dict) 41 | # Convert the checkpoint into the cleaned metadata file. 42 | metadata = filters.clean(ckpt, repo, path) 43 | # Capture metadata writing into a string. 44 | metadata = str(metadata) 45 | # Convert the metadata file into a git blob without having it on disk. 46 | blob = git_theta.git_utils.make_blob(repo, metadata, path) 47 | # Add the metadata to staging. 48 | repo.index.add([blob]) 49 | # Commit the metadata. 50 | if sys.platform in ("win32", "cygwin"): 51 | # When you use GitPython to commit, things like hooks, i.e. our post-commit 52 | # hook, run as subprocesses. Currently it seems that running shell scripts 53 | # with the subprocess does not work in windows. 54 | # Commit using the GitPython wrapper around the `git commit` command. This 55 | # way hooks will be handled the same way as a normal commit. 56 | repo.git.commit(m=commit_msg) 57 | # We now need to get the sha manually in order to reference this commit. 58 | sha = repo.commit("HEAD") 59 | else: 60 | # Commit directly from python 61 | sha = repo.index.commit(commit_msg) 62 | if checkout: 63 | repo.git.checkout(sha) 64 | if tag is not None: 65 | repo.create_tag(tag, ref=sha) 66 | return sha 67 | 68 | 69 | def load_from_git( 70 | sha_or_tag: Union[str, git.Commit], 71 | path: str, 72 | checkpoint_type: str = "pytorch", 73 | checkout: bool = False, 74 | ): 75 | """Load a model from git-theta without having it checked out. 76 | 77 | Args: 78 | sha_or_tag: A reference to the commit to load the model from. It can be the 79 | sha1 or a tag. 80 | path: The path to where the model was saved in the working tree. 81 | checkpoint_type: The checkpoint format name, used to get the checkpoint plugin. 82 | checkout: If true, the commit is also checked out, keeping the on disk model 83 | in sync with the in-memory model (This incurs extra compute and I/O cost 84 | as the model will be moved from git-storage to the working tree). 85 | 86 | Returns: 87 | The loaded model in the checkpoint native format. 88 | """ 89 | repo = git_utils.get_git_repo() 90 | # Set the checkpoint type env variable so that it respects the user input. 91 | utils.EnvVarConstants.CHECKPOINT_TYPE = checkpoint_type 92 | # Look up the metadata for this checkpoint in git. 93 | metadata_blob = git_utils.get_file_version(repo, path, sha_or_tag) 94 | # Build the metadata object from the blob data. 95 | metadata_obj = metadata.Metadata.from_file(metadata_blob.data_stream) 96 | # Convert the metadata into a checkpoint with weights. 97 | ckpt = filters.smudge(metadata_obj, repo, path) 98 | # Checkout the commit we are loading from so the state on disk matches the 99 | # state in memory 100 | if checkout: 101 | repo.git.checkout(sha_or_tag) 102 | # Convert the checkpoint to the native state dict. 103 | return ckpt.to_framework() 104 | -------------------------------------------------------------------------------- /git_theta/async_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for running I/O-bound operations asyncronously.""" 2 | 3 | import asyncio 4 | import dataclasses 5 | import functools 6 | import itertools 7 | import logging 8 | import sys 9 | import weakref 10 | from typing import Any, Awaitable, Dict, Optional, Sequence, Tuple, TypeVar, Union 11 | 12 | import six 13 | 14 | if sys.version_info >= (3, 8): 15 | from typing import Protocol 16 | else: 17 | from typing_extensions import Protocol 18 | 19 | 20 | class AsyncTaskMixin(logging.Handler): 21 | """Include an async task index in the log record.""" 22 | 23 | def __init__(self, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | # A way to always get the next id. Loggers are singltons so 26 | # there will only be one of these couters and therefore no dups. 27 | self._next_id = itertools.count().__next__ 28 | # WeakDict lets us use a reference to an async task as a key 29 | # without stopping it from being garbage collected. 30 | self._task_ids = weakref.WeakKeyDictionary() 31 | 32 | def _task_id(self): 33 | """Map an Async Task to an id.""" 34 | try: 35 | task = asyncio.current_task() 36 | if task not in self._task_ids: 37 | self._task_ids[task] = self._next_id() 38 | return f"task-{self._task_ids[task]}" 39 | except RuntimeError: 40 | return "main" 41 | 42 | def emit(self, record): 43 | """Add the task id to the record.""" 44 | # Use setattr over `.` notation to avoid some overloading on the 45 | # record class. What people seem to do in most online examples. 46 | record.__setattr__("task", self._task_id()) 47 | super().emit(record) 48 | 49 | 50 | class AsyncTaskStreamHandler(AsyncTaskMixin, logging.StreamHandler): 51 | """Include an Async task-id when logging to a stream.""" 52 | 53 | 54 | class AsyncTaskFileHandler(AsyncTaskMixin, logging.FileHandler): 55 | """Include an Async task-id when logging to a file.""" 56 | 57 | 58 | def run(*args, **kwargs): 59 | """Run an awaitable to completion, dispatch based on python version.""" 60 | if sys.version_info < (3, 8): 61 | if sys.platform in ("win32", "cygwin"): 62 | asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) 63 | return asyncio.run(*args, **kwargs) 64 | 65 | 66 | # A type variable to indicate that the keys of the dict will not change. 67 | K = TypeVar("K") 68 | # A type variable to indicate that the async task is called with K, V tuples. 69 | V = TypeVar("V") 70 | 71 | 72 | class MapTask(Protocol): 73 | def __call__(self, key: K, value: V) -> Awaitable[Tuple[K, Any]]: 74 | """An async function that runs on each key, value pair in a map.""" 75 | 76 | 77 | async def limited_concurrency(*args, f: MapTask, sem: asyncio.Semaphore, **kwargs): 78 | """Run f but limit the number of processes that can run at once.""" 79 | async with sem: 80 | return await f(*args, **kwargs) 81 | 82 | 83 | async def run_map( 84 | mapping: Dict[K, V], 85 | func: MapTask, 86 | max_concurrency: int = -1, 87 | ) -> Dict[K, Any]: 88 | """Run async function on K, V pairs, return map with result as new value.""" 89 | if max_concurrency > 0: 90 | sem = asyncio.Semaphore(max_concurrency) 91 | func = functools.partial(limited_concurrency, f=func, sem=sem) 92 | return dict(await asyncio.gather(*(func(k, v) for k, v in mapping.items()))) 93 | 94 | 95 | @dataclasses.dataclass 96 | class CompletedAsyncProcess: 97 | """Results from a finished async subprocess run.""" 98 | 99 | args: Union[Sequence[str], str] 100 | returncode: Optional[int] 101 | stdout: Optional[bytes] = None 102 | stderr: Optional[bytes] = None 103 | 104 | 105 | async def subprocess_run( 106 | command: Union[Sequence[str], str], 107 | input: Optional[Union[str, bytes]] = None, 108 | capture_output: bool = False, 109 | ) -> CompletedAsyncProcess: 110 | """Run a subprocess with async. Tries to mirror the subprocess.run API.""" 111 | if not isinstance(command, str): 112 | shell_command = " ".join(command) 113 | else: 114 | shell_command = command 115 | proc = await asyncio.create_subprocess_shell( 116 | shell_command, 117 | stdin=asyncio.subprocess.PIPE, 118 | stdout=asyncio.subprocess.PIPE, 119 | stderr=asyncio.subprocess.PIPE, 120 | ) 121 | if input is not None: 122 | stdout, stderr = await proc.communicate(input=six.ensure_binary(input)) 123 | else: 124 | stdout, stderr = await proc.communicate() 125 | return CompletedAsyncProcess( 126 | command, 127 | proc.returncode, 128 | stdout if capture_output else None, 129 | stderr if capture_output else None, 130 | ) 131 | -------------------------------------------------------------------------------- /git_theta/checkpoints/__init__.py: -------------------------------------------------------------------------------- 1 | """A module of checkpoint handlers.""" 2 | 3 | from git_theta.checkpoints.base import ( 4 | Checkpoint, 5 | get_checkpoint_handler, 6 | get_checkpoint_handler_name, 7 | ) 8 | -------------------------------------------------------------------------------- /git_theta/checkpoints/base.py: -------------------------------------------------------------------------------- 1 | """Base class and utilities for different checkpoint format backends.""" 2 | 3 | import os 4 | import sys 5 | from abc import ABCMeta, abstractmethod 6 | from typing import Optional 7 | 8 | import numpy as np 9 | 10 | if sys.version_info < (3, 10): 11 | from importlib_metadata import entry_points 12 | else: 13 | from importlib.metadata import entry_points 14 | 15 | from git_theta import utils 16 | 17 | 18 | @utils.abstract_classattributes("name") 19 | class Checkpoint(dict, metaclass=ABCMeta): 20 | """Abstract base class for wrapping checkpoint formats.""" 21 | 22 | name: str = NotImplemented # The name of this checkpoint handler, can be used to lookup the plugin. 23 | 24 | @classmethod 25 | def from_file(cls, checkpoint_path): 26 | """Create a new Checkpoint object. 27 | 28 | Parameters 29 | ---------- 30 | checkpoint_path : str or file-like object 31 | Path to a checkpoint file 32 | """ 33 | return cls.from_framework(cls.load(checkpoint_path)) 34 | 35 | @classmethod 36 | @abstractmethod 37 | def load(cls, checkpoint_path): 38 | """Load a checkpoint into a dict format. 39 | 40 | Parameters 41 | ---------- 42 | checkpoint_path : str or file-like object 43 | Path to a checkpoint file 44 | 45 | Returns 46 | ------- 47 | model_dict : dict 48 | Dictionary mapping parameter names to parameter values. Parameters 49 | should be numpy arrays. 50 | """ 51 | 52 | @classmethod 53 | @abstractmethod 54 | def from_framework(cls, model_dict): 55 | """Convert a checkpoint from the native framework format to git-thetas.""" 56 | 57 | @abstractmethod 58 | def to_framework(self): 59 | """Convert out checkpoint into the native framework state dict.""" 60 | 61 | @abstractmethod 62 | def save(self, checkpoint_path): 63 | """Load a checkpoint into a dict format. 64 | 65 | Parameters 66 | ---------- 67 | checkpoint_path : str or file-like object 68 | Path to write out the checkpoint file to 69 | """ 70 | 71 | def flatten(self): 72 | return utils.flatten(self, is_leaf=lambda v: isinstance(v, np.ndarray)) 73 | 74 | def unflatten(self): 75 | return utils.unflatten(self) 76 | 77 | @classmethod 78 | def diff(cls, m1: "Checkpoint", m2: "Checkpoint") -> "Checkpoint": 79 | """Compute the diff between two checkpoints. 80 | 81 | Parameters 82 | ---------- 83 | m1 : Checkpoint 84 | The new checkpoint 85 | m2 : Checkpoint 86 | The old checkpoint 87 | 88 | Returns 89 | ------- 90 | added : Checkpoint 91 | Checkpoint containing the parameter groups added to m1 92 | removed : Checkpoint 93 | Checkpoint containing the parameter groups removed from m2 94 | modified : Checkpoint 95 | Checkpoint containing the parameter groups modified between m1 and m2 96 | """ 97 | m1_flat = m1.flatten() 98 | m2_flat = m2.flatten() 99 | # N.b.: This is actually faster than set operations on m1 and m2's keys 100 | added = cls({k: v for k, v in m1_flat.items() if k not in m2_flat}).unflatten() 101 | removed = cls( 102 | {k: v for k, v in m2_flat.items() if k not in m1_flat} 103 | ).unflatten() 104 | modified = cls( 105 | { 106 | k: v 107 | for k, v in m1_flat.items() 108 | if k in m2_flat and not np.allclose(v, m2_flat[k]) 109 | } 110 | ).unflatten() 111 | return added, removed, modified 112 | 113 | 114 | def get_checkpoint_handler_name(checkpoint_type: Optional[str] = None) -> str: 115 | """Get the name of the checkpoint handler to use. 116 | 117 | Order of precedence is 118 | 1. `checkpoint_type` argument 119 | 2. `$GIT_THETA_CHECKPOINT_TYPE` environment variable 120 | 3. default value (currently pytorch) 121 | 122 | Parameters 123 | ---------- 124 | checkpoint_type 125 | Name of the checkpoint handler 126 | 127 | Returns 128 | ------- 129 | str 130 | Name of the checkpoint handler 131 | """ 132 | # TODO(bdlester): Find a better way to include checkpoint type information 133 | # in git clean filters that are run without `git theta add`. 134 | # TODO: Don't default to pytorch once other checkpoint formats are supported. 135 | return checkpoint_type or utils.EnvVarConstants.CHECKPOINT_TYPE 136 | 137 | 138 | def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint: 139 | """Get the checkpoint handler either by name or from an environment variable. 140 | 141 | Gets the checkpoint handler either for the `checkpoint_type` argument or 142 | `$GIT_THETA_CHECKPOINT_TYPE` environment variable. 143 | 144 | Defaults to pytorch when neither are defined. 145 | 146 | Parameters 147 | ---------- 148 | checkpoint_type 149 | Name of the checkpoint handler 150 | 151 | Returns 152 | ------- 153 | Checkpoint 154 | The checkpoint handler (usually an instance of `git_theta.checkpoints.Checkpoint`). 155 | Returned handler may be defined in a user installed plugin. 156 | """ 157 | checkpoint_type = get_checkpoint_handler_name(checkpoint_type) 158 | discovered_plugins = entry_points(group="git_theta.plugins.checkpoints") 159 | return discovered_plugins[checkpoint_type].load() 160 | -------------------------------------------------------------------------------- /git_theta/checkpoints/flax_checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpoint plugin for Flax's msgpack format.""" 2 | 3 | from file_or_name import file_or_name 4 | from flax import serialization 5 | 6 | from git_theta.checkpoints import Checkpoint 7 | 8 | 9 | class FlaxCheckpoint(Checkpoint): 10 | """Load a msgpack based Flax Checkpoint.""" 11 | 12 | name: str = "flax" 13 | 14 | @classmethod 15 | @file_or_name(checkpoint_path="rb") 16 | def load(cls, checkpoint_path, **kwargs): 17 | return serialization.msgpack_restore(checkpoint_path.read()) 18 | 19 | @classmethod 20 | def from_framework(cls, model_dict): 21 | return cls(model_dict) 22 | 23 | def to_framework(self): 24 | return self 25 | 26 | @file_or_name(checkpoint_path="wb") 27 | def save(self, checkpoint_path, **kwargs): 28 | checkpoint_path.write(serialization.msgpack_serialize(dict(self))) 29 | -------------------------------------------------------------------------------- /git_theta/checkpoints/pickled_dict_checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpoints that use a pickled dict format like pytorch.""" 2 | 3 | import io 4 | 5 | import torch 6 | 7 | from git_theta.checkpoints import Checkpoint 8 | 9 | 10 | # TODO: We should rename this back to be Torch related as we do things like check if they are Torch.tensors. 11 | class PickledDictCheckpoint(Checkpoint): 12 | """Class for wrapping picked dict checkpoints, commonly used with PyTorch.""" 13 | 14 | name: str = "pickled_dict" 15 | 16 | @classmethod 17 | def load(cls, checkpoint_path): 18 | """Load a checkpoint into a dict format. 19 | 20 | Parameters 21 | ---------- 22 | checkpoint_path : str or file-like object 23 | Path to a checkpoint file 24 | 25 | Returns 26 | ------- 27 | model_dict : dict 28 | Dictionary mapping parameter names to parameter values 29 | """ 30 | if isinstance(checkpoint_path, io.IOBase): 31 | checkpoint_path = io.BytesIO(checkpoint_path.read()) 32 | 33 | # Map all values to the CPU as they may bave been saved to the GPU and we don't 34 | # know if the same GPU topology is available now. 35 | model_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) 36 | if not isinstance(model_dict, dict): 37 | raise ValueError("Supplied PyTorch checkpoint must be a dict.") 38 | if not all(isinstance(k, str) for k in model_dict.keys()): 39 | raise ValueError("All PyTorch checkpoint keys must be strings.") 40 | if not all(isinstance(v, torch.Tensor) for v in model_dict.values()): 41 | raise ValueError("All PyTorch checkpoint values must be tensors.") 42 | return model_dict 43 | 44 | @classmethod 45 | def from_framework(cls, model_dict): 46 | # If things were saved with gradient requirements we need to detach them 47 | # before converting them to numpy arrays. 48 | return cls({k: v.cpu().detach().numpy() for k, v in model_dict.items()}) 49 | 50 | def to_framework(self): 51 | return {k: torch.as_tensor(v) for k, v in self.items()} 52 | 53 | def save(self, checkpoint_path): 54 | """Load a checkpoint into a dict format. 55 | 56 | Parameters 57 | ---------- 58 | checkpoint_path : str or file-like object 59 | Path to write out the checkpoint file to 60 | """ 61 | checkpoint_dict = self.to_framework() 62 | torch.save(checkpoint_dict, checkpoint_path) 63 | -------------------------------------------------------------------------------- /git_theta/checkpoints/safetensors_checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpoint using the HF safetensors format. 2 | 3 | safetensors has the ability to write model checkpoint from "dl-native" -> "safetensors" 4 | and read "safetensors" -> any "dl-native" framework, not just the one that was 5 | used to write it. Therefore, we read/write with their numpy API. 6 | """ 7 | 8 | import safetensors.numpy 9 | from file_or_name import file_or_name 10 | 11 | from git_theta.checkpoints import Checkpoint 12 | 13 | 14 | # TODO(bdlester): Can we leverage the lazying loading ability to make things faster? 15 | class SafeTensorsCheckpoint(Checkpoint): 16 | """Class for r/w of the safetensors format. https://github.com/huggingface/safetensors""" 17 | 18 | name: str = "safetensors" 19 | 20 | @classmethod 21 | @file_or_name(checkpoint_path="rb") 22 | def load(cls, checkpoint_path: str): 23 | # Note that we use the numpy as the framework because we don't care what 24 | # their downstream dl framework is, we only want the results back as 25 | # numpy arrays. 26 | return safetensors.numpy.load(checkpoint_path.read()) 27 | 28 | @file_or_name(checkpoint_path="wb") 29 | def save(self, checkpoint_path: str): 30 | # Note, git theta uses numpy internally, so we save using the numpy api, 31 | # regardless of the original framework they used to write the checkpoint. 32 | checkpoint_dict = self.to_framework() 33 | checkpoint_path.write(safetensors.numpy.save(checkpoint_dict)) 34 | 35 | def to_framework(self): 36 | return self 37 | 38 | @classmethod 39 | def from_framework(cls, model_dict): 40 | return cls(model_dict) 41 | -------------------------------------------------------------------------------- /git_theta/checkpoints/tensorflow_checkpoint.py: -------------------------------------------------------------------------------- 1 | """A Checkpoint backend for tensorflow models.""" 2 | 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from git_theta import utils 8 | from git_theta.checkpoints import Checkpoint 9 | 10 | 11 | class DynamicNetwork(tf.keras.Model): 12 | """A keras model that can dynamically build itself from a map of params for tf saving.""" 13 | 14 | def __init__(self, params): 15 | super().__init__() 16 | for name, param in params.items(): 17 | # Convert numpy to tf.Variable so it will be saved. 18 | if isinstance(param, np.ndarray): 19 | param = tf.Variable(param, name=name) 20 | # Converted nested models into nested networks. 21 | elif isinstance(param, dict): 22 | param = DynamicNetwork(param) 23 | # Save the variable (or sub-model) to an attribute so it will get tracked. 24 | self.__setattr__(name, param) 25 | 26 | 27 | class TensorFlowCheckpoint(Checkpoint): 28 | """Process a TensorFlow checkpoint via `tf.keras.Model.save_weights`. (no computation graph included).""" 29 | 30 | name: str = "tensorflow-checkpoint" 31 | VALUE_STRING = ".ATTRIBUTES/VARIABLE_VALUE" 32 | 33 | @staticmethod 34 | def is_parameter(param_name: str) -> bool: 35 | return param_name.endswith(TensorFlowCheckpoint.VALUE_STRING) 36 | 37 | @staticmethod 38 | def normalize_name(param_name: str) -> str: 39 | param_name = utils.remove_suffix(param_name, TensorFlowCheckpoint.VALUE_STRING) 40 | param_name = utils.remove_suffix(param_name, "/") 41 | return param_name 42 | 43 | @classmethod 44 | def load(cls, checkpoint_path: str): 45 | ckpt_read = tf.train.load_checkpoint(checkpoint_path) 46 | params = {} 47 | for param_name in ckpt_read.get_variable_to_shape_map(): 48 | if not TensorFlowCheckpoint.is_parameter(param_name): 49 | continue 50 | simple_name = TensorFlowCheckpoint.normalize_name(param_name) 51 | params[tuple(simple_name.split("/"))] = ckpt_read.get_tensor(param_name) 52 | return utils.unflatten(params) 53 | 54 | @classmethod 55 | def from_framework(cls, model_dict): 56 | return cls(model_dict) 57 | 58 | def to_framework(self): 59 | return self 60 | 61 | def save(self, checkpoint_path: str): 62 | model = DynamicNetwork(self) 63 | model.save_weights(checkpoint_path) 64 | 65 | 66 | # TODO 67 | class TensorFlowSavedModel(Checkpoint): 68 | """Process a TensorFlow SavedModel (computation graph included).""" 69 | 70 | name: str = "tensorflow-savedmodel" 71 | 72 | @classmethod 73 | def load(cls, checkpoint_path: str): 74 | raise ValueError("Sorry, SavedModel support is a work in progress.") 75 | 76 | def save(self, checkpoint_path: str): 77 | raise ValueError("Sorry, SavedModel support is a work in progress.") 78 | -------------------------------------------------------------------------------- /git_theta/filters.py: -------------------------------------------------------------------------------- 1 | """Clean and Filter functions.""" 2 | 3 | import logging 4 | 5 | import git 6 | import numpy as np 7 | 8 | from git_theta import ( 9 | async_utils, 10 | checkpoints, 11 | git_utils, 12 | lsh, 13 | metadata, 14 | params, 15 | updates, 16 | ) 17 | from git_theta.utils import EnvVarConstants 18 | 19 | 20 | def clean( 21 | checkpoint: checkpoints.Checkpoint, repo: git.Repo, path: str 22 | ) -> metadata.Metadata: 23 | """Convert a `Checkpoint` to cleaned `Metadata`.""" 24 | # Note: If the update serializer is configurable per-parameter, it will 25 | # need to be created inside _clean 26 | update_serializer = params.get_update_serializer() 27 | # Create an update handler based on user input. 28 | update_handler = updates.get_update_handler()( 29 | update_serializer, EnvVarConstants.UPDATE_DATA_PATH 30 | ) 31 | prev_metadata = metadata.Metadata.from_commit(repo, path, "HEAD").flatten() 32 | logger = logging.getLogger("git_theta") 33 | 34 | async def _clean(param_keys, new_param): 35 | logger.debug(f"Cleaning {'/'.join(param_keys)}") 36 | # Get the metadata from the previous version of the parameter 37 | param_metadata = prev_metadata.get(param_keys) 38 | # Create new metadata from the current value 39 | logger.debug(f"Making new Metadata for {'/'.join(param_keys)}") 40 | new_tensor_metadata = metadata.TensorMetadata.from_tensor(new_param) 41 | logger.debug(f"Finished new Metadata for {'/'.join(param_keys)}") 42 | 43 | # If the parameter tensor has not changed, just keep the metadata the same 44 | # TODO: Encapsulate this parameter check within an equality check. 45 | if ( 46 | param_metadata 47 | and param_metadata.tensor_metadata.shape == new_tensor_metadata.shape 48 | and param_metadata.tensor_metadata.dtype == new_tensor_metadata.dtype 49 | # A parameter with a side-loaded update will not have changed in the 50 | # normal checkpoint, so ask the updater if it will be updated with 51 | # side-loaded information. 52 | and not update_handler.will_update(param_keys) 53 | ): 54 | # Compare the parameters using the LSH 55 | hasher = lsh.get_lsh() 56 | # TODO: Is is possible to make this comparison async? 57 | logger.debug(f"Comparing Hashes for: {'/'.join(param_keys)}") 58 | hash_distance = hasher.distance( 59 | param_metadata.tensor_metadata.hash, new_tensor_metadata.hash 60 | ) 61 | # If hash_distance < PARAMETER_ATOL, assume the tensors pass 62 | # np.allclose and parameter hasn't changed 63 | if hash_distance < EnvVarConstants.PARAMETER_ATOL: 64 | return param_keys, param_metadata 65 | # If PARAMETER_ATOL < hash_distance < LSH_THRESHOLD, load parameters 66 | # and check if parameter has changed with np.allclose 67 | elif hash_distance < EnvVarConstants.LSH_THRESHOLD: 68 | # Load the previous parameter using the specific update handler 69 | # for that parameter. 70 | param_update_handler = updates.get_update_handler( 71 | param_metadata.theta_metadata.update_type 72 | )(update_serializer) 73 | param = await param_update_handler.apply( 74 | param_metadata, param_keys, repo=repo, path=path 75 | ) 76 | if np.allclose( 77 | param, 78 | new_param, 79 | rtol=EnvVarConstants.PARAMETER_RTOL, 80 | atol=EnvVarConstants.PARAMETER_ATOL, 81 | ): 82 | return param_keys, param_metadata 83 | 84 | # Create git-theta metadata for the new parameter. 85 | new_theta_metadata = metadata.ThetaMetadata( 86 | update_type=update_handler.name, last_commit=git_utils.get_head(repo) 87 | ) 88 | # Write the new parameter into git-lfs 89 | lfs_metadata, param_hash = await update_handler.write( 90 | new_param, 91 | param_keys, 92 | prev_metadata=param_metadata, 93 | repo=repo, 94 | path=path, 95 | ) 96 | # If we are an IncrementalUpdate, we need to re-calculate the hash 97 | # so it is based on the updated value, not the old one. 98 | if param_hash is not None: 99 | new_tensor_metadata.hash = param_hash 100 | # Combine metadata into single paramtere metadata blob 101 | new_param_metadata = metadata.ParamMetadata( 102 | lfs_metadata=lfs_metadata, 103 | tensor_metadata=new_tensor_metadata, 104 | theta_metadata=new_theta_metadata, 105 | ) 106 | logger.debug(f"Finished Cleaning {'/'.join(param_keys)}") 107 | del new_param 108 | return param_keys, new_param_metadata 109 | 110 | # Sort the keys so we don't get changing diffs based on serialization order. 111 | sorted_checkpoint = dict(sorted(checkpoint.flatten().items())) 112 | if EnvVarConstants.LOW_MEMORY: 113 | # Run one at a time and delete the old values as you go 114 | # TODO: Is is possible/better to process the keys based on the size 115 | # of the tensor and resort later? Then you could do things like delete 116 | # all the small ones before you have to process the large one. 117 | logger.warning( 118 | "Runing Git-Theta in Low Memory Mode, no concurrency will be used, and references to parameter weights will be freed after use." 119 | ) 120 | meta = {} 121 | for k in list(sorted_checkpoint.keys()): 122 | # Get the param while removing it from the dict, removing the 123 | # reference in the dict will allow the tensor to be gc'd 124 | v = sorted_checkpoint.pop(k) 125 | param_name, param_meta = async_utils.run(_clean(k, v)) 126 | meta[param_name] = param_meta 127 | # Drop the reference to the value to allow it to be gc'd. 128 | del v 129 | return metadata.Metadata(meta).unflatten() 130 | return metadata.Metadata( 131 | **async_utils.run( 132 | async_utils.run_map( 133 | sorted_checkpoint, 134 | _clean, 135 | max_concurrency=EnvVarConstants.MAX_CONCURRENCY, 136 | ) 137 | ) 138 | ).unflatten() 139 | 140 | 141 | # TODO: Now that we have this as a separate function, us it (instead of 142 | # `subprocess.run`) in the manual merge escape hatch. 143 | def smudge( 144 | cleaned_metadata: metadata.Metadata, repo: git.Repo, path: str 145 | ) -> checkpoints.Checkpoint: 146 | """Convert cleaned `Metadata` to a `Checkpoint`.""" 147 | curr_metadata = cleaned_metadata.flatten() 148 | 149 | async def _smudge(param_keys, param_metadata): 150 | logger = logging.getLogger("git_theta") 151 | logger.debug(f"Smudging {'/'.join(param_keys)}") 152 | update_handler = updates.get_update_handler( 153 | param_metadata.theta_metadata.update_type 154 | )(params.get_update_serializer()) 155 | param_value = await update_handler.apply( 156 | param_metadata, param_keys, repo=repo, path=path 157 | ) 158 | return param_keys, param_value 159 | 160 | model_dict = async_utils.run( 161 | async_utils.run_map( 162 | curr_metadata, _smudge, max_concurrency=EnvVarConstants.MAX_CONCURRENCY 163 | ) 164 | ) 165 | 166 | checkpoint_handler = checkpoints.get_checkpoint_handler() 167 | return checkpoint_handler(model_dict).unflatten() 168 | -------------------------------------------------------------------------------- /git_theta/hooks/post-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # git-theta post-commit hook that: 4 | # 1. Checks for git-theta tracked models in the last commit 5 | # 2. Finds the Git LFS object ids corresponding to the parameter groups modified in that commit 6 | # 3. Records the modified object ids in .git/theta/commits/ 7 | 8 | git theta post-commit "$@" 9 | -------------------------------------------------------------------------------- /git_theta/hooks/pre-push: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # git-theta pre-push hook that: 4 | # 1. Gets the commits being pushed to the remote 5 | # 2. Checks for git-theta tracked models in those commits 6 | # 3. Reads .git/theta/commits/ to get the parameter group object-ids modified in each commit 7 | # 4. Runs git lfs push --object-id to push the modified parameter groups to the LFS store 8 | 9 | git theta pre-push "$@" 10 | -------------------------------------------------------------------------------- /git_theta/lsh/__init__.py: -------------------------------------------------------------------------------- 1 | from git_theta.lsh.base import HashFamily 2 | from git_theta.lsh.euclidean_lsh import get_lsh 3 | -------------------------------------------------------------------------------- /git_theta/lsh/base.py: -------------------------------------------------------------------------------- 1 | """Base class for computing locality-sensitive hashes""" 2 | 3 | import abc 4 | 5 | from git_theta.lsh.pool import RandomnessPool 6 | from git_theta.lsh.types import Parameter, Signature 7 | 8 | 9 | class HashFamily(metaclass=abc.ABCMeta): 10 | def __init__(self, signature_size: int): 11 | self._signature_size = signature_size 12 | self.pool = RandomnessPool(signature_size) 13 | 14 | @property 15 | @abc.abstractmethod 16 | def name(self) -> str: 17 | """The name of the distance function this hash family approximates.""" 18 | 19 | @property 20 | def signature_size(self) -> int: 21 | """The size of the signatures this hash produces.""" 22 | return self._signature_size 23 | 24 | @abc.abstractmethod 25 | def hash(self, x: Parameter) -> Signature: 26 | """Convert `x` to its signature.""" 27 | 28 | @abc.abstractmethod 29 | def distance(self, query: Signature, data: Signature) -> float: 30 | """Calculate the approximate distance between two signatures.""" 31 | -------------------------------------------------------------------------------- /git_theta/lsh/euclidean_lsh.py: -------------------------------------------------------------------------------- 1 | """Classes for computing Euclidean locality-sensitive hashes""" 2 | 3 | import os 4 | 5 | import numba as nb 6 | import numpy as np 7 | 8 | from git_theta.lsh import HashFamily 9 | from git_theta.lsh.pool import RandomnessPool 10 | from git_theta.lsh.types import Parameter, Signature 11 | from git_theta.utils import EnvVarConstants 12 | 13 | 14 | class EuclideanLSH(HashFamily): 15 | """ 16 | Class for performing the Euclidean l2 LSH (E2LSH) algorithm described in https://www.cs.princeton.edu/courses/archive/spring05/cos598E/bib/p253-datar.pdf 17 | with a pre-computed randomness pool as described in Section 3 of http://personal.denison.edu/~lalla/papers/online-lsh.pdf 18 | 19 | Also see http://mlwiki.org/index.php/Euclidean_LSH for an introduction to E2LSH 20 | """ 21 | 22 | def __init__(self, signature_size: int, bucket_width: float): 23 | super().__init__(signature_size) 24 | self.bucket_width = bucket_width 25 | 26 | @property 27 | def name(self) -> str: 28 | return "euclidean" 29 | 30 | def hash(self, x: Parameter) -> Signature: 31 | """Convert `x` to its signature.""" 32 | x = x.ravel() 33 | hyperplanes = self.pool.get_hyperplanes(x.size) 34 | return np.floor((x @ hyperplanes) / self.bucket_width).astype(np.int64) 35 | 36 | def distance(self, query: Signature, data: Signature) -> float: 37 | """Compute the distance between two EuclideanLSH signatures""" 38 | return ( 39 | (1 / np.sqrt(self.signature_size)) 40 | * np.linalg.norm(query - data) 41 | * self.bucket_width 42 | ) 43 | 44 | 45 | class FastEuclideanLSH(EuclideanLSH): 46 | """ 47 | Class for performing the Euclidean LSH using numba-jitted loops. This is both faster than the EuclideanLSH class using numpy matrix multiplications and 48 | also uses less memory since the whole hyperplane matrix (feature_size x signature_size) is never in memory at once. 49 | """ 50 | 51 | def hash(self, x: Parameter) -> Signature: 52 | """Convert `x` to its signature.""" 53 | return nb_hash(x.ravel(), self.signature_size, self.pool, self.bucket_width) 54 | 55 | 56 | @nb.jit(nopython=True, parallel=True) 57 | def nb_hash( 58 | x: Parameter, signature_size: int, pool: RandomnessPool, bucket_width: float 59 | ) -> Signature: 60 | signature = np.zeros(signature_size) 61 | 62 | for signature_idx in nb.prange(signature_size): 63 | for feature_idx, feature in enumerate(x): 64 | hyperplane_element = pool.get_hyperplane_element(feature_idx, signature_idx) 65 | signature[signature_idx] += feature * hyperplane_element 66 | 67 | return np.floor(signature / bucket_width).astype(np.int64) 68 | 69 | 70 | def get_lsh(): 71 | # TODO we need a better way of keeping track of configuration at the repository level 72 | # For LSH configuration, once it is set for a repository, changing it should be handled with care 73 | return FastEuclideanLSH( 74 | EnvVarConstants.LSH_SIGNATURE_SIZE, EnvVarConstants.PARAMETER_ATOL 75 | ) 76 | -------------------------------------------------------------------------------- /git_theta/lsh/pool.py: -------------------------------------------------------------------------------- 1 | """Class for deterministically supplying pre-computed random values""" 2 | 3 | import sys 4 | 5 | import numba as nb 6 | import numpy as np 7 | from numpy.random import MT19937, Generator 8 | 9 | from git_theta.utils import EnvVarConstants 10 | 11 | spec = [("pool", nb.float64[:]), ("signature_offsets", nb.int64[:])] 12 | 13 | 14 | @nb.experimental.jitclass(spec) 15 | class RandomnessPool: 16 | def __init__(self, signature_size): 17 | with nb.objmode(pool="float64[:]", signature_offsets="int64[:]"): 18 | # N.b. we use a fixed seed so that every instance of RandomPool has the same set of random numbers 19 | rng = Generator(MT19937(seed=42)) 20 | pool = rng.normal(size=EnvVarConstants.LSH_POOL_SIZE) 21 | int64_range = np.iinfo(np.int64) 22 | signature_offsets = rng.integers( 23 | int64_range.min, int64_range.max, size=signature_size, dtype=np.int64 24 | ) 25 | self.pool = pool 26 | self.signature_offsets = signature_offsets 27 | 28 | def get_hyperplanes(self, feature_size): 29 | hyperplanes = np.empty((feature_size, self.signature_offsets.size)) 30 | for feature_idx in nb.prange(feature_size): 31 | for signature_idx in nb.prange(self.signature_offsets.size): 32 | hyperplanes[feature_idx, signature_idx] = self.get_hyperplane_element( 33 | feature_idx, signature_idx 34 | ) 35 | return hyperplanes 36 | 37 | def get_hyperplane_element(self, feature_idx, signature_idx): 38 | signature_offset = self.signature_offsets[signature_idx] 39 | pool_idx = np.mod(np.bitwise_xor(feature_idx, signature_offset), self.pool.size) 40 | return self.pool[pool_idx] 41 | -------------------------------------------------------------------------------- /git_theta/lsh/types.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | Signature = np.ndarray 4 | Parameter = np.ndarray 5 | -------------------------------------------------------------------------------- /git_theta/merges/__init__.py: -------------------------------------------------------------------------------- 1 | """Plugins for Model Merging. 2 | 3 | Note: 4 | In order to dynamically create the menu of possible actions that describe what 5 | each plug-in does, the plugins get imported at the start of the merge tool. 6 | Therefore, plug-ins must not have slow side-effects that happen at import-time. 7 | """ 8 | 9 | from git_theta.merges.base import Merge, MergeArgument, all_merge_handlers 10 | -------------------------------------------------------------------------------- /git_theta/merges/base.py: -------------------------------------------------------------------------------- 1 | """Plugins for Model Merging. 2 | 3 | Note: 4 | In order to dynamically create the menu of possible actions that describe what 5 | each plug-in does, the plugins get imported at the start of the merge tool. 6 | Therefore, plug-ins must not have slow side-effects that happen at import-time. 7 | """ 8 | 9 | 10 | import logging 11 | import sys 12 | from abc import ABCMeta, abstractmethod 13 | from dataclasses import dataclass 14 | from typing import Any, Dict, FrozenSet, List, Optional, Tuple, Type, Union 15 | 16 | if sys.version_info < (3, 10): 17 | from importlib_metadata import entry_points 18 | else: 19 | from importlib.metadata import entry_points 20 | 21 | from git_theta import metadata, utils 22 | 23 | ParamName = Tuple[str, ...] 24 | Parameter = Any 25 | PartialModel = Dict[ParamName, Parameter] 26 | 27 | 28 | @dataclass 29 | class MergeArgument: 30 | """Metadata for how to describe and validate a user-specified merge-strategy-specific argument""" 31 | 32 | name: str 33 | description: str 34 | type: Type 35 | range: Optional[Tuple[Union[float, int], Union[float, int]]] 36 | 37 | @property 38 | def validator(self): 39 | """Returns a function checking whether a given string is a valid input for this argument""" 40 | 41 | def is_valid(x): 42 | # TODO: May need to support non-numeric types at some point 43 | try: 44 | x = self.type(x) 45 | if self.range: 46 | return x >= self.range[0] and x <= self.range[1] 47 | return False 48 | except: 49 | return False 50 | 51 | return is_valid 52 | 53 | 54 | class PrintableABCMeta(ABCMeta): 55 | """Add custom `str` to /classes/, not objects.""" 56 | 57 | def __str__(cls): 58 | return f"{cls.NAME}: {cls.DESCRIPTION}" 59 | 60 | 61 | @utils.abstract_classattributes("DESCRIPTION", "NAME", "SHORT_CUT", "INACTIVE_STATES") 62 | class Merge(metaclass=PrintableABCMeta): 63 | """A Plug-in that handles parameter merging. 64 | 65 | Note: 66 | Informational string about the plugin can contain `prompt_toolkit` 67 | supported HTML markup for styling and coloring text. 68 | """ 69 | 70 | DESCRIPTION: str = NotImplemented # Description of Merge Action, shown in menu. 71 | NAME: str = NotImplemented # Unique name of the merge, to look up the plugin with. 72 | SHORT_CUT: str = ( 73 | NotImplemented # A Request keyboard shortcut to use during merging. 74 | ) 75 | INACTIVE_STATES: FrozenSet[ 76 | utils.DiffState 77 | ] = frozenset() # States where this action will not appear in the menu. 78 | 79 | def __call__(self, param_name, *args, **kwargs): 80 | logger = logging.getLogger("git_theta") 81 | logger.info(f"Running {self.NAME} merge on parameter {'/'.join(param_name,)}") 82 | return self.merge(param_name, *args, **kwargs) 83 | 84 | @abstractmethod 85 | def merge( 86 | self, 87 | param_name: ParamName, 88 | paramA: metadata.ParamMetadata, 89 | paramB: metadata.ParamMetadata, 90 | paramO: metadata.ParamMetadata, 91 | metadataA: metadata.Metadata, 92 | metadataB: metadata.Metadata, 93 | metadataO: metadata.Metadata, 94 | modelA: PartialModel, 95 | modelB: PartialModel, 96 | modelO: PartialModel, 97 | path: str, 98 | **kwargs, 99 | ) -> metadata.ParamMetadata: 100 | """Merge parameters parameters. 101 | 102 | Parameters 103 | ---------- 104 | param_name: The name of the parameter we are looking at. 105 | paramA: The parameter metadata from branch A (current). 106 | paramB: The parameter metadata from branch B (other). 107 | paramO: The parameter metadata from the ancestor. 108 | metadataA: The full model metadata from branch A (current). 109 | metadataB: The full model metadata from branch B (other). 110 | metadataO: The full model metadata from the ancestor. 111 | modelA: A partially filled in model of real parameter values from 112 | branch A (current). Allows caching and reuse for any sort of 113 | "full model" merging method. 114 | modelB: A partially filled in model of real parameter values from 115 | branch B (other). Allows caching and reuse for any sort of 116 | "full model" merging method. 117 | modelO: A partially filled in model of real parameter values from 118 | the ancestor. Allows caching and reuse for any sort of 119 | "full model" merging method. 120 | path: The path to where the model actually lives. 121 | kwargs: Merge-strategy-specific arguments. 122 | """ 123 | 124 | @classmethod 125 | def merge_arguments(self) -> List[MergeArgument]: 126 | """Returns a list of `MergeArgument`s that provide information about the arguments specific to each merge strategy 127 | Each `MergeArgument` contains: 128 | 1. The name of the merge argument 129 | 2. A text description of what the argument does 130 | 3. The type of the argument 131 | """ 132 | return [] 133 | 134 | 135 | def all_merge_handlers() -> Dict[str, Merge]: 136 | """Enumerate and Load (import) all merge plugins.""" 137 | discovered_plugins = entry_points(group="git_theta.plugins.merges") 138 | loaded_plugins = {ep.name: ep.load() for ep in discovered_plugins} 139 | return loaded_plugins 140 | -------------------------------------------------------------------------------- /git_theta/merges/context.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | from prompt_toolkit import print_formatted_text 6 | from prompt_toolkit.formatted_text import HTML 7 | 8 | from git_theta import git_utils 9 | from git_theta.merges import Merge 10 | from git_theta.types import ParamName 11 | from git_theta.utils import TEXT_STYLE, DiffState, NoResult 12 | 13 | 14 | def get_other_commit_in_merge() -> str: 15 | git_heads = [e for e in os.environ.keys() if e.startswith("GITHEAD_")] 16 | if not git_heads: 17 | return None 18 | return git_heads[0].split("_", maxsplit=1)[-1] 19 | 20 | 21 | def trim_log(log: str, limit: int = 80) -> str: 22 | if len(log) >= limit - 3: 23 | return f"{log[:limit - 3]}..." 24 | return f"{log}" 25 | 26 | 27 | class Context(Merge): 28 | DESCRIPTION = f"Show extra information about what {TEXT_STYLE.format_who('us')} vs {TEXT_STYLE.format_who('them')} means." 29 | NAME = "context" 30 | SHORT_CUT = "c" 31 | INACTIVE_STATES = frozenset() 32 | 33 | def merge(self, *args, **kwargs): 34 | repo = git_utils.get_git_repo() 35 | other_hash = get_other_commit_in_merge() 36 | other_commit = repo.commit(other_hash) 37 | other_branch = repo.git.branch("--contains", other_hash).strip() 38 | other_log = other_commit.summary 39 | 40 | my_commit = repo.commit("HEAD") 41 | my_hash = my_commit.hexsha 42 | my_branch = repo.active_branch 43 | my_log = my_commit.summary 44 | 45 | print_formatted_text( 46 | HTML( 47 | "Merge Context:\n" 48 | f"\t{TEXT_STYLE.format_who('us')}, {my_hash[:6]} ({my_branch}): {trim_log(my_log)}\n" 49 | f"\t{TEXT_STYLE.format_who('them')}, {other_hash[:6]} ({other_branch}): {trim_log(other_log)}" 50 | ) 51 | ) 52 | 53 | return NoResult 54 | -------------------------------------------------------------------------------- /git_theta/merges/take.py: -------------------------------------------------------------------------------- 1 | """Merge operations that select one version or another.""" 2 | 3 | from git_theta import metadata 4 | from git_theta.merges import Merge 5 | from git_theta.types import ParamName 6 | from git_theta.utils import TEXT_STYLE, DiffState 7 | 8 | 9 | class TakeUs(Merge): 10 | DESCRIPTION = f"Use {TEXT_STYLE.format_who('our')} change to the parameter." 11 | NAME = "take_us" 12 | SHORT_CUT = "tu" 13 | # If only they made a change take "us" doesn't make sense. 14 | INACTIVE_STATES = frozenset( 15 | {DiffState.CHANGED_B, DiffState.ADDED_B, DiffState.DELETED_B} 16 | ) 17 | 18 | def merge( 19 | self, 20 | param_name: ParamName, 21 | paramA: metadata.ParamMetadata, 22 | paramB: metadata.ParamMetadata, 23 | paramO: metadata.ParamMetadata, 24 | *args, 25 | **kwargs, 26 | ) -> metadata.ParamMetadata: 27 | """Grab the changes from branch A (current).""" 28 | return paramA 29 | 30 | 31 | class TakeThem(Merge): 32 | DESCRIPTION = f"Use {TEXT_STYLE.format_who('their')} change to the parameter." 33 | NAME = "take_them" 34 | SHORT_CUT = "tt" 35 | # If only we made a change take "them" doesn't make sense. 36 | INACTIVE_STATES = frozenset( 37 | { 38 | DiffState.CHANGED_A, 39 | DiffState.ADDED_A, 40 | DiffState.DELETED_A, 41 | } 42 | ) 43 | 44 | def merge( 45 | self, 46 | param_name: ParamName, 47 | paramA: metadata.ParamMetadata, 48 | paramB: metadata.ParamMetadata, 49 | paramO: metadata.ParamMetadata, 50 | *args, 51 | **kwargs, 52 | ) -> metadata.ParamMetadata: 53 | """Grab the changes from branch B (other).""" 54 | return paramB 55 | 56 | 57 | class TakeOriginal(Merge): 58 | DESCRIPTION = f"Use the {TEXT_STYLE.format_who('original')} parameter." 59 | NAME = "take_original" 60 | SHORT_CUT = "to" 61 | INACTIVE_STATES = frozenset({}) 62 | 63 | def merge( 64 | self, 65 | param_name: ParamName, 66 | paramA: metadata.ParamMetadata, 67 | paramB: metadata.ParamMetadata, 68 | paramO: metadata.ParamMetadata, 69 | *args, 70 | **kwargs, 71 | ) -> metadata.ParamMetadata: 72 | """Grab the changes from the ancestor.""" 73 | return paramO 74 | -------------------------------------------------------------------------------- /git_theta/metadata.py: -------------------------------------------------------------------------------- 1 | """Classes representing checkpoint metadata files""" 2 | 3 | from __future__ import annotations 4 | 5 | import dataclasses 6 | import hashlib 7 | import json 8 | import logging 9 | import re 10 | from collections import OrderedDict 11 | from typing import Any, ClassVar, Dict, TextIO, Tuple, Union 12 | 13 | import git 14 | import numpy as np 15 | from file_or_name import file_or_name 16 | 17 | from git_theta import git_utils, lsh, utils 18 | 19 | 20 | @dataclasses.dataclass(eq=True) 21 | class MetadataField: 22 | def serialize(self) -> Dict[str, Any]: 23 | return dataclasses.asdict(self, dict_factory=OrderedDict) 24 | 25 | 26 | @dataclasses.dataclass(eq=True) 27 | class LfsMetadata(MetadataField): 28 | version: str 29 | oid: str 30 | size: str 31 | name: ClassVar[str] = "lfs_metadata" 32 | 33 | @property 34 | def lfs_pointer(self) -> str: 35 | return f"version {self.version}\noid sha256:{self.oid}\nsize {self.size}\n" 36 | 37 | @classmethod 38 | def from_pointer(cls, pointer_contents: str) -> LfsMetadata: 39 | match = re.match( 40 | r"^version (?P[^\s]+)\noid sha256:(?P[0-9a-f]{64})\nsize (?P[0-9]+)\n$", 41 | pointer_contents, 42 | ) 43 | if match is None: 44 | raise ValueError(f"Failed to parse pointer file {pointer_contents}") 45 | return cls( 46 | version=match.group("version"), 47 | oid=match.group("oid"), 48 | size=match.group("size"), 49 | ) 50 | 51 | @classmethod 52 | def from_bytes(cls, b: bytes) -> LfsMetadata: 53 | return cls.from_pointer(git_utils.git_lfs_clean(b)) 54 | 55 | 56 | @dataclasses.dataclass(eq=True) 57 | class TensorMetadata(MetadataField): 58 | shape: str 59 | dtype: str 60 | hash: np.ndarray 61 | name: ClassVar[str] = "tensor_metadata" 62 | 63 | def __post_init__(self): 64 | self.hash = np.array(self.hash) 65 | 66 | def __eq__(self, other): 67 | return ( 68 | self.shape == other.shape 69 | and self.dtype == other.dtype 70 | and np.array_equal(self.hash, other.hash) 71 | ) 72 | 73 | @classmethod 74 | def from_tensor(cls, tensor: np.ndarray) -> TensorMetadata: 75 | shape = str(tensor.shape) 76 | dtype = str(tensor.dtype) 77 | logger = logging.getLogger("git_theta") 78 | logger.debug(f"Starting LSH Hash") 79 | hash = lsh.get_lsh().hash(tensor) 80 | logger.debug(f"Finished LSH Hash") 81 | return cls(shape=shape, dtype=dtype, hash=hash) 82 | 83 | 84 | @dataclasses.dataclass(eq=True) 85 | class ThetaMetadata(MetadataField): 86 | update_type: str 87 | last_commit: str 88 | name: ClassVar[str] = "theta_metadata" 89 | 90 | 91 | @dataclasses.dataclass(eq=True) 92 | class ParamMetadata(MetadataField): 93 | tensor_metadata: TensorMetadata 94 | lfs_metadata: LfsMetadata 95 | theta_metadata: ThetaMetadata 96 | 97 | @classmethod 98 | def from_metadata_dict(cls, d: Dict[str, Any]) -> ParamMetadata: 99 | tensor_metadata = TensorMetadata(**d[TensorMetadata.name]) 100 | lfs_metadata = LfsMetadata(**d[LfsMetadata.name]) 101 | theta_metadata = ThetaMetadata(**d[ThetaMetadata.name]) 102 | return cls(tensor_metadata, lfs_metadata, theta_metadata) 103 | 104 | 105 | class Metadata(OrderedDict): 106 | @classmethod 107 | def from_metadata_dict(cls, d: Dict[str, Any]) -> Metadata: 108 | flattened = utils.flatten(d, is_leaf=lambda v: LfsMetadata.name in v) 109 | for param_keys, param_metadata in flattened.items(): 110 | flattened[param_keys] = ParamMetadata.from_metadata_dict(param_metadata) 111 | metadata = utils.unflatten(flattened) 112 | return cls(metadata) 113 | 114 | @classmethod 115 | @file_or_name(file="r") 116 | def from_file(cls, file: TextIO) -> Metadata: 117 | metadata_dict = json.load(file) 118 | return cls.from_metadata_dict(metadata_dict) 119 | 120 | @classmethod 121 | def from_commit(cls, repo: git.Repo, path: str, commit_hash: str) -> Metadata: 122 | obj = git_utils.get_file_version(repo, path, commit_hash) 123 | if obj is None: 124 | return cls() 125 | else: 126 | return cls.from_file(obj.data_stream) 127 | 128 | @file_or_name(file="w") 129 | def write(self, file: TextIO): 130 | file.write(str(self)) 131 | 132 | def flatten(self) -> Metadata: 133 | return utils.flatten(self, is_leaf=lambda v: isinstance(v, ParamMetadata)) 134 | 135 | def unflatten(self) -> Metadata: 136 | return utils.unflatten(self) 137 | 138 | def diff(self, other: Metadata) -> Tuple[Metadata, Metadata, Metadata]: 139 | self_flattened = self.flatten() 140 | other_flattened = other.flatten() 141 | added = Metadata( 142 | { 143 | k: self_flattened[k] 144 | for k in self_flattened.keys() - other_flattened.keys() 145 | } 146 | ).unflatten() 147 | removed = Metadata( 148 | { 149 | k: other_flattened[k] 150 | for k in other_flattened.keys() - self_flattened.keys() 151 | } 152 | ).unflatten() 153 | modified = Metadata() 154 | for param_keys in set(self_flattened.keys()).intersection( 155 | other_flattened.keys() 156 | ): 157 | if ( 158 | self_flattened[param_keys].lfs_metadata 159 | != other_flattened[param_keys].lfs_metadata 160 | ): 161 | modified[param_keys] = self_flattened[param_keys] 162 | 163 | modified = modified.unflatten() 164 | return added, removed, modified 165 | 166 | def serialize(self) -> Dict[str, Any]: 167 | flattened = self.flatten() 168 | for param_keys, param_metadata in flattened.items(): 169 | flattened[param_keys] = param_metadata.serialize() 170 | return flattened.unflatten() 171 | 172 | def __str__(self) -> str: 173 | metadata_dict = self.serialize() 174 | return json.dumps(metadata_dict, indent=4, cls=MetadataEncoder) 175 | 176 | 177 | class MetadataEncoder(json.JSONEncoder): 178 | def default(self, obj): 179 | if isinstance(obj, np.ndarray): 180 | return obj.tolist() 181 | else: 182 | return json.JSONEncoder.default(self, obj) 183 | -------------------------------------------------------------------------------- /git_theta/params.py: -------------------------------------------------------------------------------- 1 | """Classes for serializing model updates.""" 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | import msgpack 6 | import tensorstore as ts 7 | 8 | 9 | class TensorSerializer(metaclass=ABCMeta): 10 | """Serialize/Deserialize tensors.""" 11 | 12 | @abstractmethod 13 | async def serialize(self, tensor): 14 | """Convert a tensor to bytes.""" 15 | 16 | @abstractmethod 17 | async def deserialize(self, serialized_tensor): 18 | """Convert bytes to a tensor object.""" 19 | 20 | 21 | class TensorStoreSerializer(TensorSerializer): 22 | async def serialize(self, tensor): 23 | store = await ts.open( 24 | { 25 | "driver": "zarr", 26 | "kvstore": {"driver": "memory"}, 27 | "metadata": {"shape": tensor.shape, "dtype": tensor.dtype.str}, 28 | "create": True, 29 | }, 30 | ) 31 | await store.write(tensor) 32 | serialized_param = { 33 | k.decode("utf-8"): store.kvstore[k] for k in await store.kvstore.list() 34 | } 35 | return serialized_param 36 | 37 | async def deserialize(self, serialized_tensor): 38 | ctx = ts.Context() 39 | kvs = await ts.KvStore.open("memory://", context=ctx) 40 | for name, contents in serialized_tensor.items(): 41 | kvs[name] = contents 42 | 43 | store = await ts.open({"driver": "zarr", "kvstore": "memory://"}, context=ctx) 44 | param = await store.read() 45 | return param 46 | 47 | 48 | class FileCombiner(metaclass=ABCMeta): 49 | """Combine and Split serialized tensors, enables single blob processing for multiple tensors.""" 50 | 51 | @abstractmethod 52 | def combine(self, files): 53 | """Combine multiple byte steams into one.""" 54 | 55 | @abstractmethod 56 | def split(self, file): 57 | """Split a combined byte stream into original bytes.""" 58 | 59 | 60 | class MsgPackCombiner(FileCombiner): 61 | def combine(self, files): 62 | return msgpack.packb(files, use_bin_type=True) 63 | 64 | def split(self, file): 65 | return msgpack.unpackb(file, raw=False) 66 | 67 | 68 | class Serializer(metaclass=ABCMeta): 69 | """Serialize/Deserialize parameters, even when represented with multiple tensors.""" 70 | 71 | @abstractmethod 72 | async def serialize(self, params): 73 | """Serialize parameter.""" 74 | 75 | @abstractmethod 76 | async def deserialize(self, serialized): 77 | """Deserialize parameter.""" 78 | 79 | 80 | class UpdateSerializer(Serializer): 81 | def __init__(self, tensor_serializer, file_combiner): 82 | self.serializer = tensor_serializer 83 | self.combiner = file_combiner 84 | 85 | async def serialize(self, params): 86 | serialized_params = { 87 | name: await self.serializer.serialize(param) 88 | for name, param in params.items() 89 | } 90 | return self.combiner.combine(serialized_params) 91 | 92 | async def deserialize(self, serialized): 93 | serialized_params = self.combiner.split(serialized) 94 | update_params = { 95 | name: await self.serializer.deserialize(serialized_param) 96 | for name, serialized_param in serialized_params.items() 97 | } 98 | return update_params 99 | 100 | 101 | def get_update_serializer(): 102 | # TODO: Right now this just returns a tensorstore/msgpack serializer but in 103 | # the future we can implement other Serializers and/or support user plugins 104 | return UpdateSerializer(TensorStoreSerializer(), MsgPackCombiner()) 105 | -------------------------------------------------------------------------------- /git_theta/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """Shared logging setup for scripts. 2 | 3 | This is not part of the main git_theta.__init__ as we don't want to configure 4 | logging if git_theta is being used as a library. 5 | """ 6 | 7 | import logging 8 | import os 9 | import tempfile 10 | from typing import Optional 11 | 12 | import git_theta 13 | 14 | 15 | def configure_logging( 16 | exe_name: str, logger_name: str = "git_theta", root: Optional[str] = None 17 | ): 18 | logger = logging.getLogger(logger_name) 19 | format_str = f"{exe_name}: [%(asctime)s] [%(task)s] [%(package)s:%(funcName)s] %(levelname)s - %(message)s" 20 | log_level = getattr( 21 | logging, git_theta.utils.EnvVarConstants.LOG_LEVEL.upper(), logging.DEBUG 22 | ) 23 | logger.setLevel(log_level) 24 | formatter = logging.Formatter(fmt=format_str) 25 | 26 | root = ( 27 | os.path.dirname(os.path.dirname(git_theta.__file__)) if root is None else root 28 | ) 29 | 30 | def log_filter(record: logging.LogRecord) -> logging.LogRecord: 31 | package = record.pathname[len(root) + 1 :] 32 | if package.endswith(".py"): 33 | package = package[:-3] 34 | record.package = package.replace(os.sep, ".") 35 | return record 36 | 37 | handlers = ( 38 | git_theta.async_utils.AsyncTaskStreamHandler(), 39 | git_theta.async_utils.AsyncTaskFileHandler( 40 | filename=os.path.join(tempfile.gettempdir(), "git-theta.log") 41 | ), 42 | ) 43 | for handler in handlers: 44 | handler.setLevel(log_level) 45 | handler.setFormatter(formatter) 46 | handler.addFilter(log_filter) 47 | logger.addHandler(handler) 48 | 49 | return logger 50 | -------------------------------------------------------------------------------- /git_theta/scripts/git_theta_cli.py: -------------------------------------------------------------------------------- 1 | """Installation and .git manipulation scripts.""" 2 | 3 | import argparse 4 | import fnmatch 5 | import logging 6 | import re 7 | import sys 8 | 9 | import git 10 | 11 | if sys.version_info < (3, 10): 12 | from importlib_metadata import entry_points 13 | else: 14 | from importlib.metadata import entry_points 15 | 16 | import git_theta 17 | from git_theta import async_utils, git_utils, metadata, theta, utils 18 | 19 | git_theta.scripts.configure_logging("git-theta") 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="git-theta filter program") 24 | subparsers = parser.add_subparsers(title="Commands", dest="command") 25 | subparsers.required = True 26 | 27 | post_commit_parser = subparsers.add_parser( 28 | "post-commit", 29 | help="post-commit command that records parameter groups changed in a commit", 30 | ) 31 | post_commit_parser.set_defaults(func=post_commit) 32 | 33 | pre_push_parser = subparsers.add_parser( 34 | "pre-push", 35 | help="pre-push command used to send parameter groups to an LFS store. Should only be called internally by git push.", 36 | ) 37 | pre_push_parser.add_argument( 38 | "remote_name", help="Name of the remote being pushed to" 39 | ) 40 | pre_push_parser.add_argument( 41 | "remote_location", help="Location of the remote being pushed to" 42 | ) 43 | pre_push_parser.set_defaults(func=pre_push) 44 | 45 | ls_files_parser = subparsers.add_parser( 46 | "ls-files", help="List files that are tracked by git-theta." 47 | ) 48 | ls_files_parser.add_argument( 49 | "args", nargs="*", default=None, help="The raw args to pass to git ls-files" 50 | ) 51 | ls_files_parser.set_defaults(func=ls_files) 52 | 53 | install_parser = subparsers.add_parser( 54 | "install", help="Install command used to setup git-theta via git configs." 55 | ) 56 | install_parser.add_argument( 57 | "--scope", 58 | choices=["global", "repository", "user", "system"], 59 | default="global", 60 | help="Which git config location to use.", 61 | ) 62 | install_parser.set_defaults(func=install) 63 | 64 | track_parser = subparsers.add_parser( 65 | "track", 66 | help="track command used to identify model checkpoint for git-theta to track", 67 | ) 68 | track_parser.add_argument( 69 | "file", help="model checkpoint file or file pattern to track" 70 | ) 71 | track_parser.set_defaults(func=track) 72 | 73 | add_parser = subparsers.add_parser("add", help="add command used to stage files.") 74 | add_parser.add_argument("file", help="The file we are git adding.") 75 | add_parser.add_argument( 76 | "--update-type", 77 | choices=[e.name for e in entry_points(group="git_theta.plugins.updates")], 78 | help="Type of update being applied", 79 | ) 80 | add_parser.add_argument("--update-data", help="Where update data is stored.") 81 | add_parser.set_defaults(func=add) 82 | 83 | args = parser.parse_known_args() 84 | return args 85 | 86 | 87 | def post_commit(args): 88 | """ 89 | Post-commit git hook that records the LFS objects that were created (equivalent to parameter groups that were modified) in this commit 90 | """ 91 | repo = git_utils.get_git_repo() 92 | theta_commits = theta.ThetaCommits(repo) 93 | 94 | gitattributes_file = git_utils.get_gitattributes_file(repo) 95 | gitattributes = git_utils.read_gitattributes(gitattributes_file) 96 | 97 | oids = set() 98 | commit = repo.commit("HEAD") 99 | for path in commit.stats.files.keys(): 100 | if git_utils.is_theta_tracked(path, gitattributes): 101 | curr_metadata = metadata.Metadata.from_file(commit.tree[path].data_stream) 102 | prev_metadata = metadata.Metadata.from_commit(repo, path, "HEAD~1") 103 | 104 | added, removed, modified = curr_metadata.diff(prev_metadata) 105 | oids.update([param.lfs_metadata.oid for param in added.flatten().values()]) 106 | oids.update( 107 | [param.lfs_metadata.oid for param in modified.flatten().values()] 108 | ) 109 | 110 | commit_info = theta.CommitInfo(oids) 111 | theta_commits.write_commit_info(commit.hexsha, commit_info) 112 | 113 | 114 | def pre_push(args): 115 | """ 116 | Pre-push git hook for sending objects to the LFS server 117 | """ 118 | repo = git_utils.get_git_repo() 119 | theta_commits = theta.ThetaCommits(repo) 120 | 121 | # Read lines of the form LF 122 | lines = sys.stdin.readlines() 123 | lines_parsed = git_utils.parse_pre_push_args(lines) 124 | commit_ranges = [ 125 | (l.group("remote_sha1"), l.group("local_sha1")) for l in lines_parsed 126 | ] 127 | oids = theta_commits.get_commit_oids_ranges(*commit_ranges) 128 | async_utils.run(git_utils.git_lfs_push_oids(args.remote_name, oids)) 129 | 130 | 131 | def ls_files(args): 132 | repo = git_utils.get_git_repo() 133 | # TODO: git ls-files can take a bunch of extra args that we pass in here, 134 | # it is unclear if that is working/if the format of the output changes, but 135 | # this implementation covers the common usages. 136 | # Note: We use repo.git.ls_files instead of traversing the tree ourself as 137 | # git ls-files leverages the index better than we can from GitPython. 138 | if args.args: 139 | files = repo.git.ls_files(args.args).split("\n") 140 | else: 141 | files = repo.git.ls_files().split("\n") 142 | 143 | gitattributes_file = git_utils.get_gitattributes_file(repo) 144 | gitattributes = git_utils.read_gitattributes(gitattributes_file) 145 | 146 | for path in files: 147 | if git_utils.is_theta_tracked(path, gitattributes): 148 | print(path) 149 | 150 | 151 | def install(args): 152 | """ 153 | Install git-lfs and initialize the git-theta filter driver 154 | """ 155 | # check if git-lfs is installed and abort if not 156 | logger = logging.getLogger("git_theta") 157 | if not git_utils.is_git_lfs_installed(): 158 | logger.error( 159 | "git-theta depends on git-lfs and it does not appear to be installed. See installation directions at https://github.com/r-three/git-theta/blob/main/README.md#git-lfs-installation" 160 | ) 161 | sys.exit(1) 162 | 163 | if args.scope == "repository": 164 | # To install at the repository level, you need to be in repo. 165 | try: 166 | repo = git_utils.get_git_repo() 167 | except git.exc.InvalidGitRepositoryError as e: 168 | logger.error( 169 | "Tried to install git-theta at the repository level, but you are not in a git repository. Please navigate to the repository you want to use git-theta in." 170 | ) 171 | sys.exit(1) 172 | # We are using a private method, but the error message from trying to get 173 | # a repository level config file from `git.config.get_config_path` said to 174 | # use this method. 175 | path = repo._get_config_path(args.scope) 176 | else: 177 | path = git.config.get_config_path(args.scope) 178 | 179 | logger.debug(f"Installing git-theta via git configuration file at {path}") 180 | config_writer = git.GitConfigParser(path, config_level=args.scope, read_only=False) 181 | config_writer.set_value('filter "theta"', "clean", "git-theta-filter clean %f") 182 | config_writer.set_value('filter "theta"', "smudge", "git-theta-filter smudge %f") 183 | config_writer.set_value('filter "theta"', "required", "true") 184 | config_writer.set_value('merge "theta"', "name", "Merge Models with Git-Theta") 185 | config_writer.set_value('merge "theta"', "driver", "git-theta-merge %O %A %B %P") 186 | config_writer.set_value('diff "theta"', "command", "git-theta-diff") 187 | config_writer.release() 188 | 189 | 190 | def track(args): 191 | """ 192 | Track a particular model checkpoint file with git-theta 193 | """ 194 | repo = git_utils.get_git_repo() 195 | if not git_utils.is_git_theta_installed(): 196 | logger = logging.getLogger("git_theta") 197 | logger.error( 198 | "You are trying to track a file with git-theta, but git-theta is not installed, please run `git theta install`." 199 | ) 200 | sys.exit(1) 201 | model_path = git_utils.get_relative_path_from_root(repo, args.file) 202 | 203 | gitattributes_file = git_utils.get_gitattributes_file(repo) 204 | gitattributes = git_utils.read_gitattributes(gitattributes_file) 205 | 206 | new_gitattributes = git_utils.add_theta_to_gitattributes(gitattributes, model_path) 207 | 208 | git_utils.write_gitattributes(gitattributes_file, new_gitattributes) 209 | git_utils.add_file(gitattributes_file, repo) 210 | 211 | 212 | def add(args, unparsed_args): 213 | repo = git_utils.get_git_repo() 214 | if not git_utils.is_git_theta_installed(): 215 | logger = logging.getLogger("git_theta") 216 | logger.error( 217 | "You are trying to add a file using git-theta, but git-theta is not installed, please run `git theta install`." 218 | ) 219 | sys.exit(1) 220 | env_vars = { 221 | "GIT_THETA_UPDATE_TYPE": args.update_type, 222 | "GIT_THETA_UPDATE_DATA_PATH": args.update_data, 223 | } 224 | # The most common use for `git theta add` is when you have side-loaded 225 | # information and thus the main checkpoint file has not been modified. This 226 | # results in git not running the add command as the modification time has 227 | # not changed. We touch the file so it will actually get added. 228 | utils.touch(args.file) 229 | with repo.git.custom_environment(**env_vars): 230 | repo.git.add(args.file, *unparsed_args) 231 | 232 | 233 | def main(): 234 | args, unparsed_args = parse_args() 235 | if not args.func == install: 236 | git_utils.set_hooks() 237 | if args.func == add: 238 | args.func(args, unparsed_args) 239 | else: 240 | args.func(args) 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | -------------------------------------------------------------------------------- /git_theta/scripts/git_theta_diff.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import textwrap 4 | 5 | import numpy as np 6 | from colorama import Fore, Style 7 | 8 | import git_theta 9 | from git_theta import checkpoints, metadata 10 | 11 | git_theta.scripts.configure_logging("git-theta-diff") 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description="git-theta diff program") 16 | parser.add_argument("path", help="path to file being diff-ed") 17 | 18 | parser.add_argument( 19 | "old_checkpoint", help="file that old version of checkpoint can be read from" 20 | ) 21 | parser.add_argument("old_hex", help="SHA-1 hash of old version of checkpoint") 22 | parser.add_argument("old_mode", help="file mode for old version of checkpoint") 23 | 24 | parser.add_argument( 25 | "new_checkpoint", help="file that new version of checkpoint can be read from" 26 | ) 27 | parser.add_argument("new_hex", help="SHA-1 hash of new version of checkpoint") 28 | parser.add_argument("new_mode", help="file mode for new version of checkpoint") 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def color_string(s, color): 35 | return f"{color}{s}" if color else s 36 | 37 | 38 | def bold_string(s): 39 | return f"{Style.BRIGHT}{s}" 40 | 41 | 42 | def print_formatted(s, indent=0, color=None, bold=False): 43 | if indent: 44 | s = "\n".join( 45 | textwrap.wrap( 46 | s, indent=" " * 4 * indent, subsequent_indent=" " * 4 * (indent + 1) 47 | ) 48 | ) 49 | if color: 50 | s = color_string(s, color) 51 | if bold: 52 | s = bold_string(s) 53 | print(s) 54 | 55 | 56 | def print_header(header, indent=0, color=None): 57 | print_formatted(header, indent=indent, color=color, bold=True) 58 | print_formatted("-" * len(header), indent=indent, color=color, bold=True) 59 | 60 | 61 | def print_added_params_summary(added, indent=0, color=None): 62 | if added: 63 | print_header("ADDED PARAMETER GROUPS", indent=indent, color=color) 64 | for flattened_group, param in added.flatten().items(): 65 | group = "/".join(flattened_group) 66 | print_formatted(group, indent=indent, color=color) 67 | print_formatted("\n") 68 | 69 | 70 | def print_removed_params_summary(removed, indent=0, color=None): 71 | if removed: 72 | print_header("REMOVED PARAMETER GROUPS", indent=indent, color=color) 73 | for flattened_group, param in removed.flatten().items(): 74 | group = "/".join(flattened_group) 75 | print_formatted(group, indent=indent, color=color) 76 | print_formatted("\n") 77 | 78 | 79 | def print_modified_params_summary(modified, indent=0, color=None): 80 | if modified: 81 | print_header("MODIFIED PARAMETER GROUPS", indent=indent, color=color) 82 | for flattened_group, param in modified.flatten().items(): 83 | group = "/".join(flattened_group) 84 | print_formatted(group, indent=indent, color=color) 85 | print_formatted("\n") 86 | 87 | 88 | def main(): 89 | args = parse_args() 90 | checkpoint_handler = checkpoints.get_checkpoint_handler() 91 | old_checkpoint = checkpoint_handler.from_file(args.old_checkpoint) 92 | new_checkpoint = checkpoint_handler.from_file(args.new_checkpoint) 93 | added, removed, modified = checkpoint_handler.diff(new_checkpoint, old_checkpoint) 94 | 95 | print_added_params_summary(added, indent=0, color=Fore.GREEN) 96 | print_removed_params_summary(removed, indent=0, color=Fore.RED) 97 | print_modified_params_summary(modified, indent=0, color=Fore.YELLOW) 98 | -------------------------------------------------------------------------------- /git_theta/scripts/git_theta_filter.py: -------------------------------------------------------------------------------- 1 | """Clean and Smudge filters for version controlling machine learning models.""" 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import sys 7 | 8 | import git_theta 9 | from git_theta import checkpoints, git_utils, metadata 10 | from git_theta.filters import clean, smudge 11 | from git_theta.utils import EnvVarConstants 12 | 13 | git_theta.scripts.configure_logging("git-theta-filter") 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="git-theta filter program") 18 | subparsers = parser.add_subparsers(title="Commands", dest="command") 19 | subparsers.required = True 20 | 21 | clean_parser = subparsers.add_parser("clean", help="clean filter") 22 | clean_parser.add_argument("file", help="file being passed to clean filter") 23 | clean_parser.set_defaults(func=run_clean) 24 | 25 | smudge_parser = subparsers.add_parser("smudge", help="smudge filter") 26 | smudge_parser.add_argument("file", help="file being passed to smudge filter") 27 | smudge_parser.set_defaults(func=run_smudge) 28 | 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def run_clean(args): 34 | """ 35 | Implements clean filter for model files 36 | """ 37 | logger = logging.getLogger("git_theta") 38 | logger.debug(f"Running clean filter on {args.file}") 39 | repo = git_utils.get_git_repo() 40 | checkpoint_handler = checkpoints.get_checkpoint_handler() 41 | if EnvVarConstants.LOW_MEMORY: 42 | logger.warning( 43 | "Running Git-Theta in low memory mode. No concurrency is supported and the original checkout value will be transiently stored in a temporary file." 44 | ) 45 | temp_file = f".{args.file}-temp-checkpoint" 46 | try: 47 | # In some places we don't have enough space when you write to the 48 | # tempfile location. 49 | logger.debug(f"Writing checkpoint to {temp_file}") 50 | with open(temp_file, "w+b") as tmp: 51 | tmp.write(sys.stdin.buffer.read()) 52 | logger.debug(f"Reading checkpoint from {temp_file}") 53 | # We write and then seek instead of write,close,open because this was 54 | # originally written to use the tempfile lib, but there were space 55 | # issues. We keep that paradigm as we may switch back eventually, 56 | tmp.seek(0) 57 | model_checkpoint = checkpoint_handler.from_file(tmp) 58 | finally: 59 | # Make sure we always remove the temp checkpoint file. 60 | os.remove(temp_file) 61 | else: 62 | model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer) 63 | new_metadata = clean(model_checkpoint, repo, args.file) 64 | new_metadata.write(sys.stdout) 65 | # If we had side-loaded information, write it out so we don't get false 66 | # positives for `git status` 67 | if EnvVarConstants.UPDATE_DATA_PATH: 68 | smudge(new_metadata, repo, args.file) 69 | 70 | 71 | def run_smudge(args): 72 | """ 73 | Implements smudge filter for model files 74 | """ 75 | logger = logging.getLogger("git_theta") 76 | logger.debug(f"Running smudge filter on {args.file}") 77 | 78 | repo = git_utils.get_git_repo() 79 | curr_metadata = metadata.Metadata.from_file(sys.stdin) 80 | model_checkpoint = smudge(curr_metadata, repo, args.file) 81 | model_checkpoint.save(sys.stdout.buffer) 82 | 83 | 84 | def main(): 85 | args = parse_args() 86 | git_utils.set_hooks() 87 | args.func(args) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /git_theta/theta.py: -------------------------------------------------------------------------------- 1 | """Module for reading and writing to .git/theta""" 2 | 3 | import functools 4 | import json 5 | import logging 6 | import os 7 | import re 8 | 9 | from file_or_name import file_or_name 10 | 11 | from git_theta import utils 12 | 13 | 14 | class CommitInfo: 15 | def __init__(self, oids): 16 | self.oids = set(oids) if oids else set() 17 | if not all(map(utils.is_valid_oid, self.oids)): 18 | invalid_oids = filter(lambda x: not utils.is_valid_oid(x), self.oids) 19 | raise ValueError(f"Invalid LFS object-ids {list(invalid_oids)}") 20 | 21 | def __eq__(self, other): 22 | return self.oids == other.oids 23 | 24 | @classmethod 25 | @file_or_name(f="r") 26 | def from_file(cls, f): 27 | commit_info_dict = json.load(f) 28 | return cls(commit_info_dict.get("oids")) 29 | 30 | @file_or_name(f="w") 31 | def write(self, f): 32 | commit_info_dict = {"oids": list(self.oids)} 33 | json.dump(commit_info_dict, f, indent=4) 34 | 35 | 36 | class ThetaCommits: 37 | def __init__(self, repo): 38 | self.repo = repo 39 | self.path = os.path.abspath(os.path.join(repo.git_dir, "theta", "commits")) 40 | os.makedirs(self.path, exist_ok=True) 41 | self.logger = logging.getLogger("git_theta") 42 | 43 | @staticmethod 44 | def combine_oid_sets(oid_sets): 45 | return functools.reduce(lambda a, b: a.union(b), oid_sets, set()) 46 | 47 | def get_commit_path(self, commit_hash): 48 | if not utils.is_valid_commit_hash(commit_hash): 49 | raise ValueError(f"Invalid commit hash {commit_hash}") 50 | return os.path.join(self.path, commit_hash) 51 | 52 | def get_commit_info(self, commit_hash): 53 | path = self.get_commit_path(commit_hash) 54 | if not (os.path.exists(path) and os.path.isfile(path)): 55 | raise ValueError(f"commit {commit_hash} is not found in {self.path}") 56 | commit = CommitInfo.from_file(path) 57 | return commit 58 | 59 | def get_commit_info_range(self, start_hash, end_hash): 60 | self.logger.debug(f"Getting commits from {start_hash}..{end_hash}") 61 | # N.b. the all-zero hash is used by git to indicate a non-existent start hash 62 | # For example, a git pre-push hook will receive the all-zero hash if the remote ref does not have any commit history 63 | if re.match("^0{40}$", start_hash): 64 | commits = list(self.repo.iter_commits(end_hash)) 65 | else: 66 | commits = list(self.repo.iter_commits(f"{start_hash}..{end_hash}")) 67 | 68 | self.logger.debug(f"Found commits {commits}") 69 | commit_infos = [self.get_commit_info(commit.hexsha) for commit in commits] 70 | return commit_infos 71 | 72 | def get_commit_oids(self, commit_hash): 73 | self.logger.debug(f"Getting oids from commit {commit_hash}") 74 | commit_info = self.get_commit_info(commit_hash) 75 | oids = commit_info.oids 76 | self.logger.debug(f"Found oids {oids}") 77 | return oids 78 | 79 | def get_commit_oids_range(self, start_hash, end_hash): 80 | self.logger.debug(f"Getting oids from commit range {start_hash}..{end_hash}") 81 | commit_infos = self.get_commit_info_range(start_hash, end_hash) 82 | oids = ThetaCommits.combine_oid_sets( 83 | [commit_info.oids for commit_info in commit_infos] 84 | ) 85 | self.logger.debug(f"Found oids {oids}") 86 | return oids 87 | 88 | def get_commit_oids_ranges(self, *ranges): 89 | oids = [ 90 | self.get_commit_oids_range(start_hash, end_hash) 91 | for start_hash, end_hash in ranges 92 | ] 93 | return ThetaCommits.combine_oid_sets(oids) 94 | 95 | def write_commit_info(self, commit_hash, commit_info): 96 | self.logger.debug(f"Writing commit_info to commit {commit_hash}") 97 | if not utils.is_valid_commit_hash(commit_hash): 98 | raise ValueError(f"Cannot write commit info for invalid hash {commit_hash}") 99 | path = self.get_commit_path(commit_hash) 100 | if os.path.exists(path): 101 | raise ValueError( 102 | f"Cannot duplicate commit info at {path}. Something is wrong!" 103 | ) 104 | commit_info.write(path) 105 | -------------------------------------------------------------------------------- /git_theta/types.py: -------------------------------------------------------------------------------- 1 | """Common types used in git-theta.""" 2 | 3 | 4 | from typing import Tuple 5 | 6 | ParamName = Tuple[str, ...] 7 | -------------------------------------------------------------------------------- /git_theta/updates/__init__.py: -------------------------------------------------------------------------------- 1 | """Classes for controlling how parameter updates are made.""" 2 | 3 | from git_theta.updates.base import IncrementalUpdate, Update, get_update_handler 4 | -------------------------------------------------------------------------------- /git_theta/updates/base.py: -------------------------------------------------------------------------------- 1 | """Base class for parameter update plugins.""" 2 | 3 | import os 4 | import sys 5 | from abc import ABCMeta, abstractmethod 6 | 7 | if sys.version_info < (3, 10): 8 | from importlib_metadata import entry_points 9 | else: 10 | from importlib.metadata import entry_points 11 | 12 | import logging 13 | from typing import Dict, FrozenSet, Optional, Tuple 14 | 15 | import numpy as np 16 | 17 | from git_theta import checkpoints, git_utils, lsh, metadata, params, utils 18 | from git_theta.lsh.types import Signature 19 | 20 | Parameter = np.ndarray 21 | 22 | 23 | @utils.abstract_classattributes("name") 24 | class Update(metaclass=ABCMeta): 25 | """Base class for parameter update plugins.""" 26 | 27 | name: str = NotImplemented # The name used to lookup the plug-in. 28 | 29 | def __init__(self, serializer: params.Serializer, *args, **kwargs): 30 | self.serializer = serializer 31 | 32 | async def read(self, param_metadata: metadata.ParamMetadata) -> Parameter: 33 | """Read in and deserialize a single parameter value based metadata.""" 34 | lfs_pointer = param_metadata.lfs_metadata.lfs_pointer 35 | serialized_param = await git_utils.git_lfs_smudge(lfs_pointer) 36 | param = await self.serializer.deserialize(serialized_param) 37 | return param.get("parameter", param) 38 | 39 | def will_update(self, param_keys: Tuple[str]) -> bool: 40 | return False 41 | 42 | @abstractmethod 43 | async def write( 44 | self, param: Parameter, param_keys: Tuple[str], **kwargs 45 | ) -> Tuple[metadata.LfsMetadata, Signature]: 46 | """Serialize and save a parameter with git-lfs.""" 47 | 48 | @abstractmethod 49 | async def apply( 50 | self, param_metadata: metadata.ParamMetadata, param_keys: Tuple[str], **kwargs 51 | ) -> Parameter: 52 | """Get the final parameter value, including fetching previous values.""" 53 | 54 | 55 | # TODO: Fix this for inheritance so we don't need to dup "name" here. 56 | @utils.abstract_classattributes("name", "required_keys") 57 | class IncrementalUpdate(Update): 58 | """Base class for parameter updates that depend on the previous value.""" 59 | 60 | required_keys: FrozenSet[str] = NotImplemented # Names for side-loaded information. 61 | 62 | def __init__(self, serializer: params.Serializer, update_data: str = ""): 63 | super().__init__(serializer) 64 | self.update_information: Dict[str, np.ndarray] = None 65 | self.update_names: utils.Trie = None 66 | # Flatten the side-loaded information into a of string keys to arrays. 67 | if update_data: 68 | self.update_information = { 69 | "/".join(k): v 70 | for k, v in checkpoints.get_checkpoint_handler() 71 | .from_file(update_data) 72 | .flatten() 73 | .items() 74 | } 75 | self.update_names = utils.Trie.from_iterable(self.update_information.keys()) 76 | self.logger = logging.getLogger("git_theta") 77 | 78 | def will_update(self, param_keys: Tuple[str]) -> bool: 79 | if self.update_information is not None: 80 | param_keys = "/".join(param_keys) 81 | return self.update_names.prefix(param_keys) 82 | return False 83 | 84 | async def get_previous_metadata( 85 | self, 86 | param_metadata: metadata.ParamMetadata, 87 | param_keys: Tuple[str], 88 | repo, 89 | path: str, 90 | ) -> metadata.ParamMetadata: 91 | """Get the metadata from the last time this parameter was updated via git.""" 92 | self.logger.debug(f"Getting previous metadata for {'/'.join(param_keys)}") 93 | self.logger.debug( 94 | f"Current Metadata for {'/'.join(param_keys)}: {param_metadata}" 95 | ) 96 | last_commit = param_metadata.theta_metadata.last_commit 97 | # TODO: Currently, if the model checkpoint is added during the first commit 98 | # then we can't do a sparse update until a second dense update is commited. 99 | if not last_commit: 100 | raise ValueError( 101 | f"Cannot find previous version for parameter {'/'.join(param_keys)}" 102 | ) 103 | self.logger.debug( 104 | f"Getting metadata for {'/'.join(param_keys)} from commit {last_commit}" 105 | ) 106 | last_metadata_obj = git_utils.get_file_version(repo, path, last_commit) 107 | last_metadata = metadata.Metadata.from_file(last_metadata_obj.data_stream) 108 | last_param_metadata = last_metadata.flatten()[param_keys] 109 | self.logger.debug( 110 | f"Previous Metadata for {'/'.join(param_keys)}: {last_param_metadata}" 111 | ) 112 | return last_param_metadata 113 | 114 | async def get_previous_value( 115 | self, 116 | param_metadata: metadata.ParamMetadata, 117 | param_keys: Tuple[str], 118 | repo, 119 | path: str, 120 | ) -> Parameter: 121 | """Get the last value for this parameter via git.""" 122 | self.logger.debug(f"Getting previous value for {'/'.join(param_keys)}") 123 | # TODO: get_update_serializer returns instantiated objects while the other 124 | # getters return classes to be instantiated. 125 | prev_serializer = params.get_update_serializer() 126 | prev_update = get_update_handler(param_metadata.theta_metadata.update_type)( 127 | prev_serializer 128 | ) 129 | return await prev_update.apply(param_metadata, param_keys, repo=repo, path=path) 130 | 131 | @abstractmethod 132 | async def calculate_update( 133 | self, parameter: Parameter, previous_parameter: Parameter 134 | ) -> Parameter: 135 | """Calculate the update required to go from previous_parameter -> parameter.""" 136 | 137 | async def read_update(self, param_keys) -> Parameter: 138 | return { 139 | k: self.update_information["/".join(param_keys + (k,))] 140 | for k in self.required_keys 141 | } 142 | 143 | @classmethod 144 | @abstractmethod 145 | async def apply_update(cls, update: Parameter, previous: Parameter) -> Parameter: 146 | """Apply the update to the previous value to get the new value.""" 147 | 148 | @abstractmethod 149 | def format_update(self, param: Parameter, *args, **kwargs) -> Parameter: 150 | """A user-facing helper function to help format an update for git-theta.""" 151 | 152 | async def write_update(self, update: Parameter) -> metadata.LfsMetadata: 153 | """Save and serialize (just) the update weights.""" 154 | if not isinstance(update, dict): 155 | update = {"parameter": update} 156 | serialized_update = await self.serializer.serialize(update) 157 | lfs_pointer = await git_utils.git_lfs_clean(serialized_update) 158 | return metadata.LfsMetadata.from_pointer(lfs_pointer) 159 | 160 | # TODO: Revisit what the metadata a write takes, right now it gets the full 161 | # metadata of the previous parameter value but only uses the update type. If 162 | # we do call get_previous_metadata on it, like we do in apply, the result is 163 | # that a parameter value is skipped and we calculate the incremental update 164 | # from 2 steps back, which can be a foot-gun. 165 | async def write( 166 | self, 167 | param: Parameter, 168 | param_keys, 169 | *, 170 | prev_metadata: metadata.ParamMetadata, 171 | repo, 172 | path: str, 173 | **kwargs, 174 | ) -> metadata.LfsMetadata: 175 | """Serialize and save a parameter with git-lfs as a delta from the previous value.""" 176 | self.logger.debug(f"Writing {self.name} update for {'/'.join(param_keys)}") 177 | previous_value = await self.get_previous_value( 178 | prev_metadata, param_keys, repo=repo, path=path 179 | ) 180 | if self.update_information is not None and self.will_update(param_keys): 181 | update_value = await self.read_update(param_keys) 182 | # Calculate and hash the *new* value so that we can update the 183 | # metadata when using side-loaded information. 184 | new_value = await self.apply_update(update_value, previous_value) 185 | new_hash = lsh.get_lsh().hash(new_value) 186 | return await self.write_update(update_value), new_hash 187 | else: 188 | update_value = await self.calculate_update(param, previous_value) 189 | return await self.write_update(update_value), None 190 | 191 | async def apply( 192 | self, 193 | param_metadata: metadata.ParamMetadata, 194 | param_keys: Tuple[str], 195 | *, 196 | repo, 197 | path: str, 198 | **kwargs, 199 | ) -> Parameter: 200 | """Get the final parameter value, including fetching previous values.""" 201 | self.logger.debug(f"Applying {self.name} update for {'/'.join(param_keys)}") 202 | update_value = await self.read(param_metadata) 203 | # param_metadata is the metadata for the parameter as it is *at this 204 | # commit*. 205 | prev_metadata = await self.get_previous_metadata( 206 | param_metadata, param_keys, repo=repo, path=path 207 | ) 208 | prev_value = await self.get_previous_value( 209 | prev_metadata, param_keys, repo=repo, path=path 210 | ) 211 | return await self.apply_update(update_value, prev_value) 212 | 213 | 214 | def get_update_handler_name(update_type: Optional[str] = None) -> str: 215 | return update_type or utils.EnvVarConstants.UPDATE_TYPE 216 | 217 | 218 | def get_update_handler(update_type: Optional[str] = None) -> Update: 219 | """Get an Update class by name. 220 | 221 | Parameters 222 | ---------- 223 | update_type: 224 | The name of the update type we want to use. 225 | 226 | Returns 227 | ------- 228 | Update 229 | The update class. Returned class may be defined in a user installed 230 | plugin. 231 | """ 232 | update_name = get_update_handler_name(update_type) 233 | discovered_plugins = entry_points(group="git_theta.plugins.updates") 234 | return discovered_plugins[update_name].load() 235 | -------------------------------------------------------------------------------- /git_theta/updates/dense.py: -------------------------------------------------------------------------------- 1 | """Class managing dense parameter updates.""" 2 | 3 | import logging 4 | from typing import Any, Optional 5 | 6 | from git_theta import git_utils, metadata 7 | from git_theta.updates import Update 8 | 9 | Parameter = Any 10 | 11 | 12 | class DenseUpdate(Update): 13 | """An update where all parameters are changed.""" 14 | 15 | name: str = "dense" 16 | 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.logger = logging.getLogger("git_theta") 20 | 21 | async def apply(self, param_metadata, param_keys, *args, **kwargs) -> Parameter: 22 | param_name = "/".join(param_keys) 23 | self.logger.debug(f"Reading Dense update for {param_name}") 24 | tensor = await self.read(param_metadata) 25 | self.logger.debug(f"Finished Read Dense update for {param_name}") 26 | return tensor 27 | 28 | async def write(self, param, param_keys, *args, **kwargs) -> metadata.LfsMetadata: 29 | param_name = "/".join(param_keys) 30 | self.logger.debug(f"Writing Dense update for {param_name}") 31 | self.logger.debug(f"Starting Serializing {param_name}") 32 | serialized = await self.serializer.serialize({"parameter": param}) 33 | self.logger.debug(f"Finished Serializing {param_name}") 34 | self.logger.debug(f"Starting git-lfs clean for {param_name}") 35 | lfs_pointer = await git_utils.git_lfs_clean(serialized) 36 | self.logger.debug(f"Finished git-lfs clean for {param_name}") 37 | return metadata.LfsMetadata.from_pointer(lfs_pointer), None 38 | -------------------------------------------------------------------------------- /git_theta/updates/ia3.py: -------------------------------------------------------------------------------- 1 | """A class for handling activations scaling using ia3 vectors.""" 2 | 3 | import logging 4 | from typing import Any, FrozenSet, List, Optional 5 | 6 | import numpy as np 7 | 8 | from git_theta import params 9 | from git_theta.updates import IncrementalUpdate 10 | 11 | Parameter = Any 12 | 13 | 14 | class IA3Update(IncrementalUpdate): 15 | """An update where activations are scaled.""" 16 | 17 | name: str = "ia3" 18 | required_keys: FrozenSet[str] = frozenset(("ia3",)) 19 | 20 | @classmethod 21 | def format_update(cls, param: Parameter, *args, **kwargs) -> Parameter: 22 | """User-facing helper to convert an update to ia3.""" 23 | return {"ia3": param} 24 | 25 | async def calculate_update( 26 | self, 27 | parameter: Parameter, 28 | previous_parameter: Parameter, 29 | broadcast_dims: List[int], 30 | ) -> Parameter: 31 | """Calculate the update for the given parameter where ia3 is applied over broadcast dims.""" 32 | 33 | # use mask1 to prevent divide by zeros 34 | mask1 = previous_parameter != 0 35 | multiplier = np.divide( 36 | parameter, previous_parameter, out=np.zeros(parameter.shape), where=mask1 37 | ) 38 | 39 | # Calcuate ia3 by averaging multiplier over broadcast dims and take into account the fact that some values may be zero 40 | denominator = np.sum(mask1, axis=tuple(broadcast_dims), keepdims=True) 41 | mask2 = denominator != 0 42 | ia3_update = np.divide( 43 | np.sum(multiplier, axis=tuple(broadcast_dims), keepdims=True), 44 | denominator, 45 | out=np.zeros(parameter.shape), 46 | where=mask2, 47 | ) 48 | 49 | return {"ia3": ia3_update} 50 | 51 | async def apply_update(self, update: Parameter, previous: Parameter) -> Parameter: 52 | return previous * update["ia3"] 53 | -------------------------------------------------------------------------------- /git_theta/updates/low_rank.py: -------------------------------------------------------------------------------- 1 | """An update type where the update is stored as 2 low-rank matrices.""" 2 | 3 | 4 | import logging 5 | from typing import Any, FrozenSet, Optional 6 | 7 | import numpy as np 8 | 9 | from git_theta.updates import IncrementalUpdate 10 | 11 | Parameter = Any 12 | 13 | 14 | class LowRankUpdate(IncrementalUpdate): 15 | """An update make for 2 low rank matrices.""" 16 | 17 | name: str = "low-rank" 18 | required_keys: FrozenSet[str] = frozenset(("R", "C")) 19 | 20 | # TODO: Make these configuration options easy set. 21 | def __init__( 22 | self, *args, K: Optional[int] = None, threshold: float = 1e-11, **kwargs 23 | ): 24 | super().__init__(*args, **kwargs) 25 | self.K = K 26 | self.threshold = threshold 27 | 28 | @classmethod 29 | def format_update( 30 | cls, param1: Parameter, param2: Parameter, *args, **kwargs 31 | ) -> Parameter: 32 | return { 33 | "R": param1, 34 | "C": param2, 35 | } 36 | 37 | async def calculate_update( 38 | self, parameter: Parameter, previous_parameter: Parameter 39 | ) -> Parameter: 40 | update = parameter - previous_parameter 41 | if update.ndim < 2: 42 | return update 43 | logger = logging.getLogger("git_theta") 44 | logger.info("Inferring low-rank update based on SVD") 45 | u, s, vh = np.linalg.svd(update, full_matrices=False) 46 | if self.K is not None: 47 | k = self.K 48 | logger.info(f"Low Rank Update configured to have a rank of {k}") 49 | else: 50 | k = np.sum(s > self.threshold) 51 | logger.info(f"Low Rank Update inferred to have a rank of {k}") 52 | return {"R": u[:, :k], "C": (np.diag(s[:k]) @ vh[:k, :])} 53 | 54 | async def apply_update(self, update: Parameter, previous: Parameter) -> Parameter: 55 | if not isinstance(update, dict): 56 | return update + previous 57 | return update["R"] @ update["C"] + previous 58 | -------------------------------------------------------------------------------- /git_theta/updates/sparse.py: -------------------------------------------------------------------------------- 1 | """A class for handling sparse updates to parameters.""" 2 | 3 | import logging 4 | from typing import Any, FrozenSet, Optional 5 | 6 | import numpy as np 7 | import scipy.sparse 8 | 9 | from git_theta import params 10 | from git_theta.updates import IncrementalUpdate 11 | 12 | Parameter = Any 13 | 14 | 15 | class SparseUpdate(IncrementalUpdate): 16 | """An update where only some parameters are touched.""" 17 | 18 | name: str = "sparse" 19 | required_keys: FrozenSet[str] = frozenset(("data", "indices", "indptr", "shape")) 20 | 21 | def __init__( 22 | self, 23 | serializer: params.Serializer, 24 | update_data: str = "", 25 | threshold: float = 1e-12, 26 | ): 27 | # TODO: Make threshold configurable 28 | super().__init__(serializer, update_data) 29 | self.threshold = threshold 30 | 31 | @classmethod 32 | def format_update(cls, param: Parameter, *args, **kwargs) -> Parameter: 33 | """User-facing helper to convert an array to sparse storage.""" 34 | update = scipy.sparse.csr_matrix(np.reshape(param, (1, -1))) 35 | return { 36 | "data": update.data, 37 | "indices": update.indices, 38 | "indptr": update.indptr, 39 | "shape": np.array(param.shape), 40 | } 41 | 42 | async def calculate_update( 43 | self, parameter: Parameter, previous_parameter: Parameter 44 | ) -> Parameter: 45 | diff = parameter - previous_parameter 46 | diff[np.abs(diff) < self.threshold] = 0 47 | # csr_matrix looks for actual zeros in diff tensor. We added a configurable threshold to have the diff tensor (the update) be really sparse 48 | update = scipy.sparse.csr_matrix(np.reshape(diff, (1, -1))) 49 | return { 50 | "data": update.data, 51 | "indices": update.indices, 52 | "indptr": update.indptr, 53 | "shape": np.array(parameter.shape), 54 | } 55 | 56 | async def apply_update(self, update: Parameter, previous: Parameter) -> Parameter: 57 | # Provide shape of original flattened array to ensure correct shape of output. Without the provided shape, csr_matrix interprets the shape as (1, index of last occurence of a non-zero number in the original flattened array) 58 | param_update = scipy.sparse.csr_matrix( 59 | (update["data"], update["indices"], update["indptr"]), 60 | shape=(1, np.prod(update["shape"])), 61 | ) 62 | return np.reshape(param_update.toarray(), update["shape"]) + previous 63 | -------------------------------------------------------------------------------- /plugins/README.md: -------------------------------------------------------------------------------- 1 | # git-theta Plug-ins 2 | 3 | git-theta support plugins for custom model checkpoint formats. 4 | 5 | A checkpoint plug-in should subclass `git_theta.checkpoints.Checkpoint`. The plugin 6 | should implement the `load` method which reads the checkpoint format into a dict 7 | mapping parameter names to parameter weights. It should also implement `save` which 8 | writes the original checkpoint format based on the dict of weights representation. 9 | 10 | ## Packaging a Plug-in 11 | 12 | The plug-in should be wrapped in an installable package that declares itself as a plugin using the 13 | `"git_theta.plugins.checkpoint"` entry point. The following should appear in the 14 | `setup.py` for the package. 15 | 16 | ```python 17 | setup( 18 | ..., 19 | install_requires=[ 20 | "git_theta", 21 | ..., 22 | ] 23 | entry_points={ 24 | "git_theta.plugins.checkpoints": [ 25 | "my-cool-checkpoint = package.module:MyCoolCheckpointClass", 26 | ], 27 | }, 28 | ..., 29 | ) 30 | ``` 31 | 32 | Note: user plug-in packages can have any name as long as they register the 33 | `"git_theta.plugins.checkpoints"` entry point. 34 | 35 | ## Using a Plugin 36 | 37 | Having a plug-in installed will give `git_theta` access to the checkpoint class it provides. 38 | To use this class include the CLI argument `--type my-cool-checkpoint` to the 39 | `git-theta add ...` command. 40 | -------------------------------------------------------------------------------- /plugins/json-checkpoint/README.md: -------------------------------------------------------------------------------- 1 | # git-theta Plug-in Example 2 | 3 | This json based checkpoint is an example of using the plug-in system to create 4 | a new checkpoint type for git-theta. 5 | -------------------------------------------------------------------------------- /plugins/json-checkpoint/git_theta_json_checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | """A demo plugin for reading json checkpoints with git_theta.""" 2 | -------------------------------------------------------------------------------- /plugins/json-checkpoint/git_theta_json_checkpoint/checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import io 4 | import json 5 | 6 | from git_theta import checkpoints 7 | 8 | 9 | class JSONCheckpoint(checkpoints.Checkpoint): 10 | """Class for prototyping with JSON checkpoints""" 11 | 12 | @classmethod 13 | def load(cls, checkpoint_path): 14 | """Load a checkpoint into a dict format. 15 | 16 | Parameters 17 | ---------- 18 | checkpoint_path : str or file-like object 19 | Path to a checkpoint file 20 | 21 | Returns 22 | ------- 23 | model_dict : dict 24 | Dictionary mapping parameter names to parameter values 25 | """ 26 | if isinstance(checkpoint_path, io.IOBase): 27 | return json.load(checkpoint_path) 28 | else: 29 | with open(checkpoint_path, "r") as f: 30 | return json.load(f) 31 | 32 | def save(self, checkpoint_path): 33 | """Load a checkpoint into a dict format. 34 | 35 | Parameters 36 | ---------- 37 | checkpoint_path : str or file-like object 38 | Path to write out the checkpoint file to 39 | """ 40 | if isinstance(checkpoint_path, io.IOBase): 41 | json.dump(self, checkpoint_path) 42 | else: 43 | with open(checkpoint_path, "w") as f: 44 | json.dump(self, f) 45 | -------------------------------------------------------------------------------- /plugins/json-checkpoint/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name="git_theta_json_checkpoint", 7 | description="Demo plugin for the git theta VCS.", 8 | install_requires=[ 9 | "git_theta", 10 | ], 11 | packages=["git_theta_json_checkpoint"], 12 | entry_points={ 13 | "git_theta.plugins.checkpoints": [ 14 | "json = git_theta_json_checkpoint.checkpoints:JSONCheckpoint", 15 | ] 16 | }, 17 | ) 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | include = '((\.pyi?|\.ipynb)$|bin/)' 3 | 4 | [tool.isort] 5 | profile = "black" 6 | 7 | [build-system] 8 | # Minimum requirements for the build system to execute. 9 | requires = ["setuptools", "wheel"] # PEP 508 specifications. 10 | 11 | [tool.pytest.ini_options] 12 | norecursedirs = "tests/helpers tests/end2end" 13 | testpaths = [ 14 | "tests" 15 | ] 16 | -------------------------------------------------------------------------------- /requirements-ci.txt: -------------------------------------------------------------------------------- 1 | file_or_name>=1.1.6 2 | flax 3 | GitPython>=3.1.31 4 | importlib_metadata>=4.11.1 5 | importlib_resources>=5.1.4 6 | numba>=0.56.4 7 | numpy>=1.23.5 8 | prompt_toolkit>=3.0.18 9 | pytest>=6.2.3 10 | scipy>=1.10.1 11 | setuptools>=49.2.1 12 | six>=1.15.0 13 | tensorflow>=2.12.0 14 | tensorstore>=0.1.27 15 | torch>=1.8.1 16 | typing_extensions>=4.1.1 17 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Install the git-theta package.""" 2 | 3 | import ast 4 | import itertools 5 | from pathlib import Path 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | def get_version(file_name: str, version_variable: str = "__version__") -> str: 11 | """Find the version by walking the AST to avoid duplication. 12 | 13 | Parameters 14 | ---------- 15 | file_name : str 16 | The file we are parsing to get the version string from. 17 | version_variable : str 18 | The variable name that holds the version string. 19 | 20 | Raises 21 | ------ 22 | ValueError 23 | If there was no assignment to version_variable in file_name. 24 | 25 | Returns 26 | ------- 27 | version_string : str 28 | The version string parsed from file_name_name. 29 | """ 30 | with open(file_name) as f: 31 | tree = ast.parse(f.read()) 32 | # Look at all assignment nodes that happen in the ast. If the variable 33 | # name matches the given parameter, grab the value (which will be 34 | # the version string we are looking for). 35 | for node in ast.walk(tree): 36 | if isinstance(node, ast.Assign): 37 | if node.targets[0].id == version_variable: 38 | return node.value.s 39 | raise ValueError( 40 | f"Could not find an assignment to {version_variable} " f"within '{file_name}'" 41 | ) 42 | 43 | 44 | # Packages to install for using different deep learning frameworks. 45 | frameworks_require = { 46 | "pytorch": ["torch"], 47 | "tensorflow": ["tensorflow"], 48 | "flax": ["flax"], 49 | "safetensors": ["safetensors"], 50 | } 51 | 52 | 53 | with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: 54 | LONG_DESCRIPTION = f.read() 55 | 56 | 57 | setup( 58 | name="git_theta", 59 | version=get_version("git_theta/__init__.py"), 60 | description="Version control system for machine learning model checkpoints.", 61 | long_description=LONG_DESCRIPTION, 62 | long_description_content_type="text/markdown", 63 | author="Colin Raffel", 64 | author_email="craffel@gmail.com", 65 | url="https://github.com/r-three/git-theta", 66 | packages=find_packages(), 67 | include_package_data=True, 68 | package_data={"git_theta": ["hooks/post-commit", "hooks/pre-push"]}, 69 | python_requires=">=3.8", 70 | classifiers=[ 71 | "License :: OSI Approved :: MIT License", 72 | "Programming Language :: Python", 73 | "Development Status :: 4 - Beta", 74 | "Environment :: Console", 75 | "Intended Audience :: Developers", 76 | "Intended Audience :: Science/Research", 77 | "Intended Audience :: Developers", 78 | "Topic :: Software Development :: Version Control", 79 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 80 | "Programming Language :: Python :: 3 :: Only", 81 | "Natural Language :: English", 82 | "Operating System :: MacOS :: MacOS X", 83 | "Operating System :: Microsoft :: Windows", 84 | "Operating System :: POSIX :: Linux", 85 | ], 86 | keywords="git vcs machine-learning", 87 | license="MIT", 88 | install_requires=[ 89 | "GitPython", 90 | "gitdb", 91 | "tensorstore >= 0.1.14", 92 | "file-or-name", 93 | "six", 94 | "scipy", 95 | "numba", 96 | "msgpack", 97 | 'importlib_resources; python_version < "3.9.0"', 98 | 'importlib_metadata; python_version < "3.10.0"', 99 | 'typing_extensions; python_version < "3.8.0"', 100 | "prompt_toolkit", 101 | "colorama", 102 | ], 103 | extras_require={ 104 | **frameworks_require, 105 | # Install all framework deps with the all target. 106 | "test": ["pytest"], 107 | "all": list(set(itertools.chain(*frameworks_require.values()))), 108 | "docs": ["sphinx", "numpydoc"], 109 | }, 110 | entry_points={ 111 | "console_scripts": [ 112 | "git-theta = git_theta.scripts.git_theta_cli:main", 113 | "git-theta-filter = git_theta.scripts.git_theta_filter:main", 114 | "git-theta-merge = git_theta.scripts.git_theta_merge:main", 115 | "git-theta-diff = git_theta.scripts.git_theta_diff:main", 116 | ], 117 | "git_theta.plugins.checkpoints": [ 118 | "pytorch = git_theta.checkpoints.pickled_dict_checkpoint:PickledDictCheckpoint", 119 | "pickled-dict = git_theta.checkpoints.pickled_dict_checkpoint:PickledDictCheckpoint", 120 | "tf = git_theta.checkpoints.tensorflow_checkpoint:TensorFlowCheckpoint", 121 | "tensorflow = git_theta.checkpoints.tensorflow_checkpoint:TensorFlowCheckpoint", 122 | "tensorflow-checkpoint = git_theta.checkpoints.tensorflow_checkpoint:TensorFlowCheckpoint", 123 | "tf-savedmodel = git_theta.checkpoints.tensorflow_checkpoint:TensorFlowSavedModel", 124 | "tensorflow-savedmodel = git_theta.checkpoints.tensorflow_checkpoint:TensorFlowSavedModel", 125 | "flax = git_theta.checkpoints.flax_checkpoint:FlaxCheckpoint", 126 | "flax-checkpoint = git_theta.checkpoints.flax_checkpoint:FlaxCheckpoint", 127 | "safetensors = git_theta.checkpoints.safetensors_checkpoint:SafeTensorsCheckpoint", 128 | "safetensors-checkpoint = git_theta.checkpoints.safetensors_checkpoint:SafeTensorsCheckpoint", 129 | ], 130 | "git_theta.plugins.updates": [ 131 | "dense = git_theta.updates.dense:DenseUpdate", 132 | "sparse = git_theta.updates.sparse:SparseUpdate", 133 | "low-rank = git_theta.updates.low_rank:LowRankUpdate", 134 | "ia3 = git_theta.updates.ia3:IA3Update", 135 | ], 136 | "git_theta.plugins.merges": [ 137 | "take_us = git_theta.merges.take:TakeUs", 138 | "take_them = git_theta.merges.take:TakeThem", 139 | "take_original = git_theta.merges.take:TakeOriginal", 140 | "average-ours-theirs = git_theta.merges.average:Average", 141 | "average-all = git_theta.merges.average:AverageAll", 142 | "average-ours-original = git_theta.merges.average:AverageOursOriginal", 143 | "average-theirs-original = git_theta.merges.average:AverageTheirsOriginal", 144 | "context = git_theta.merges.context:Context", 145 | ], 146 | }, 147 | ) 148 | -------------------------------------------------------------------------------- /tests/checkpoints/checkpoints_test.py: -------------------------------------------------------------------------------- 1 | """Tests for checkpoints.py""" 2 | 3 | import os 4 | 5 | import pytest 6 | 7 | from git_theta import checkpoints 8 | 9 | ENV_CHECKPOINT_TYPE = "GIT_THETA_CHECKPOINT_TYPE" 10 | 11 | torch = pytest.importorskip("torch") 12 | 13 | 14 | @pytest.fixture 15 | def env_var(): 16 | current_env = dict(os.environ) 17 | os.environ[ENV_CHECKPOINT_TYPE] = "env_variable_handler" 18 | 19 | yield 20 | os.environ.clear() 21 | os.environ.update(current_env) 22 | 23 | 24 | @pytest.fixture 25 | def no_env_var(): 26 | current_env = dict(os.environ) 27 | os.environ.pop(ENV_CHECKPOINT_TYPE, None) 28 | 29 | yield 30 | os.environ.clear() 31 | os.environ.update(current_env) 32 | 33 | 34 | @pytest.fixture 35 | def empty_env_var(): 36 | current_env = dict(os.environ) 37 | os.environ[ENV_CHECKPOINT_TYPE] = "" 38 | 39 | yield 40 | os.environ.clear() 41 | os.environ.update(current_env) 42 | 43 | 44 | def test_get_checkpoint_handler_name_user_input(env_var): 45 | """Check that function prefers user input to environment variable""" 46 | 47 | user_input = "user_input_handler" 48 | name = checkpoints.get_checkpoint_handler_name(user_input) 49 | assert name == user_input 50 | 51 | 52 | def test_get_checkpoint_handler_name_env_variable(env_var): 53 | """Check that function uses environment variable no user input specified""" 54 | 55 | name = checkpoints.get_checkpoint_handler_name() 56 | assert name == "env_variable_handler" 57 | 58 | 59 | def test_get_checkpoint_handler_name_default1(no_env_var): 60 | """Check that function has correct default behavior with no user input and environment variable""" 61 | 62 | name = checkpoints.get_checkpoint_handler_name() 63 | assert name == "pytorch" 64 | 65 | 66 | def test_get_checkpoint_handler_name_default2(empty_env_var): 67 | """Check that function has correct default behavior with no user input and environment variable is empty string""" 68 | 69 | name = checkpoints.get_checkpoint_handler_name() 70 | assert name == "pytorch" 71 | 72 | 73 | # TODO: Move this (and other pytorch checkpoint tests) to new file. Remove the 74 | # importorskip too. 75 | def test_get_checkpoint_handler_pytorch(no_env_var): 76 | """Check that checkpoint_handler type is correct for when checkpoint_handler name resolves to pytorch""" 77 | 78 | out = checkpoints.get_checkpoint_handler("pytorch") 79 | assert out == checkpoints.pickled_dict_checkpoint.PickledDictCheckpoint 80 | -------------------------------------------------------------------------------- /tests/checkpoints/safetensors_checkpoint_test.py: -------------------------------------------------------------------------------- 1 | """safetensors checkpoint tests.""" 2 | 3 | import operator as op 4 | import os 5 | 6 | import helpers 7 | import numpy as np 8 | import pytest 9 | 10 | # Skip all these tests if tensorflow is not installed 11 | safetensors = pytest.importorskip("safetensors") 12 | 13 | from git_theta import checkpoints 14 | from git_theta.checkpoints import safetensors_checkpoint 15 | 16 | 17 | @pytest.fixture 18 | def fake_model(): 19 | return { 20 | "layer1/weight": np.random.rand(1024, 1024), 21 | "layer1/bias": np.random.rand(1024), 22 | "layer2/weight": np.random.rand(512, 1024), 23 | "layer2/bias": np.random.rand(512), 24 | } 25 | 26 | 27 | def test_round_trip(fake_model): 28 | with helpers.utils.named_temporary_file() as f: 29 | ckpt = safetensors_checkpoint.SafeTensorsCheckpoint(fake_model) 30 | ckpt.save(f.name) 31 | f.flush() 32 | f.close() 33 | ckpt2 = safetensors_checkpoint.SafeTensorsCheckpoint.from_file(f.name) 34 | for (_, og), (_, new) in zip( 35 | sorted(ckpt.items(), key=op.itemgetter(0)), 36 | sorted(ckpt2.items(), key=op.itemgetter(0)), 37 | ): 38 | np.testing.assert_array_equal(og, new) 39 | 40 | 41 | def test_get_checkpoint_handler_safetensors(): 42 | for alias in ("safetensors", "safetensors-checkpoint"): 43 | out = checkpoints.get_checkpoint_handler(alias) 44 | assert out == safetensors_checkpoint.SafeTensorsCheckpoint 45 | -------------------------------------------------------------------------------- /tests/checkpoints/tensorflow_checkpoint_test.py: -------------------------------------------------------------------------------- 1 | """Tensorflow checkpoint tests.""" 2 | 3 | import os 4 | import tempfile 5 | from unittest import mock 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | # Skip all these tests if tensorflow is not installed 11 | tf = pytest.importorskip("tensorflow") 12 | 13 | from git_theta import checkpoints 14 | from git_theta.checkpoints import tensorflow_checkpoint 15 | 16 | INPUT_SIZE = 10 17 | 18 | 19 | @pytest.fixture(autouse=True) 20 | def hide_cuda(): 21 | with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": ""}): 22 | yield 23 | 24 | 25 | class InnerLayer(tf.keras.layers.Layer): 26 | def __init__(self): 27 | super().__init__(name="inner") 28 | self.dense = tf.keras.layers.Dense(5) 29 | # Add a non-trainable parameter for completeness. 30 | self.scale = tf.Variable(12.0, name="scalar", trainable=False) 31 | 32 | def call(self, x): 33 | return self.dense(x) * self.scale 34 | 35 | 36 | class DemoModel(tf.keras.Model): 37 | def __init__(self): 38 | super().__init__() 39 | self.inner = InnerLayer() 40 | self.logits = tf.keras.layers.Dense(2, name="logits") 41 | 42 | def call(self, x): 43 | return self.logits(self.inner(x)) 44 | 45 | 46 | def make_fake_model(): 47 | dm = DemoModel() 48 | dm(np.zeros((1, INPUT_SIZE))) 49 | return dm 50 | 51 | 52 | @pytest.fixture 53 | def fake_model(): 54 | return make_fake_model() 55 | 56 | 57 | @pytest.mark.xfail(reason="Changes to Tensorflow saved model need to be accounted for.") 58 | def test_round_trip(fake_model): 59 | with tempfile.NamedTemporaryFile() as f: 60 | # Make a checkpoint via tensorflow 61 | fake_model.save_weights(f.name) 62 | # Load the Checkpoint 63 | ckpt = tensorflow_checkpoint.TensorFlowCheckpoint.from_file(f.name) 64 | with tempfile.NamedTemporaryFile() as f2: 65 | # Use the git-theta save to create a new checkpoint 66 | ckpt.save(f2.name) 67 | # Load a model from they checkpoint we just saved 68 | loaded_model = make_fake_model() 69 | loaded_model.load_weights(f2.name) 70 | for og, new in zip(fake_model.variables, loaded_model.variables): 71 | np.testing.assert_array_equal(og.numpy(), new.numpy()) 72 | 73 | 74 | @pytest.mark.xfail(reason="Changes to Tensorflow saved model need to be accounted for.") 75 | def test_round_trip_with_modifications(fake_model): 76 | with tempfile.NamedTemporaryFile() as f: 77 | # Make a checkpoint via tensorflow 78 | fake_model.save_weights(f.name) 79 | # Load the Checkpoint 80 | ckpt = tensorflow_checkpoint.TensorFlowCheckpoint.from_file(f.name) 81 | # Update value 82 | ckpt["logits"]["bias"] = np.ones_like(ckpt["logits"]["bias"]) 83 | with tempfile.NamedTemporaryFile() as f2: 84 | # Use the git-theta save to create a new checkpoint 85 | ckpt.save(f2.name) 86 | # Load a model from they checkpoint we just saved 87 | loaded_model = make_fake_model() 88 | loaded_model.load_weights(f2.name) 89 | for og, new in zip(fake_model.variables, loaded_model.variables): 90 | # Check that the updated value is loaded. 91 | if "logits" in new.name and "bias" in new.name: 92 | new_numpy = new.numpy() 93 | np.testing.assert_allclose(new_numpy, np.ones(*new_numpy.shape)) 94 | else: 95 | np.testing.assert_allclose(og.numpy(), new.numpy()) 96 | 97 | 98 | def test_get_checkpoint_handler_tensorflow(): 99 | out = checkpoints.get_checkpoint_handler("tensorflow-checkpoint") 100 | assert out == tensorflow_checkpoint.TensorFlowCheckpoint 101 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Shared fixtures for running tests""" 2 | 3 | import os 4 | import random 5 | import shutil 6 | import string 7 | 8 | import git 9 | import numpy as np 10 | import pytest 11 | 12 | from git_theta import metadata, theta 13 | 14 | 15 | class DataGenerator: 16 | @staticmethod 17 | def random_oid(): 18 | return "".join([random.choice(string.hexdigits.lower()) for _ in range(64)]) 19 | 20 | @staticmethod 21 | def random_commit_hash(): 22 | return "".join([random.choice(string.hexdigits.lower()) for _ in range(40)]) 23 | 24 | @staticmethod 25 | def random_lfs_metadata(): 26 | version = random.choice(["lfs_version1", "my_version", "version1"]) 27 | oid = "".join([random.choice(string.hexdigits.lower()) for _ in range(64)]) 28 | size = str(random.randint(0, 10000)) 29 | return metadata.LfsMetadata(version=version, oid=oid, size=size) 30 | 31 | @staticmethod 32 | def random_tensor_metadata(): 33 | ndims = random.choice(range(1, 6)) 34 | shape = tuple([random.choice(range(1, 50)) for _ in range(ndims)]) 35 | tensor = np.random.rand(*shape) 36 | return metadata.TensorMetadata.from_tensor(tensor) 37 | 38 | @staticmethod 39 | def random_theta_metadata(): 40 | update_type = random.choice(["dense", "sparse"]) 41 | last_commit = "".join( 42 | [random.choice(string.hexdigits.lower()) for _ in range(40)] 43 | ) 44 | return metadata.ThetaMetadata(update_type=update_type, last_commit=last_commit) 45 | 46 | @staticmethod 47 | def random_param_metadata(): 48 | tensor_metadata = DataGenerator.random_tensor_metadata() 49 | lfs_metadata = DataGenerator.random_lfs_metadata() 50 | theta_metadata = DataGenerator.random_theta_metadata() 51 | return metadata.ParamMetadata( 52 | tensor_metadata=tensor_metadata, 53 | lfs_metadata=lfs_metadata, 54 | theta_metadata=theta_metadata, 55 | ) 56 | 57 | @staticmethod 58 | def random_nested_dict( 59 | allowed_keys=list(string.ascii_letters), allowed_values=list(range(100)) 60 | ): 61 | """Generate random nested dicts for testing.""" 62 | result = {} 63 | prev = [result] 64 | curr = result 65 | for _ in range(random.randint(20, 50)): 66 | # Pick a key 67 | key = random.choice(allowed_keys) 68 | # 50/50, do we make a new nest level? 69 | if random.choice([True, False]): 70 | curr[key] = {} 71 | prev.append(curr) 72 | curr = curr[key] 73 | continue 74 | # Otherwise, add a leaf value 75 | value = random.choice(allowed_values) 76 | curr[key] = value 77 | # 50/50 are we done adding values to this node? 78 | if random.choice([True, False]): 79 | curr = prev.pop() 80 | # If we have tried to to up the tree from the root, stop generating. 81 | if not prev: 82 | break 83 | return result 84 | 85 | @staticmethod 86 | def random_metadata(): 87 | values = [DataGenerator.random_param_metadata() for _ in range(100)] 88 | random_metadata_dict = DataGenerator.random_nested_dict(allowed_values=values) 89 | return metadata.Metadata(random_metadata_dict) 90 | 91 | @staticmethod 92 | def random_commit_info(): 93 | oids = [DataGenerator.random_oid() for _ in range(random.randint(5, 20))] 94 | return theta.CommitInfo(oids) 95 | 96 | 97 | @pytest.fixture 98 | def data_generator(): 99 | return DataGenerator 100 | 101 | 102 | @pytest.fixture 103 | def git_repo_with_commits(): 104 | commit_infos = [ 105 | DataGenerator.random_commit_info() for _ in range(random.randint(5, 20)) 106 | ] 107 | commit_hashes = [] 108 | 109 | repo_dir = ".delete-me" 110 | os.mkdir(repo_dir) 111 | try: 112 | repo = git.Repo.init(repo_dir) 113 | 114 | config_writer = repo.config_writer(config_level="repository") 115 | config_writer.set_value("user", "name", "myusername") 116 | config_writer.set_value("user", "email", "myemail") 117 | config_writer.release() 118 | 119 | theta_commits = theta.ThetaCommits(repo) 120 | 121 | # Write a bunch of empty commits and random ThetaCommits entries 122 | for commit_info in commit_infos: 123 | repo.git.commit("--allow-empty", "-m", "empty commit") 124 | commit_hash = repo.commit("HEAD").hexsha 125 | theta_commits.write_commit_info(commit_hash, commit_info) 126 | commit_hashes.append(commit_hash) 127 | 128 | yield repo, commit_hashes, commit_infos 129 | finally: 130 | repo.close() 131 | git.rmtree(repo_dir) 132 | -------------------------------------------------------------------------------- /tests/end2end/README.md: -------------------------------------------------------------------------------- 1 | # Git-Theta End-2-End tests. 2 | 3 | This directory contains a collection of end-2-end tests for git-theta. 4 | 5 | ## Running the Tests 6 | 7 | Each subdirectory represents a test that actually interacts with git. The `runner.sh` script is responsible for running them and reports if they passed for failed. All tests can be run with `./runner.sh`. Additionally, the `.github/workflows/end2endtest.yml` configures these tests to run via GitHub Actions. 8 | 9 | ## Anatomy of a Test 10 | 11 | Each test is a directory and includes a `test.sh` script. Execution of this script runs the test and pass/fail is determined by its exit code (`0` means pass), 12 | 13 | Each test also includes a `clean.sh` script. This is run before and after the test to clean up any generated artifacts. This script should never have a non-zero exit code and should be idempotent (we can run it multiple times). 14 | 15 | The command `make-test.sh "${test name}"` can be used to create the skeleton of a new test. 16 | 17 | ### `test.sh`, the (Vegan-)?Meat and Potatoes 18 | 19 | The first steps of a `test.sh` file include sourcing the `../utils.sh` so it can use some of our shared functions. Then it should run `set -e` to ensure that errors in part of the test cause the whole test to fail. It should also call `test_init`, which will create a git repo (with a `main` branch) for it and ensure that `set -e` was used. 20 | 21 | The next step is often to create and modify some model. The provided `../model.py` file can be used as a helper to create new models and updates with special forms. This script creates two copies of the model, one that lives in the path that is version controlled and one that includes information about how it was created. This path is returned on stdout, and can be captured in `test.sh`. This checkpoint is not version controlled, uses the deep-learning framework native checkpoint format, and should be removed by the clean script. 22 | 23 | The provided `verify.py` can be used to help check that version controlled models match their original values. 24 | 25 | `../utils.sh` provides a `commit` function that will create a git commit and return the hash for that commit, for example `SHA=$(commit "commit msg")` will make a commit and save the hash to `${SHA}`. It can be called with just `commit "commit msg"` if you do not want to track the hash. This should help for tests that travel through git history. 26 | 27 | > **Warning** 28 | > The `set -e` option causes a whole bash script to exit if one of the subcommands in it has error/exits with a non `0` return code. Without this setting, tests that have failing steps will often appear as passing so we require this to be active. If your test has subcommands that are expected to fail and your test accounts for that correctly, you can suppress this `-e` behavior by adding `${subcommand} || true` to your failing commands (or call `set +e` after the `test_init` call). Both of **these settings can cause tests to fail silently**, so only use them if you are sure you understand what they are doing. 29 | -------------------------------------------------------------------------------- /tests/end2end/checkout/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # checkout test: Test that we are able to checkout a git-theta model 3 | 4 | source ../utils.sh 5 | 6 | set -e 7 | 8 | test_init 9 | 10 | MODEL_SCRIPT=model.py 11 | MODEL_NAME=model.pt 12 | 13 | echo "Making model ${MODEL_NAME}" 14 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 15 | 16 | echo "Installing git-theta and tracking ${MODEL_NAME}" 17 | git theta install 18 | git theta track ${MODEL_NAME} 19 | 20 | echo "Adding ${MODEL_NAME} to git repo." 21 | git add ${MODEL_NAME} 22 | echo "Committing ${MODEL_NAME} to git repo." 23 | SHA=$(commit "first commit") 24 | echo "Initial model commit was at ${SHA}" 25 | 26 | echo "Making Dense update to ${MODEL_NAME}" 27 | DENSE_MODEL=`python ../${MODEL_SCRIPT} --action dense --seed 42 --model-name=${MODEL_NAME}` 28 | 29 | echo "Adding Dense update to git repo." 30 | git add ${MODEL_NAME} 31 | echo "Committing dense update to repo." 32 | commit "updated model" 33 | 34 | echo "Checking out initial model at ${SHA}" 35 | git checkout ${SHA} 36 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${INIT_MODEL})." 37 | python ../verify.py --old-model "${MODEL_NAME}" --new-model "${INIT_MODEL}" 38 | 39 | echo "Checking out the dense update." 40 | git checkout main 41 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${DENSE_MODEL})." 42 | python ../verify.py --old-model "${MODEL_NAME}" --new-model "${DENSE_MODEL}" 43 | 44 | green_echo "git checkout test passed!" 45 | -------------------------------------------------------------------------------- /tests/end2end/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf .git > /dev/null 2>&1 3 | rm -rf .gitignore > /dev/null 2>&1 4 | rm -rf .gitattributes > /dev/null 2>&1 5 | rm *.pt > /dev/null 2>&1 6 | -------------------------------------------------------------------------------- /tests/end2end/commit/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Test that we are able to checkout a git-theta model 3 | 4 | source ../utils.sh 5 | 6 | set -e 7 | 8 | test_init 9 | 10 | MODEL_SCRIPT=model.py 11 | MODEL_NAME=model.pt 12 | 13 | echo "Making model ${MODEL_NAME}" 14 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 15 | 16 | echo "Installing git-theta and tracking ${MODEL_NAME}" 17 | git theta install 18 | git theta track ${MODEL_NAME} 19 | 20 | echo "Adding ${MODEL_NAME} to git repo." 21 | git add ${MODEL_NAME} 22 | echo "Committing ${MODEL_NAME} to git repo." 23 | commit "first commit" 24 | 25 | echo "Comparing model (${MODEL_NAME}) to original save (${INIT_MODEL})." 26 | python ../verify.py --old-model "${MODEL_NAME}" --new-model "${INIT_MODEL}" 27 | 28 | echo "Making sure ${MODEL_NAME} is committed." 29 | FILES=$(git ls-files) 30 | if [[ ! ${FILES} =~ ${MODEL_NAME} ]]; then 31 | red_echo "${MODEL_NAME} not found in 'git ls-files'." 32 | exit 1 33 | fi 34 | 35 | green_echo "git commit test passed!" 36 | -------------------------------------------------------------------------------- /tests/end2end/ia3/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # sparse test: Test committing and checking out sparse updates with side-loaded 3 | # information. 4 | 5 | source ../utils.sh 6 | 7 | set -e 8 | 9 | test_init 10 | 11 | MODEL_SCRIPT=model.py 12 | MODEL_NAME=model.pt 13 | 14 | echo "Making model ${MODEL_NAME}" 15 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 16 | 17 | echo "Installing git-theta and tracking ${MODEL_NAME}" 18 | git theta install 19 | git theta track ${MODEL_NAME} 20 | 21 | echo "Adding ${MODEL_NAME} to git repo." 22 | git add ${MODEL_NAME} 23 | echo "Committing ${MODEL_NAME} to git repo." 24 | SHA=$(commit "first commit") 25 | 26 | echo "Making a ia3 update to ${MODEL_NAME}" 27 | IA3_MODEL=`python ../${MODEL_SCRIPT} --action ia3 --seed 42 --model-name=${MODEL_NAME}` 28 | git theta add ${MODEL_NAME} --update-type="ia3" --update-data="ia3-data.pt" 29 | IA3_SHA=$(commit "ia3 update") 30 | 31 | echo "Checking out initial model at ${SHA}" 32 | git checkout ${SHA} 33 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${INIT_MODEL})" 34 | python ../verify.py --new-model "${MODEL_NAME}" --old-model "${INIT_MODEL}" 35 | 36 | echo "Checking out the ia3 update." 37 | git checkout main 38 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${IA3_MODEL})" 39 | python ../verify.py --new-model "${MODEL_NAME}" --old-model "${IA3_MODEL}" 40 | 41 | echo "Verify that 'git status' doesn't have a diff." 42 | git diff-index --quiet HEAD -- 43 | if [[ "$?" != 0 ]]; then 44 | exit 1 45 | fi 46 | 47 | green_echo "ia3 update test passed!" 48 | -------------------------------------------------------------------------------- /tests/end2end/inprocess/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import copy 4 | import os 5 | 6 | import torch 7 | 8 | import git_theta 9 | 10 | print("loading old model") 11 | model = torch.load("og_model.pt") 12 | print("making a copy of the model with a single value change") 13 | updated_model = copy.deepcopy(model) 14 | updated_model["layers.0.hidden.weight"] = torch.rand( 15 | *updated_model["layers.0.hidden.weight"].shape 16 | ) 17 | 18 | print("committing the same model to different paths") 19 | model_1_sha = git_theta.save_to_git(model, "model_1.pt", "commit first model") 20 | model_2_sha = git_theta.save_to_git(model, "model_2.pt", "commit second model") 21 | print("committing the changed model to the same path.") 22 | model_tag = "updated-model" 23 | model_1_1_sha = git_theta.save_to_git( 24 | updated_model, 25 | "model_1.pt", 26 | "committing changed model to the same path.", 27 | tag=model_tag, 28 | ) 29 | 30 | print("Making sure the models we not saved to disk.") 31 | assert not os.path.exists("model_1.pt") 32 | assert not os.path.exists("model_2.pt") 33 | 34 | print("loading the model from git-theta directly.") 35 | m1 = git_theta.load_from_git(model_1_sha, "model_1.pt") 36 | print("loading model from a tag") 37 | m11 = git_theta.load_from_git(model_tag, "model_1.pt") 38 | m2 = git_theta.load_from_git(model_2_sha, "model_2.pt") 39 | 40 | print("saving the models to disk to inspect them later.") 41 | torch.save(m1, "should_match_1.pt") 42 | torch.save(m2, "should_match_2.pt") 43 | torch.save(m11, "no_match.pt") 44 | -------------------------------------------------------------------------------- /tests/end2end/inprocess/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # inprocess test: TODO Add description of test. 3 | 4 | source ../utils.sh 5 | 6 | set -e 7 | 8 | test_init 9 | 10 | MODEL_SCRIPT=model.py 11 | MODEL_NAME=og_model.pt 12 | 13 | echo "Making model ${MODEL_NAME}" 14 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 15 | 16 | python test.py 17 | if [[ "$?" != 0 ]]; then 18 | exit 1 19 | fi 20 | echo "Verifying that the same model saved in different paths match" 21 | python ../verify.py --old-model should_match_1.pt --new-model should_match_2.pt 22 | echo "Verifying that the changed model, which was the same path, but committed later, is different." 23 | R=$(python ../verify.py --old-model should_match_1.pt --new-model no_match.pt 2> /dev/null || true) 24 | if [[ "$R" == 0 ]]; then 25 | exit 1 26 | fi 27 | 28 | green_echo "in-process test passed!" 29 | -------------------------------------------------------------------------------- /tests/end2end/low-rank/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # low-rank test: Test committing and checking out low-rank updates with 3 | # side-loaded information. 4 | 5 | source ../utils.sh 6 | 7 | set -e 8 | 9 | test_init 10 | 11 | MODEL_SCRIPT=model.py 12 | MODEL_NAME=model.pt 13 | 14 | echo "Making model ${MODEL_NAME}" 15 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 16 | 17 | echo "Installing git-theta and tracking ${MODEL_NAME}" 18 | git theta install 19 | git theta track ${MODEL_NAME} 20 | 21 | echo "Adding ${MODEL_NAME} to git repo." 22 | git add ${MODEL_NAME} 23 | echo "Committing ${MODEL_NAME} to git repo." 24 | SHA=$(commit "first commit") 25 | 26 | echo "Making a low-rank update to ${MODEL_NAME}" 27 | LOW_RANK_MODEL=`python ../${MODEL_SCRIPT} --action low-rank --seed 42 --model-name=${MODEL_NAME}` 28 | git theta add ${MODEL_NAME} --update-type="low-rank" --update-data="low-rank-data.pt" 29 | LOW_RANK_SHA=$(commit "low-rank") 30 | 31 | echo "Checking out initial model at ${SHA}" 32 | git checkout ${SHA} 33 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${INIT_MODEL})" 34 | python ../verify.py --new-model "${MODEL_NAME}" --old-model "${INIT_MODEL}" 35 | 36 | echo "Checking out the low-rank update." 37 | git checkout main 38 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${SPARSE_MODEL})" 39 | python ../verify.py --new-model "${MODEL_NAME}" --old-model "${LOW_RANK_MODEL}" 40 | 41 | echo "Verify that 'git status' doesn't have a diff." 42 | git diff-index --quiet HEAD -- 43 | if [[ "$?" != 0 ]]; then 44 | exit 1 45 | fi 46 | 47 | green_echo "low-rank update test passed!" 48 | -------------------------------------------------------------------------------- /tests/end2end/make-test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Make a test skeleton. 3 | 4 | TEST_NAME=${1} 5 | 6 | if [[ -z ${TEST_NAME} ]]; then 7 | echo "usage: make-test.sh 'test-name'" 8 | exit 1 9 | fi 10 | 11 | echo "Making skeleton for a test named ${TEST_NAME}" 12 | mkdir ${TEST_NAME} 13 | pushd ${TEST_NAME} > /dev/null 14 | 15 | cat < test.sh 16 | #!/usr/bin/env bash 17 | # ${TEST_NAME} test: TODO Add description of test. 18 | 19 | source ../utils.sh 20 | 21 | set -e 22 | 23 | test_init 24 | EOF 25 | 26 | chmod +x test.sh 27 | -------------------------------------------------------------------------------- /tests/end2end/model.py: -------------------------------------------------------------------------------- 1 | """Utility for creating models for testing.""" 2 | 3 | import argparse 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import scipy.sparse 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | import git_theta 14 | 15 | parser = argparse.ArgumentParser(description="Model building for Integration tests.") 16 | parser.add_argument( 17 | "--action", choices=["init", "dense", "sparse", "low-rank", "ia3"], required=True 18 | ) 19 | parser.add_argument("--seed", default=1337, type=int) 20 | parser.add_argument("--model-name", default="model.pt") 21 | parser.add_argument("--previous") 22 | 23 | 24 | class TestingModel(nn.Module): 25 | """A small model for testing, weird architecture but tries to cover several pytorch paradigms.""" 26 | 27 | def __init__(self): 28 | super().__init__() 29 | self.embeddings = nn.Embedding(30, 10) 30 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False) 31 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False) 32 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.fc1 = nn.Linear(32 * 7 * 7, 120, bias=False) 34 | self.fc2 = nn.Linear(120, 84, bias=False) 35 | self.fc3 = nn.Linear(84, 10, bias=False) 36 | self.layers = nn.Sequential( 37 | TestingLayer(10, 8, 6), 38 | nn.ReLU(), 39 | TestingLayer(6, 4, 2), 40 | nn.Tanh(), 41 | nn.LogSoftmax(dim=-1), 42 | ) 43 | 44 | def __call__(self, x): 45 | x = self.pool(F.relu(self.conv1(x))) 46 | x = self.pool(F.relu(self.conv2(x))) 47 | x = torch.flatten(x, 1) # flatten all dimensions except batch 48 | x = F.relu(self.fc1(x)) 49 | x = F.relu(self.fc2(x)) 50 | x = self.fc3(x) 51 | return self.layers(x) 52 | 53 | 54 | class TestingLayer(nn.Module): 55 | def __init__(self, n_in, n_hidden, n_out): 56 | super().__init__() 57 | self.hidden = nn.Linear(n_in, n_hidden) 58 | self.output = nn.Linear(n_hidden, n_out) 59 | self.highway_like = nn.Linear(n_in, n_out) 60 | 61 | def __call__(self, x): 62 | y = self.hidden(x) 63 | y = F.relu(y) 64 | y = self.output(y) 65 | y += self.highway_like(x) 66 | return y 67 | 68 | 69 | def low_rank_update(t, rank): 70 | # TODO: Add a way to get numpy dtype from torch dtype easily. 71 | if t.ndim == 1: 72 | return np.random.uniform(size=t.shape).astype(np.float32) 73 | R = np.random.uniform(size=(*t.shape[:-1], rank)).astype(np.float32) 74 | C = np.random.uniform(size=(rank, *t.shape[-1:])).astype(np.float32) 75 | return {"A": R, "B": C} 76 | 77 | 78 | def make_ia3_update(value): 79 | ia3 = np.random.randn(*value.shape).astype(np.float32) 80 | axes = (0, -1) if ia3.ndim > 3 else (-1,) 81 | ia3 = np.mean(ia3, axis=axes, keepdims=True) 82 | return {"ia3": ia3} 83 | 84 | 85 | def main(args): 86 | # Set seeds and enable determinism 87 | random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | np.random.seed(args.seed) 90 | torch.use_deterministic_algorithms(True) 91 | 92 | file_name, ext = os.path.splitext(args.model_name) 93 | persistent_name = f"{file_name}-{args.action}-{args.seed}{ext}" 94 | if args.previous is None: 95 | args.previous = args.model_name 96 | 97 | if args.action == "init" or args.action == "dense": 98 | model = TestingModel().state_dict() 99 | torch.save(model, args.model_name) 100 | torch.save(model, persistent_name) 101 | elif args.action == "sparse": 102 | update_handler = git_theta.updates.get_update_handler("sparse") 103 | previous = torch.load(args.previous) 104 | # Create a new version of the model and call it the "sparse" update 105 | sparse = TestingModel().state_dict() 106 | # Combine the sparse update and the old values to create the "new" model 107 | with_sparse = {name: value + previous[name] for name, value in sparse.items()} 108 | # Save the combined model to the persistent location for comparisons. 109 | torch.save(with_sparse, persistent_name) 110 | # Convert the sparse update into a sparse format. 111 | sparse_update = {} 112 | for name, value in sparse.items(): 113 | update = update_handler.format_update(value.numpy()) 114 | for k, v in update.items(): 115 | sparse_update[f"{name}/{k}"] = torch.tensor(v) 116 | torch.save(sparse_update, "sparse-data.pt") 117 | 118 | elif args.action == "low-rank": 119 | previous = torch.load(args.previous) 120 | low_rank = 2 121 | lr_update = { 122 | name: low_rank_update(value, low_rank) for name, value in previous.items() 123 | } 124 | update_handler = git_theta.updates.get_update_handler("low-rank") 125 | # Just the low-rank data 126 | update_data = {} 127 | # The updated parameter values 128 | new_model = {} 129 | for name, update in lr_update.items(): 130 | if isinstance(update, dict): 131 | # Formatter is simple, just assigned param1 and param2 to R and 132 | # C respectivly. 133 | update = update_handler.format_update(update["A"], update["B"]) 134 | for k, v in update.items(): 135 | update_data[f"{name}/{k}"] = torch.tensor(v) 136 | new_model[name] = previous[name] + update["R"] @ update["C"] 137 | else: 138 | new_model[name] = previous[name] + update 139 | previous[name] = previous[name] + update 140 | # Checkpoint with the non-low-rank changes added 141 | torch.save(previous, args.model_name) 142 | # Checkpoint with the low-rank updates 143 | torch.save(update_data, "low-rank-data.pt") 144 | # Checkpoint with all changes added 145 | torch.save(new_model, persistent_name) 146 | 147 | elif args.action == "ia3": 148 | previous = torch.load(args.previous) 149 | ia3_update = {name: make_ia3_update(value) for name, value in previous.items()} 150 | update_handler = git_theta.updates.get_update_handler("ia3") 151 | # Just the ia3 data 152 | update_data = {} 153 | # The updated parameter values 154 | new_model = {} 155 | for name, update in ia3_update.items(): 156 | # Formatter is simple, just assign param to ia3 157 | update = update_handler.format_update(update["ia3"]) 158 | for k, v in update.items(): 159 | update_data[f"{name}/{k}"] = torch.tensor(v) 160 | new_model[name] = previous[name] * update["ia3"] 161 | # Save the new model 162 | torch.save(new_model, persistent_name) 163 | # save the ia3 data 164 | torch.save(update_data, "ia3-data.pt") 165 | 166 | print(persistent_name) 167 | 168 | 169 | if __name__ == "__main__": 170 | args = parser.parse_args() 171 | main(args) 172 | -------------------------------------------------------------------------------- /tests/end2end/runner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source ./utils.sh 4 | 5 | # Create an associated array mapping test name to exit code. 6 | declare -A TESTS 7 | # Map testname to the number of times the test was run. 8 | declare -A RUNS 9 | 10 | TRIALS=3 11 | 12 | # Each sub-directory in the current directory is a test. 13 | for test in ./*/ 14 | do 15 | # Remove things like dir / from the test name 16 | testname="${test%*/}" 17 | testname="${testname#./}" 18 | yellow_echo "Running Test: ${testname}" 19 | echo "=================================================================" 20 | # Move into the test dir 21 | pushd ${testname} 22 | # Run the test up to ${TRIALS} times, stopping when it has a return of 0 23 | i=0 24 | # This is just a non-zero value, just to get us into the loop, it doesn't 25 | # matter what the value is as it will be overwritten by the test return code. 26 | return_code=-1 27 | while [[ "${i}" < "${TRIALS}" && ${return_code} != 0 ]]; do 28 | # If there is a local clean script run that, otherwise run the global one. 29 | # This is in the loop to ensure it is cleaned before each attempt. 30 | if [[ -f ./clean.sh ]]; then 31 | ./clean.sh 32 | else 33 | ../clean.sh 34 | fi 35 | if [[ "${i}" > 0 ]]; then 36 | red_echo "${testname} failed, running trial $((i + 1))" 37 | fi 38 | ./test.sh 39 | return_code="${?}" 40 | i=$((i + 1)) 41 | done 42 | # Save the tests return value 43 | TESTS[${testname}]="${return_code}" 44 | # Save the number of times a test had to be run. 45 | RUNS[${testname}]="${i}" 46 | # Cleanup after the test, again, look for a local clean and use if found. 47 | if [[ -f ./clean.sh ]]; then 48 | ./clean.sh 49 | else 50 | ../clean.sh 51 | fi 52 | popd 53 | done 54 | 55 | FAILED=0 56 | echo "Test Summary:" 57 | for test in "${!TESTS[@]}" 58 | do 59 | if [[ "${TESTS[$test]}" == 0 ]]; then 60 | green_echo "${test}" "n" 61 | # Check if we had to re-run tests. 62 | if [[ "${RUNS[$test]}" != 1 ]]; then 63 | yellow_echo " (Had to run test '${test}' ${RUNS[$test]} times to pass)." 64 | else 65 | echo 66 | fi 67 | else 68 | red_echo "${test}" 69 | fi 70 | # Passes tests have return values of 0. Summing all passed test results in 71 | # a 0 return value for the whole running. If one of the tests has a non-zero 72 | # return value, the runner will have a non-zero value (return value beyond 73 | # non-zero is not meaningful). 74 | FAILED+="${TESTS[$test]}" 75 | done 76 | 77 | exit "${FAILED}" 78 | -------------------------------------------------------------------------------- /tests/end2end/smudge/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf .git > /dev/null 2>&1 3 | rm -rf .gitattributes > /dev/null 2>&1 4 | rm -rf .gitignore > /dev/null 2>&1 5 | rm *.pt > /dev/null 2>&1 6 | rm *.json > /dev/null 7 | -------------------------------------------------------------------------------- /tests/end2end/smudge/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # smudge test: This tests that a metadata file (produced by the clean filter) 3 | # can be used directly with the smudge filter to re-create the checkpoint, 4 | # regardless of which commit is currently checked out. 5 | # 6 | # This behavior makes some parts of merging easier so we want to ensure that 7 | # it holds. Additionally, this property makes working with checkpoints/git 8 | # in git theta easier. 9 | 10 | source ../utils.sh 11 | 12 | set -e 13 | 14 | test_init 15 | 16 | MODEL_SCRIPT=model.py 17 | MODEL_NAME=model.pt 18 | 19 | echo "Making model ${MODEL_NAME}" 20 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 21 | 22 | echo "Installing git-theta and tracking ${MODEL_NAME}" 23 | git theta install 24 | git theta track ${MODEL_NAME} 25 | 26 | echo "Adding ${MODEL_NAME} to git repo." 27 | git add ${MODEL_NAME} 28 | echo "Committing ${MODEL_NAME} to git repo." 29 | SHA=$(commit "first commit") 30 | echo "Initial model commit was at ${SHA}" 31 | 32 | git show ${SHA}:${MODEL_NAME} > init-metadata.json 33 | 34 | echo "Making Sparse Update to ${MODEL_NAME}" 35 | SPARSE_MODEL=`python ../${MODEL_SCRIPT} --action sparse --seed 1234 --model-name=${MODEL_NAME}` 36 | echo "Adding Sparse Update to the git repo." 37 | git theta add ${MODEL_NAME} --update-type=sparse --update-data=sparse-data.pt 38 | echo "Commiting sparser update to repo." 39 | SPARSE_SHA=$(commit "sparse update") 40 | 41 | git show ${SPARSE_SHA}:${MODEL_NAME} > sparse-metadata.json 42 | 43 | echo "Making Dense update to ${MODEL_NAME}" 44 | DENSE_MODEL=`python ../${MODEL_SCRIPT} --action dense --seed 42 --model-name=${MODEL_NAME}` 45 | 46 | echo "Adding Dense update to git repo." 47 | git add ${MODEL_NAME} 48 | echo "Committing dense update to repo." 49 | commit "updated model" 50 | 51 | git show HEAD:${MODEL_NAME} > dense-metadata.json 52 | 53 | 54 | echo "Verifying smudges at the dense update commit." 55 | echo "Verifying dense model" 56 | git-theta-filter smudge ${MODEL_NAME} < dense-metadata.json > dense-model.pt 57 | python ../verify.py --old-model ${DENSE_MODEL} --new-model dense-model.pt 58 | echo "Verifying sparse model" 59 | git-theta-filter smudge ${MODEL_NAME} < sparse-metadata.json > sparse-model.pt 60 | python ../verify.py --old-model ${SPARSE_MODEL} --new-model sparse-model.pt 61 | echo "Verifying initial model" 62 | git-theta-filter smudge ${MODEL_NAME} < init-metadata.json > init-model.pt 63 | python ../verify.py --old-model ${INIT_MODEL} --new-model init-model.pt 64 | 65 | git checkout ${SPARSE_SHA} 66 | echo "Verifying smudges at the sparse update commit." 67 | echo "Verifying dense model" 68 | git-theta-filter smudge ${MODEL_NAME} < dense-metadata.json > dense-model.pt 69 | python ../verify.py --old-model ${DENSE_MODEL} --new-model dense-model.pt 70 | echo "Verifying sparse model" 71 | git-theta-filter smudge ${MODEL_NAME} < sparse-metadata.json > sparse-model.pt 72 | python ../verify.py --old-model ${SPARSE_MODEL} --new-model sparse-model.pt 73 | echo "Verifying initial model" 74 | git-theta-filter smudge ${MODEL_NAME} < init-metadata.json > init-model.pt 75 | python ../verify.py --old-model ${INIT_MODEL} --new-model init-model.pt 76 | 77 | git checkout ${SHA} 78 | echo "Verifying smudges at the initial commit." 79 | echo "Verifying dense model" 80 | git-theta-filter smudge ${MODEL_NAME} < dense-metadata.json > dense-model.pt 81 | python ../verify.py --old-model ${DENSE_MODEL} --new-model dense-model.pt 82 | echo "Verifying sparse model" 83 | git-theta-filter smudge ${MODEL_NAME} < sparse-metadata.json > sparse-model.pt 84 | python ../verify.py --old-model ${SPARSE_MODEL} --new-model sparse-model.pt 85 | echo "Verifying initial model" 86 | git-theta-filter smudge ${MODEL_NAME} < init-metadata.json > init-model.pt 87 | python ../verify.py --old-model ${INIT_MODEL} --new-model init-model.pt 88 | 89 | green_echo "Smudging at various commits passed!" 90 | -------------------------------------------------------------------------------- /tests/end2end/sparse/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf .git > /dev/null 2>&1 3 | rm -rf .gitignore > /dev/null 2>&1 4 | rm -rf .gitattributes > /dev/null 2>&1 5 | rm *.pt > /dev/null 2>&1 6 | -------------------------------------------------------------------------------- /tests/end2end/sparse/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # sparse test: Test committing and checking out sparse updates with side-loaded 3 | # information. 4 | 5 | source ../utils.sh 6 | 7 | set -e 8 | 9 | test_init 10 | 11 | MODEL_SCRIPT=model.py 12 | MODEL_NAME=model.pt 13 | 14 | echo "Making model ${MODEL_NAME}" 15 | INIT_MODEL=`python ../${MODEL_SCRIPT} --action init --seed 1337 --model-name=${MODEL_NAME}` 16 | 17 | echo "Installing git-theta and tracking ${MODEL_NAME}" 18 | git theta install 19 | git theta track ${MODEL_NAME} 20 | 21 | echo "Adding ${MODEL_NAME} to git repo." 22 | git add ${MODEL_NAME} 23 | echo "Committing ${MODEL_NAME} to git repo." 24 | SHA=$(commit "first commit") 25 | 26 | echo "Making a sparse update to ${MODEL_NAME}" 27 | SPARSE_MODEL=`python ../${MODEL_SCRIPT} --action sparse --seed 42 --model-name=${MODEL_NAME}` 28 | git theta add ${MODEL_NAME} --update-type="sparse" --update-data="sparse-data.pt" 29 | SPARSE_SHA=$(commit "sparse") 30 | 31 | echo "Checking out initial model at ${SHA}" 32 | git checkout ${SHA} 33 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${INIT_MODEL})" 34 | python ../verify.py --new-model "${MODEL_NAME}" --old-model "${INIT_MODEL}" 35 | 36 | echo "Checking out the sparse update." 37 | git checkout main 38 | echo "Comparing checked out model (${MODEL_NAME}) to original save (${SPARSE_MODEL})" 39 | python ../verify.py --new-model "${MODEL_NAME}" --old-model "${SPARSE_MODEL}" 40 | 41 | echo "Verify that 'git status' doesn't have a diff." 42 | git diff-index --quiet HEAD -- 43 | if [[ "$?" != 0 ]]; then 44 | exit 1 45 | fi 46 | 47 | green_echo "sparse update test passed!" 48 | -------------------------------------------------------------------------------- /tests/end2end/utils.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | GREEN="\033[0;32m" 5 | RED="\033[0;31m" 6 | YELLOW="\033[0;33m" 7 | NORMAL="\033[0m" 8 | 9 | function color_echo { 10 | local text="${1}" 11 | local color="${2}" 12 | local newline="${3}" 13 | # If we provide a 3rd argument, don't include the newline, this lets us 14 | # do two colors on one line easily. 15 | if [[ -z "${newline}" ]]; then 16 | echo -e "${color}${text}${NORMAL}" 17 | else 18 | echo -en "${color}${text}${NORMAL}" 19 | fi 20 | } 21 | 22 | function green_echo { 23 | local text="${1}" 24 | local newline="${2}" 25 | color_echo "${text}" "${GREEN}" "${newline}" 26 | } 27 | 28 | function red_echo { 29 | local text="${1}" 30 | local newline="${2}" 31 | color_echo "${text}" "${RED}" "${newline}" 32 | } 33 | 34 | function yellow_echo { 35 | local text="${1}" 36 | local newline="${2}" 37 | color_echo "${text}" "${YELLOW}" "${newline}" 38 | } 39 | 40 | function make_repo { 41 | echo "Making Git Repo." 42 | git init 2> /dev/null 43 | git branch -m main 44 | # Set the git user/email for the generated test repo 45 | git config --local user.email "git-theta-tester@example.com" 46 | git config --local user.name "Git Theta Tester" 47 | } 48 | 49 | function test_init { 50 | make_repo 51 | if [[ ! ${-} =~ e ]]; then 52 | red_echo "It seems that 'set -e' was not done in this test. This makes it easy for a test to fail but appear to pass. Please add 'set -e' and do specific error handling if a part of your test is allowed to fail." 53 | exit 1 54 | fi 55 | } 56 | 57 | function commit { 58 | local msg="${1}" 59 | git commit -m "${msg}" > /dev/null 60 | echo $(git rev-parse HEAD) 61 | } 62 | -------------------------------------------------------------------------------- /tests/end2end/verify.py: -------------------------------------------------------------------------------- 1 | """Tool to verify that checkpoints match.""" 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | import torch 7 | 8 | parser = argparse.ArgumentParser(description="Compare checkpoints for testing.") 9 | parser.add_argument("--new-model", help="The path to the new model.") 10 | parser.add_argument( 11 | "--old-model", help="The path to the model we are comparing aganist." 12 | ) 13 | parser.add_argument("--compare") 14 | 15 | 16 | def get_compare_function(compare): 17 | """Eventually add configurable comparison functions?""" 18 | 19 | def _cmp(a, b): 20 | return np.array_equal(a.numpy(), b.numpy()) 21 | 22 | return _cmp 23 | 24 | 25 | def main(args): 26 | old = torch.load(args.old_model) 27 | new = torch.load(args.new_model) 28 | 29 | compare = get_compare_function(args.compare) 30 | 31 | if old.keys() != new.keys(): 32 | raise ValueError( 33 | f"Parameter keys differ. Got: {args.old_model} -> " 34 | f"{sorted(old.keys())} and {args.new_model} -> {sorted(new.keys())}" 35 | ) 36 | 37 | mismatched = set() 38 | for name, value in new.items(): 39 | if not compare(value, old[name]): 40 | mismatched.add(name) 41 | 42 | if mismatched: 43 | raise ValueError(f"Parameters: {sorted(mismatched)} differ between models.") 44 | 45 | 46 | if __name__ == "__main__": 47 | args = parser.parse_args() 48 | main(args) 49 | -------------------------------------------------------------------------------- /tests/git_utils_test.py: -------------------------------------------------------------------------------- 1 | """Tests for git_utils.py""" 2 | 3 | import os 4 | 5 | import pytest 6 | 7 | from git_theta import git_utils 8 | 9 | 10 | def test_add_theta_gitattributes_empty_file(): 11 | assert list(map(str, git_utils.add_theta_to_gitattributes([], "example"))) == [ 12 | "example filter=theta merge=theta diff=theta" 13 | ] 14 | 15 | 16 | def test_is_theta_tracked_with_override(): 17 | attrs = [ 18 | git_utils.parse_gitattributes(a) 19 | for a in ( 20 | "mymodel.pt filter=theta merge=theta diff=theta", 21 | "*.pt filter=theta merge=theta diff=theta", 22 | ) 23 | ] 24 | print(attrs) 25 | assert git_utils.is_theta_tracked("mymodel.pt", attrs) 26 | 27 | 28 | def test_is_theta_tracked_with_override_false(): 29 | attrs = [ 30 | git_utils.parse_gitattributes(a) 31 | for a in ( 32 | "mymodel.pt filter=theta merge=theta diff=theta", 33 | "*.pt filter=lfs merge=theta diff=lfs", 34 | ) 35 | ] 36 | assert git_utils.is_theta_tracked("mymodel.pt", attrs) == False 37 | 38 | 39 | def test_is_theta_tracked_no_lines(): 40 | assert git_utils.is_theta_tracked("mymodel.pt", []) == False 41 | 42 | 43 | def test_is_theta_tracked_no_attrs(): 44 | assert ( 45 | git_utils.is_theta_tracked( 46 | "mymodel.pt", [git_utils.parse_gitattributes("mymodel.pt")] 47 | ) 48 | == False 49 | ) 50 | 51 | 52 | def test_is_theta_tracked_with_following_filter(): 53 | attrs = [ 54 | git_utils.parse_gitattributes(a) 55 | for a in ( 56 | "mymodel.pt filter=theta merge=theta diff=theta", 57 | "*.pt filter=theta filter=lfs merge=lfs diff=lfs", 58 | ) 59 | ] 60 | assert git_utils.is_theta_tracked("mymodel.pt", attrs) == False 61 | 62 | 63 | def test_parse_gitattributes_uses_last(): 64 | attr = git_utils.parse_gitattributes("example.txt merge=theta merge=wrong") 65 | assert attr.attributes["merge"] == "wrong" 66 | 67 | 68 | def test_parse_gitattributes_no_equal(): 69 | s = "example.txt thing" 70 | attr = git_utils.parse_gitattributes(s) 71 | assert attr.attributes == {"thing": None} 72 | assert str(attr) == s 73 | 74 | 75 | def test_parse_gitattributes_raw_string(): 76 | og_string = "example.txt merge=theta merge=wrong" 77 | attr = git_utils.parse_gitattributes(og_string) 78 | assert str(attr) == og_string 79 | attr.raw = None 80 | assert str(attr) == "example.txt merge=wrong" 81 | 82 | 83 | def test_add_theta_gitattributes_no_match(): 84 | # Should add a new path 85 | atts = [ 86 | git_utils.parse_gitattributes(a) 87 | for a in ( 88 | "Some-other-path filter=lfs", 89 | "*-cool-models.pt filter=theta merge=theta diff=theta", 90 | ) 91 | ] 92 | model_path = "path/to/my/model.pt" 93 | assert ( 94 | str(git_utils.add_theta_to_gitattributes(atts, model_path)[-1]) 95 | == f"{model_path} filter=theta merge=theta diff=theta" 96 | ) 97 | 98 | 99 | def test_add_theta_gitattributes_exact_match_with_conflicting_attributes(): 100 | model_path = "really/cool/model/yall.ckpt" 101 | atts = [git_utils.parse_gitattributes(f"{model_path} filter=lfs")] 102 | with pytest.raises(ValueError): 103 | new_attributes = git_utils.add_theta_to_gitattributes(atts, model_path) 104 | 105 | 106 | def test_add_theta_gitattributes_pattern_match_with_conflicting_attributes(): 107 | model_path = "literal-the-best-checkpoint.pt" 108 | atts = [git_utils.parse_gitattributes("*.pt thing merge=lfs")] 109 | with pytest.raises(ValueError): 110 | new_attributes = git_utils.add_theta_to_gitattributes(atts, model_path) 111 | 112 | 113 | def test_add_theta_gitattributes_exact_match_disjoint_attributes(): 114 | # Should create a new attribute with values copied over 115 | model_path = "my-test_model" 116 | atts = [ 117 | git_utils.parse_gitattributes(a) 118 | for a in ("my-test_model merge=theta diff=theta banana=fruit",) 119 | ] 120 | new_att = git_utils.add_theta_to_gitattributes(atts, model_path)[-1] 121 | assert new_att.attributes["banana"] == "fruit" 122 | assert new_att.attributes["filter"] == "theta" 123 | 124 | 125 | def test_add_theta_gitattributes_pattern_disjoint_attributes(): 126 | # Should create a new attribute with values copied over 127 | model_path = "my-test_model" 128 | atts = [ 129 | git_utils.parse_gitattributes(a) 130 | for a in ("my-test* merge=theta diff=theta banana=fruit",) 131 | ] 132 | new_att = git_utils.add_theta_to_gitattributes(atts, model_path)[-1] 133 | assert new_att.pattern == model_path 134 | assert new_att.attributes["banana"] == "fruit" 135 | assert new_att.attributes["filter"] == "theta" 136 | 137 | 138 | def test_add_theta_gitattributes_disjoint_attributes_multiple_matches(): 139 | # Should create a new attribute with values copied over 140 | model_path = "100-on-mnist.npy" 141 | atts = [ 142 | git_utils.parse_gitattributes(a) 143 | for a in ("*.npy other-filter", f"{model_path} target-filter") 144 | ] 145 | new_attributes = git_utils.add_theta_to_gitattributes(atts, model_path) 146 | # Note: target-filter is expected rather than other-filter because the *last* 147 | # filter in the file is the active one. 148 | assert ( 149 | str(new_attributes[-1]) 150 | == f"{model_path} target-filter filter=theta merge=theta diff=theta" 151 | ) 152 | 153 | 154 | def test_add_theta_gitattributes_match_with_theta_already(): 155 | # Should be a no-op 156 | model_path = "my-bad-model.chkp" 157 | atts = [ 158 | git_utils.parse_gitattributes(a) 159 | for a in ( 160 | "my-*-model.chkp filter=theta merge=theta diff=theta", 161 | "example.txt thing", 162 | ) 163 | ] 164 | new_attributes = git_utils.add_theta_to_gitattributes(atts, model_path) 165 | assert new_attributes == atts 166 | 167 | 168 | # This should fail until unsetting attributes are handled. 169 | @pytest.mark.xfail 170 | def test_add_theta_gitattributes_unset_diff(): 171 | # The attribute represention (the dict) my change when unsetting is implemented. 172 | attr = git_utils.GitAttributes("example.pt", {"-diff": None}) 173 | with pytest.raises(ValueError): 174 | new_attributes = git_utils.add_theta_to_gitattributes([attr], "example.pt") 175 | 176 | 177 | def test_add_theta_gitattributes_rest_unchanged(): 178 | model_path = "model-v3.pt" 179 | atts = [ 180 | git_utils.parse_gitattributes(a) 181 | for a in ( 182 | "some-other-path filter=theta merge=theta diff=theta", 183 | "really-reaaaally-big-files filter=lfs", 184 | r"model-v\d.pt filter", 185 | "another filter=theta merge=theta diff=theta", 186 | ) 187 | ] 188 | results = git_utils.add_theta_to_gitattributes(atts, model_path) 189 | for i, (a, r) in enumerate(zip(atts, results)): 190 | if i == 2: 191 | continue 192 | assert a == r 193 | 194 | 195 | @pytest.fixture 196 | def gitattributes(): 197 | return [ 198 | "*.pt filter=theta merge=theta diff=theta", 199 | "*.png filter=lfs", 200 | "really-big-file filter=lfs", 201 | "something else", 202 | ], [ 203 | git_utils.GitAttributes( 204 | "*.pt", {"filter": "theta", "merge": "theta", "diff": "theta"} 205 | ), 206 | git_utils.GitAttributes("*.png", {"filter": "lfs"}), 207 | git_utils.GitAttributes("really-big-file", {"filter": "lfs"}), 208 | git_utils.GitAttributes("something", {"else": None}), 209 | ] 210 | 211 | 212 | def test_read_gitattributes(gitattributes, tmp_path): 213 | """Test that reading gitattributes removes newlines.""" 214 | attributes_text, gitattributes = gitattributes 215 | gitattributes_file = tmp_path / ".gitattributes" 216 | with open(gitattributes_file, "w") as wf: 217 | wf.write("\n".join(attributes_text)) 218 | read_attributes = git_utils.read_gitattributes(gitattributes_file) 219 | assert read_attributes == gitattributes 220 | 221 | 222 | def test_read_gitattributes_missing_file(tmp_path): 223 | """Test that gitattributes file missing returns an empty list.""" 224 | missing_file = tmp_path / ".gitattributes" 225 | assert not os.path.exists(missing_file) 226 | read_attributes = git_utils.read_gitattributes(missing_file) 227 | assert read_attributes == [] 228 | 229 | 230 | def test_read_gitattributes_empty_file(tmp_path): 231 | """Test that gitattributes file being empty returns an empty list.""" 232 | empty_file = tmp_path / ".gitattributes" 233 | empty_file.touch() 234 | assert os.path.exists(empty_file) 235 | read_attributes = git_utils.read_gitattributes(empty_file) 236 | assert read_attributes == [] 237 | 238 | 239 | def test_write_gitattributes(gitattributes, tmp_path): 240 | """Test that attributes are written to file unchanged and include newlines.""" 241 | gitattributes = gitattributes[0] 242 | attr_file = tmp_path / ".gitattributes" 243 | for attr in gitattributes: 244 | assert not attr.endswith("\n") 245 | git_utils.write_gitattributes(attr_file, gitattributes) 246 | with open(attr_file) as wf: 247 | written_attrs = wf.readlines() 248 | # Check for the newlines which I purposely left on with my reading code. 249 | for written_attr, attr in zip(written_attrs, gitattributes): 250 | assert written_attr == f"{attr}\n" 251 | 252 | 253 | def test_write_gitattributes_ends_in_newline(gitattributes, tmp_path): 254 | """Make sure we have a final newline when writing out file.""" 255 | attr_file = tmp_path / ".gitattributes" 256 | git_utils.write_gitattributes(attr_file, gitattributes) 257 | with open(attr_file) as f: 258 | attrs = f.read() 259 | assert attrs[-1] == "\n" 260 | 261 | 262 | def test_write_gitattributes_creates_file(gitattributes, tmp_path): 263 | """Make sure writing the git attributes can create the missing file before writing.""" 264 | attr_file = tmp_path / ".gitattributes" 265 | assert not os.path.exists(attr_file) 266 | git_utils.write_gitattributes(attr_file, gitattributes) 267 | assert os.path.exists(attr_file) 268 | 269 | 270 | def test_read_write_gitattributes_write_read_round_trip(gitattributes, tmp_path): 271 | """Test that we can write attributes, then read them back and they will match.""" 272 | attributes_text, gitattributes = gitattributes 273 | attr_file = tmp_path / ".gitattributes" 274 | git_utils.write_gitattributes(attr_file, attributes_text) 275 | read_attrs = git_utils.read_gitattributes(attr_file) 276 | assert read_attrs == gitattributes 277 | 278 | 279 | def test_read_write_gitattributes_read_write_round_trip(gitattributes, tmp_path): 280 | """Test reading attrs from file, writing to new file and verify file contents match.""" 281 | attr_file = tmp_path / ".gitattributes" 282 | with open(attr_file, "w") as wf: 283 | wf.writelines([f"{attr}\n" for attr in gitattributes]) 284 | 285 | new_attr_file = tmp_path / ".gitattributes-2" 286 | read_attrs = git_utils.read_gitattributes(attr_file) 287 | git_utils.write_gitattributes(new_attr_file, read_attrs) 288 | 289 | with open(attr_file) as old_f: 290 | old_atts = old_f.read() 291 | with open(new_attr_file) as new_f: 292 | new_atts = new_f.read() 293 | 294 | assert old_atts == new_atts 295 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /tests/helpers/utils.py: -------------------------------------------------------------------------------- 1 | """Helper utilities for unittests.""" 2 | 3 | import contextlib 4 | import os 5 | import tempfile 6 | 7 | 8 | @contextlib.contextmanager 9 | def named_temporary_file(**kwargs): 10 | """A named temp file that is safe to use on windows. 11 | 12 | When using this function to create a named tempfile, it is safe to call 13 | operations like `.flush()` and `.close()` on the tempfile which is needed 14 | on Windows. Like the normal tempfile context manager, the file is removed 15 | automatically when you exit the `with` scope. 16 | """ 17 | # We force these so remove them. 18 | m = kwargs.pop("mode", None) 19 | if m is not None: 20 | raise RuntimeError( 21 | f"'mode' argument should not be provided to 'named_temporary_file', got {m}." 22 | ) 23 | d = kwargs.pop("delete", None) 24 | if d is not None: 25 | raise RuntimeError( 26 | f"'delete' argument should not be provided to 'named_temporary_file', got {d}." 27 | ) 28 | with tempfile.NamedTemporaryFile(mode="w", delete=False, **kwargs) as f: 29 | try: 30 | yield f 31 | finally: 32 | os.unlink(f.name) 33 | -------------------------------------------------------------------------------- /tests/metadata_test.py: -------------------------------------------------------------------------------- 1 | """Tests for metadata.py""" 2 | 3 | import os 4 | 5 | import helpers 6 | import numpy as np 7 | import pytest 8 | 9 | from git_theta import metadata 10 | 11 | 12 | def metadata_equal(m1, m2): 13 | m1_flat = m1.flatten() 14 | m2_flat = m2.flatten() 15 | if m1_flat.keys() != m2_flat.keys(): 16 | return False 17 | for k, m1_v in m1_flat.items(): 18 | if m1_v != m2_flat[k]: 19 | return False 20 | return True 21 | 22 | 23 | def test_lfs_pointer(data_generator): 24 | """ 25 | Test LfsMetadata creates and reads LFS pointers correctly 26 | """ 27 | lfs_metadata1 = data_generator.random_lfs_metadata() 28 | pointer_contents = lfs_metadata1.lfs_pointer 29 | lfs_metadata2 = metadata.LfsMetadata.from_pointer(pointer_contents) 30 | assert lfs_metadata1 == lfs_metadata2 31 | 32 | 33 | # TODO: This test will sometimes fail due to the current TensorMetadata equality check. Fix this eventually. 34 | @pytest.mark.xfail 35 | def test_tensor_metadata_machine_epsilon(): 36 | """ 37 | Test that TensorMetadata objects made from tensors with difference within machine epsilon are equal to one another 38 | """ 39 | tensor1 = np.random.rand(10, 10) 40 | tensor2 = ( 41 | tensor1 42 | + (2 * np.random.choice(2, tensor1.shape) - 1) * np.finfo(np.float32).eps 43 | ) 44 | tensor_metadata1 = metadata.TensorMetadata.from_tensor(tensor1) 45 | tensor_metadata2 = metadata.TensorMetadata.from_tensor(tensor2) 46 | assert tensor_metadata1 == tensor_metadata2 47 | 48 | 49 | def test_param_metadata_roundtrip(data_generator): 50 | """ 51 | Test that ParamMetadata serializes to dict and can be generated from dict correctly 52 | """ 53 | param_metadata = data_generator.random_param_metadata() 54 | metadata_dict = param_metadata.serialize() 55 | param_metadata_roundtrip = metadata.ParamMetadata.from_metadata_dict(metadata_dict) 56 | 57 | assert param_metadata == param_metadata_roundtrip 58 | 59 | 60 | def test_metadata_dict_roundtrip(data_generator): 61 | """ 62 | Test that Metadata serializes to dict and can be generated from dict correctly 63 | """ 64 | metadata_obj = data_generator.random_metadata() 65 | metadata_dict = metadata_obj.serialize() 66 | metadata_roundtrip = metadata.Metadata.from_metadata_dict(metadata_dict) 67 | assert metadata_equal(metadata_obj, metadata_roundtrip) 68 | 69 | 70 | def test_metadata_file_roundtrip(data_generator): 71 | """ 72 | Test that Metadata serializes to file and can be generated from file correctly 73 | """ 74 | metadata_obj = data_generator.random_metadata() 75 | with helpers.utils.named_temporary_file() as tmp: 76 | metadata_obj.write(tmp) 77 | tmp.flush() 78 | tmp.close() 79 | metadata_roundtrip = metadata.Metadata.from_file(tmp.name) 80 | assert metadata_equal(metadata_obj, metadata_roundtrip) 81 | 82 | 83 | def test_metadata_flatten(data_generator): 84 | """ 85 | Test that Metadata flattens and unflattens correctly 86 | """ 87 | metadata_obj = data_generator.random_metadata() 88 | metadata_obj_flat = metadata_obj.flatten() 89 | metadata_obj_unflat = metadata_obj_flat.unflatten() 90 | assert metadata_equal(metadata_obj, metadata_obj_unflat) 91 | -------------------------------------------------------------------------------- /tests/params_test.py: -------------------------------------------------------------------------------- 1 | """Tests for params.py""" 2 | 3 | import asyncio 4 | import random 5 | 6 | import numpy as np 7 | import pytest 8 | 9 | from git_theta import params 10 | 11 | 12 | def test_tensorstore_serializer_roundtrip(): 13 | """ 14 | Test TensorStoreSerializer serializes and deserializes correctly 15 | """ 16 | serializer = params.TensorStoreSerializer() 17 | for num_dims in range(1, 6): 18 | shape = tuple(np.random.randint(1, 20, size=num_dims).tolist()) 19 | t = np.random.rand(*shape) 20 | serialized_t = asyncio.run(serializer.serialize(t)) 21 | deserialized_t = asyncio.run(serializer.deserialize(serialized_t)) 22 | np.testing.assert_array_equal(t, deserialized_t) 23 | 24 | 25 | def test_tensorstore_serializer_roundtrip_chunked(): 26 | """ 27 | Test TensorstoreSerializer serializes and deserializes correctly on a large tensor (that should get chunked) 28 | """ 29 | serializer = params.TensorStoreSerializer() 30 | t = np.random.rand(5000, 5000) 31 | serialized_t = asyncio.run(serializer.serialize(t)) 32 | deserialized_t = asyncio.run(serializer.deserialize(serialized_t)) 33 | np.testing.assert_array_equal(t, deserialized_t) 34 | 35 | 36 | def test_tar_combiner_roundtrip(): 37 | """ 38 | Test MsgPackCombiner combines and splits correctly 39 | """ 40 | combiner = params.MsgPackCombiner() 41 | param_files = { 42 | "param1": {"file1": b"abcdefg", "file2": b"hijklmnop"}, 43 | "param2": {"file1": b"01234", "file2": b"56789"}, 44 | } 45 | combined_param_files = combiner.combine(param_files) 46 | split_param_files = combiner.split(combined_param_files) 47 | assert param_files == split_param_files 48 | 49 | 50 | def test_update_serializer_roundtrip(): 51 | """ 52 | Test UpdateSerializer made from TensorStoreSerializer and MagPackCombiner serializes and deserializes correctly 53 | """ 54 | serializer = params.UpdateSerializer( 55 | params.TensorStoreSerializer(), params.MsgPackCombiner() 56 | ) 57 | update_params = { 58 | "param1": np.random.rand(100, 100), 59 | "param2": np.random.rand(50, 10, 2), 60 | "param3": np.random.rand(1000), 61 | } 62 | serialized_update_params = asyncio.run(serializer.serialize(update_params)) 63 | deserialized_update_params = asyncio.run( 64 | serializer.deserialize(serialized_update_params) 65 | ) 66 | 67 | assert update_params.keys() == deserialized_update_params.keys() 68 | 69 | for param_name in update_params.keys(): 70 | np.testing.assert_array_equal( 71 | update_params[param_name], deserialized_update_params[param_name] 72 | ) 73 | -------------------------------------------------------------------------------- /tests/theta_test.py: -------------------------------------------------------------------------------- 1 | """Tests for theta.py""" 2 | 3 | import os 4 | import random 5 | 6 | import helpers 7 | 8 | from git_theta import theta 9 | 10 | 11 | def test_commit_info_serialization(data_generator): 12 | """ 13 | Test that CommitInfo objects serialize/deserialize to/from files correctly 14 | """ 15 | commit_info = data_generator.random_commit_info() 16 | with helpers.utils.named_temporary_file() as tmpfile: 17 | commit_info.write(tmpfile) 18 | tmpfile.flush() 19 | tmpfile.close() 20 | commit_info_read = theta.CommitInfo.from_file(tmpfile.name) 21 | assert commit_info == commit_info_read 22 | 23 | 24 | def test_get_commit_info(git_repo_with_commits): 25 | """ 26 | Test getting the correct CommitInfo object for a certain commit hash using a ThetaCommits object 27 | """ 28 | repo, commit_hashes, commit_infos = git_repo_with_commits 29 | theta_commits = theta.ThetaCommits(repo) 30 | 31 | for commit_hash, commit_info in zip(commit_hashes, commit_infos): 32 | assert theta_commits.get_commit_info(commit_hash) == commit_info 33 | 34 | 35 | def test_get_commit_info_range(git_repo_with_commits): 36 | """ 37 | Test getting the correct CommitInfo objects for a certain commit hash range using a ThetaCommits object 38 | """ 39 | repo, commit_hashes, commit_infos = git_repo_with_commits 40 | theta_commits = theta.ThetaCommits(repo) 41 | for _ in range(100): 42 | start = random.randint(0, len(commit_hashes) - 1) # Exclusive 43 | end = random.randint(start, len(commit_hashes) - 1) # Inclusive 44 | # Returned in reverse chronological order so reverse the result so it matches the order of commit_infos 45 | commit_info_range = list( 46 | reversed( 47 | theta_commits.get_commit_info_range( 48 | commit_hashes[start], commit_hashes[end] 49 | ) 50 | ) 51 | ) 52 | assert len(commit_info_range) == (end - start) 53 | for idx, commit_info in enumerate(commit_info_range): 54 | assert commit_info == commit_infos[start + idx + 1] 55 | 56 | 57 | def test_get_commit_oids(git_repo_with_commits): 58 | """ 59 | Test getting the correct object-ids for a certain commit hash using a ThetaCommits object 60 | """ 61 | repo, commit_hashes, commit_infos = git_repo_with_commits 62 | theta_commits = theta.ThetaCommits(repo) 63 | for commit_hash, commit_info in zip(commit_hashes, commit_infos): 64 | assert theta_commits.get_commit_oids(commit_hash) == commit_info.oids 65 | 66 | 67 | def test_combine_oid_sets(): 68 | """ 69 | Test combining multiple object-id sets into a single set 70 | """ 71 | oid_sets = [set([1, 2, 3, 4]), set([1, 2]), set([1, 2, 6])] 72 | combined_set = set([1, 2, 3, 4, 6]) 73 | assert theta.ThetaCommits.combine_oid_sets(oid_sets) == combined_set 74 | -------------------------------------------------------------------------------- /tests/trie_test.py: -------------------------------------------------------------------------------- 1 | """Test for the Trie (O(n) prefix lookup).""" 2 | 3 | from git_theta import utils 4 | 5 | 6 | def test_trie_insert(): 7 | word = "homework" 8 | t = utils.Trie() 9 | assert word not in t 10 | t.insert(word) 11 | assert word in t 12 | 13 | 14 | def test_trie_insert_then_prefix(): 15 | word = "homework" 16 | t = utils.Trie() 17 | t.insert(word) 18 | assert t.prefix("home") 19 | 20 | 21 | def test_trie_contians_miss(): 22 | t = utils.Trie() 23 | t.insert("homework") 24 | t.insert("television") 25 | t.insert("homeish") 26 | assert "homelab" not in t 27 | 28 | 29 | def test_trie_contains_prefix(): 30 | t = utils.Trie() 31 | t.insert("homework") 32 | assert "homework" in t 33 | assert t.prefix("home") 34 | assert "home" not in t 35 | 36 | 37 | def test_trie_word_not_prefix(): 38 | base = "home" 39 | full = f"{base}work" 40 | t = utils.Trie() 41 | t.insert(base) 42 | assert base in t 43 | assert not t.prefix(base) 44 | assert full not in t 45 | t.insert(full) 46 | assert full in t 47 | assert t.prefix(base) 48 | 49 | 50 | def test_trie_iterable(): 51 | added_words = [ 52 | "apple", 53 | "banana", 54 | "table", 55 | ] 56 | missing_words = ["laptop", "cable"] 57 | t = utils.Trie.from_iterable(added_words) 58 | for added in added_words: 59 | assert added in t 60 | for miss in missing_words: 61 | assert miss not in t 62 | -------------------------------------------------------------------------------- /tests/updates/base_test.py: -------------------------------------------------------------------------------- 1 | """Tests for common update plugin functions.""" 2 | 3 | import os 4 | 5 | import pytest 6 | 7 | from git_theta import utils 8 | from git_theta.updates import base 9 | 10 | ENV_UPDATE_TYPE = "GIT_THETA_UPDATE_TYPE" 11 | 12 | 13 | @pytest.fixture 14 | def env_var(): 15 | current_env = dict(os.environ) 16 | os.environ[ENV_UPDATE_TYPE] = "sparse" 17 | 18 | yield 19 | os.environ.clear() 20 | os.environ.update(current_env) 21 | 22 | 23 | @pytest.fixture 24 | def no_env_var(): 25 | current_env = dict(os.environ) 26 | os.environ.pop(ENV_UPDATE_TYPE, None) 27 | 28 | yield 29 | os.environ.clear() 30 | os.environ.update(current_env) 31 | 32 | 33 | @pytest.fixture 34 | def empty_env_var(): 35 | current_env = dict(os.environ) 36 | os.environ[ENV_UPDATE_TYPE] = "" 37 | 38 | yield 39 | os.environ.clear() 40 | os.environ.update(current_env) 41 | 42 | 43 | def test_get_update_handler_name_prefers_user_input(env_var): 44 | """Ensure that user input is always used if provided.""" 45 | user_input = "low-rank" 46 | assert os.environ.get(ENV_UPDATE_TYPE) == "sparse" 47 | assert base.get_update_handler_name(user_input) == user_input 48 | 49 | 50 | def test_get_update_handler_name_uses_env_variable(env_var): 51 | """Ensure env variables are checked before defaulting.""" 52 | user_input = None 53 | assert base.get_update_handler_name(user_input) == "sparse" 54 | 55 | 56 | def test_get_update_handler_name_default(no_env_var): 57 | """Ensure there is a default when there is no input and no defined env var.""" 58 | user_input = None 59 | assert os.environ.get(ENV_UPDATE_TYPE) is None 60 | assert base.get_update_handler_name(user_input) == "dense" 61 | 62 | 63 | def test_get_update_handler_name_empty_env(empty_env_var): 64 | """Ensure there is a default when there is no input and a defined, but empty, env var.""" 65 | user_input = None 66 | assert ENV_UPDATE_TYPE in os.environ 67 | assert os.environ[ENV_UPDATE_TYPE] == "" 68 | assert base.get_update_handler_name(user_input) == "dense" 69 | -------------------------------------------------------------------------------- /tests/updates/ia3_test.py: -------------------------------------------------------------------------------- 1 | """Tests for our ia3 update.""" 2 | 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from git_theta import async_utils, params 8 | from git_theta.updates import ia3 9 | 10 | SHAPE1 = 3 11 | SHAPE2 = 30 12 | SHAPE3 = 30 13 | SHAPE4 = 30 14 | 15 | TRIALS = 50 16 | 17 | 18 | @pytest.fixture 19 | def updater(): 20 | return ia3.IA3Update(params.get_update_serializer()) 21 | 22 | 23 | def test_ia3_round_trip_application(updater): 24 | for _ in range(TRIALS): 25 | parameter = np.random.randn(SHAPE1, SHAPE2, SHAPE3, SHAPE4) 26 | update = np.random.randn(SHAPE1, SHAPE2, 1, SHAPE4) 27 | updated_parameter = parameter * update 28 | 29 | calc_update = async_utils.run( 30 | updater.calculate_update(updated_parameter, parameter, broadcast_dims=[2]) 31 | ) 32 | result = async_utils.run(updater.apply_update(calc_update, parameter)) 33 | 34 | np.testing.assert_allclose(result, updated_parameter, rtol=1e-6) 35 | 36 | 37 | def test_ia3_round_trip_application_with_moredims(updater): 38 | for _ in range(TRIALS): 39 | parameter = np.random.randn(SHAPE1, SHAPE2, SHAPE3, SHAPE4) 40 | update = np.random.randn(1, SHAPE2, SHAPE3, 1) 41 | updated_parameter = parameter * update 42 | 43 | calc_update = async_utils.run( 44 | updater.calculate_update( 45 | updated_parameter, parameter, broadcast_dims=[0, 3] 46 | ) 47 | ) 48 | result = async_utils.run(updater.apply_update(calc_update, parameter)) 49 | 50 | np.testing.assert_allclose(result, updated_parameter, rtol=1e-6) 51 | 52 | 53 | def test_ia3_round_trip_application_with_sparse_parameter(updater): 54 | for _ in range(TRIALS): 55 | parameter = np.random.randn(SHAPE1, SHAPE2, SHAPE3, SHAPE4) 56 | update = np.random.randn(SHAPE1, SHAPE2, 1, 1) 57 | threshold = np.quantile(parameter, 0.3) 58 | parameter[parameter < threshold] = 0 59 | updated_parameter = parameter * update 60 | 61 | calc_update = async_utils.run( 62 | updater.calculate_update( 63 | updated_parameter, parameter, broadcast_dims=[2, 3] 64 | ) 65 | ) 66 | result = async_utils.run(updater.apply_update(calc_update, parameter)) 67 | 68 | np.testing.assert_allclose(result, updated_parameter, rtol=1e-6) 69 | -------------------------------------------------------------------------------- /tests/updates/low_rank_test.py: -------------------------------------------------------------------------------- 1 | """Tests for our low Rank Update.""" 2 | 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from git_theta import async_utils, params 8 | from git_theta.updates import low_rank 9 | 10 | K = 20 11 | INPUT_SIZE = 1024 12 | OUTPUT_SIZE = 1024 13 | TRIALS = 50 14 | 15 | 16 | @pytest.fixture 17 | def updater(): 18 | return low_rank.LowRankUpdate(params.get_update_serializer()) 19 | 20 | 21 | def test_low_rank_update_rank_inference(updater): 22 | for _ in range(TRIALS): 23 | parameter = np.random.randn(INPUT_SIZE, OUTPUT_SIZE) 24 | R = np.random.randn(INPUT_SIZE, K) 25 | C = np.random.randn(K, OUTPUT_SIZE) 26 | updated_parameter = R @ C + parameter 27 | 28 | update = async_utils.run(updater.calculate_update(updated_parameter, parameter)) 29 | assert update["R"].shape == R.shape 30 | assert update["C"].shape == C.shape 31 | 32 | 33 | @pytest.mark.xfail(strict=False) 34 | def test_low_rank_update_application(updater): 35 | for _ in range(TRIALS): 36 | parameter = np.random.randn(INPUT_SIZE, OUTPUT_SIZE) 37 | R = np.random.randn(INPUT_SIZE, K) 38 | C = np.random.randn(K, OUTPUT_SIZE) 39 | updated_parameter = R @ C + parameter 40 | 41 | update = async_utils.run(updater.calculate_update(updated_parameter, parameter)) 42 | result = async_utils.run(updater.apply_update(update, parameter)) 43 | 44 | np.testing.assert_allclose(result, updated_parameter, rtol=1e-6) 45 | 46 | 47 | def test_low_rank_update_application_1d(updater): 48 | parameter = np.random.randn(INPUT_SIZE) 49 | update = np.random.randn(*parameter.shape) 50 | 51 | updated_parameter = update + parameter 52 | 53 | calculated_update = async_utils.run( 54 | updater.calculate_update(updated_parameter, parameter) 55 | ) 56 | calculated_result = async_utils.run( 57 | updater.apply_update(calculated_update, parameter) 58 | ) 59 | 60 | np.testing.assert_allclose(calculated_result, updated_parameter) 61 | -------------------------------------------------------------------------------- /tests/updates/sparse_update_test.py: -------------------------------------------------------------------------------- 1 | """Tests for our sparse Update.""" 2 | 3 | 4 | import numpy as np 5 | import pytest 6 | import scipy.sparse 7 | 8 | from git_theta import async_utils, params 9 | from git_theta.updates import sparse 10 | 11 | SHAPE = 100 12 | NUM_UPDATES = 1000 13 | TRIALS = 50 14 | 15 | 16 | @pytest.fixture 17 | def updater(): 18 | return lambda threshold: sparse.SparseUpdate( 19 | params.get_update_serializer(), threshold=threshold 20 | ) 21 | 22 | 23 | def test_sparse_round_trip_application(updater): 24 | for _ in range(TRIALS): 25 | parameter = np.random.randn(SHAPE, SHAPE, SHAPE) 26 | x, y, z = np.random.choice( 27 | np.arange(SHAPE), size=(3, NUM_UPDATES), replace=True 28 | ) 29 | sparse_update = np.random.randn(NUM_UPDATES) 30 | updated_parameter = parameter.copy() 31 | updated_parameter[x, y, z] = sparse_update 32 | 33 | sparse_updater = updater(threshold=1e-12) 34 | update = async_utils.run( 35 | sparse_updater.calculate_update(updated_parameter, parameter) 36 | ) 37 | result = async_utils.run(sparse_updater.apply_update(update, parameter)) 38 | 39 | np.testing.assert_allclose(result, updated_parameter, rtol=1e-6) 40 | 41 | 42 | def test_known_sparsity(updater): 43 | for _ in range(TRIALS): 44 | parameter = np.random.randn(SHAPE, SHAPE, SHAPE) 45 | diff_tensor = np.random.randn(SHAPE, SHAPE, SHAPE) 46 | # To ensure there is no sparsity in diff tensor in the first place 47 | diff_tensor[diff_tensor == 0] = 0.1 48 | threshold = np.quantile(diff_tensor, 0.3) 49 | diff_tensor[diff_tensor < threshold] = 0 50 | updated_parameter = parameter + diff_tensor 51 | 52 | sparse_updater = updater(threshold=1e-12) 53 | update_dict = async_utils.run( 54 | sparse_updater.calculate_update(updated_parameter, parameter) 55 | ) 56 | calc_sparsity = 1 - len(update_dict["data"]) / np.prod(parameter.shape) 57 | np.testing.assert_allclose(calc_sparsity, 0.3, rtol=1e-5) 58 | 59 | 60 | def test_monotonic_increasing_sparseness(updater): 61 | for _ in range(TRIALS): 62 | parameter = np.random.randn(SHAPE, SHAPE, SHAPE) 63 | diff_tensor = np.random.randn(SHAPE, SHAPE, SHAPE) 64 | threshold = np.quantile(diff_tensor, 0.3) 65 | diff_tensor[diff_tensor < threshold] = 0 66 | updated_parameter = parameter + diff_tensor 67 | sparseness = [] 68 | for threshold in [1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100]: 69 | sparse_updater = updater(threshold=threshold) 70 | update_dict = async_utils.run( 71 | sparse_updater.calculate_update(updated_parameter, parameter) 72 | ) 73 | sparsity = 1 - len(update_dict["data"]) / np.prod(parameter.shape) 74 | sparseness.append(sparsity) 75 | assert all( 76 | sparseness[i] <= sparseness[i + 1] for i in range(len(sparseness) - 1) 77 | ) 78 | -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | """Tests for utils.py""" 2 | 3 | import operator as op 4 | import os 5 | import time 6 | 7 | import pytest 8 | 9 | from git_theta import utils 10 | 11 | 12 | def test_flatten_dict_empty_leaf(): 13 | """Test that empty leaves are ignored.""" 14 | nested = { 15 | "a": {}, 16 | "b": { 17 | "c": 1, 18 | "d": {}, 19 | }, 20 | } 21 | gold = {("b", "c"): 1} 22 | assert utils.flatten(nested) == gold 23 | 24 | 25 | def test_flatten_dict_empty(): 26 | assert utils.flatten({}) == {} 27 | 28 | 29 | def test_sorted_flatten_dict_insertion_order(): 30 | """Test that key order is consistent for different insertion order.""" 31 | nested_dict = { 32 | "a": { 33 | "b": { 34 | "c": 10, 35 | "d": 20, 36 | "e": 30, 37 | }, 38 | "c": { 39 | "b": 40, 40 | "z": -1, 41 | }, 42 | } 43 | } 44 | 45 | nested_dict_new_order = { 46 | "a": { 47 | "c": { 48 | "z": -1, 49 | "b": 40, 50 | }, 51 | "b": { 52 | "d": 20, 53 | "e": 30, 54 | "c": 10, 55 | }, 56 | } 57 | } 58 | assert nested_dict == nested_dict_new_order 59 | one_flat = utils.flatten(nested_dict) 60 | two_flat = utils.flatten(nested_dict_new_order) 61 | assert one_flat == two_flat 62 | for one, two in zip(sorted(one_flat.items()), sorted(two_flat.items())): 63 | assert one[0] == two[0] 64 | assert one[1] == two[1] 65 | 66 | 67 | def test_flattened_dict_keys_are_correct(data_generator): 68 | """Test that indexing the nested dict with the keys yields the value.""" 69 | nested = data_generator.random_nested_dict() 70 | for flat_key, flat_value in utils.flatten(nested).items(): 71 | curr = nested 72 | for key in flat_key: 73 | curr = curr[key] 74 | assert curr == flat_value 75 | 76 | 77 | def test_flattened_dict_sorted_is_actually_sorted(data_generator): 78 | """Test to ensure the leaves are actually sorted.""" 79 | nested = data_generator.random_nested_dict() 80 | keys = tuple(map(op.itemgetter(0), sorted(utils.flatten(nested).items()))) 81 | string_keys = ["/".join(k) for k in keys] 82 | sorted_string_keys = sorted(string_keys) 83 | assert string_keys == sorted_string_keys 84 | 85 | 86 | def test_is_valid_oid(data_generator): 87 | oids = [data_generator.random_oid() for _ in range(100)] 88 | assert all([utils.is_valid_oid(oid) for oid in oids]) 89 | 90 | 91 | def test_is_valid_commit_hash(data_generator): 92 | commit_hashes = [data_generator.random_commit_hash() for _ in range(100)] 93 | assert all( 94 | [utils.is_valid_commit_hash(commit_hash) for commit_hash in commit_hashes] 95 | ) 96 | 97 | 98 | @pytest.fixture 99 | def test_file(): 100 | path = "./test-file.txt" 101 | with open(path, "w") as w: 102 | w.write("Testing") 103 | yield path 104 | os.remove(path) 105 | 106 | 107 | def test_touch(test_file): 108 | old_stats = os.stat(test_file) 109 | # Buffer to make sure that our new a/mtime is greater than the resolution of 110 | # the host operating system. 111 | time.sleep(1) 112 | utils.touch(test_file) 113 | new_stats = os.stat(test_file) 114 | assert old_stats.st_atime < new_stats.st_atime 115 | assert old_stats.st_mtime < new_stats.st_mtime 116 | --------------------------------------------------------------------------------