├── .codecov.yml ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── gpu_tests.yaml │ ├── lint.yml │ ├── publish.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.bib ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── _static │ └── images │ │ ├── colab-badge.svg │ │ ├── couplings.png │ │ ├── logoOTT.ico │ │ └── logoOTT.png ├── _templates │ └── autosummary │ │ └── class.rst ├── bibliography.rst ├── conf.py ├── contributing.rst ├── experimental │ ├── index.rst │ └── mmsinkhorn.rst ├── geometry.rst ├── glossary.rst ├── index.rst ├── initializers │ ├── index.rst │ ├── linear.rst │ └── quadratic.rst ├── make.bat ├── math.rst ├── neural │ ├── datasets.rst │ ├── index.rst │ ├── methods.rst │ └── networks.rst ├── problems │ ├── index.rst │ ├── linear.rst │ └── quadratic.rst ├── references.bib ├── solvers │ ├── index.rst │ ├── linear.rst │ └── quadratic.rst ├── spelling │ ├── misc.txt │ └── technical.txt ├── tools.rst ├── tutorials │ ├── barycenter │ │ ├── 000_Sinkhorn_Barycenters.ipynb │ │ ├── 100_wasserstein_barycenters_gmms.ipynb │ │ ├── 200_gmm_pair_demo.ipynb │ │ └── index.rst │ ├── basic_ot_between_datasets.ipynb │ ├── geometry │ │ ├── 000_point_cloud.ipynb │ │ ├── 100_grid.ipynb │ │ └── index.rst │ ├── index.rst │ ├── linear │ │ ├── 000_One_Sinkhorn.ipynb │ │ ├── 100_OTT_&_POT.ipynb │ │ ├── 200_sinkhorn_divergence_gradient_flow.ipynb │ │ ├── 300_LRSinkhorn.ipynb │ │ ├── 400_Hessians.ipynb │ │ ├── 500_sparse_monge_displacements.ipynb │ │ ├── 600_mmsink.ipynb │ │ ├── 700_progot.ipynb │ │ ├── 800_Unbalanced_OT.ipynb │ │ └── index.rst │ ├── misc │ │ ├── 000_tracking_progress.ipynb │ │ ├── 100_soft_sort.ipynb │ │ ├── 200_application_biology.ipynb │ │ ├── 300_otcp.ipynb │ │ └── index.rst │ ├── neural │ │ ├── 000_neural_dual.ipynb │ │ ├── 100_icnn_inits.ipynb │ │ ├── 200_Monge_Gap.ipynb │ │ ├── 300_ENOT.ipynb │ │ ├── 400_MetaOT.ipynb │ │ └── index.rst │ └── quadratic │ │ ├── 000_gromov_wasserstein.ipynb │ │ ├── 100_GWLRSinkhorn.ipynb │ │ ├── 200_gromov_wasserstein_multiomics.ipynb │ │ └── index.rst └── utils.rst ├── pyproject.toml ├── setup.py ├── src └── ott │ ├── __init__.py │ ├── _version.py │ ├── datasets.py │ ├── experimental │ ├── __init__.py │ └── mmsinkhorn.py │ ├── geometry │ ├── __init__.py │ ├── costs.py │ ├── distrib_costs.py │ ├── epsilon_scheduler.py │ ├── geodesic.py │ ├── geometry.py │ ├── graph.py │ ├── grid.py │ ├── low_rank.py │ ├── pointcloud.py │ ├── regularizers.py │ └── segment.py │ ├── initializers │ ├── __init__.py │ ├── linear │ │ ├── __init__.py │ │ ├── initializers.py │ │ └── initializers_lr.py │ ├── neural │ │ ├── __init__.py │ │ └── meta_initializer.py │ └── quadratic │ │ ├── __init__.py │ │ └── initializers.py │ ├── math │ ├── __init__.py │ ├── fixed_point_loop.py │ ├── matrix_square_root.py │ ├── unbalanced_functions.py │ └── utils.py │ ├── neural │ ├── __init__.py │ ├── datasets.py │ ├── methods │ │ ├── __init__.py │ │ ├── expectile_neural_dual.py │ │ ├── flows │ │ │ ├── __init__.py │ │ │ ├── dynamics.py │ │ │ ├── genot.py │ │ │ └── otfm.py │ │ ├── monge_gap.py │ │ └── neuraldual.py │ └── networks │ │ ├── __init__.py │ │ ├── icnn.py │ │ ├── layers │ │ ├── __init__.py │ │ ├── conjugate.py │ │ ├── posdef.py │ │ └── time_encoder.py │ │ ├── potentials.py │ │ └── velocity_field.py │ ├── problems │ ├── __init__.py │ ├── linear │ │ ├── __init__.py │ │ ├── barycenter_problem.py │ │ ├── linear_problem.py │ │ └── potentials.py │ └── quadratic │ │ ├── __init__.py │ │ ├── gw_barycenter.py │ │ ├── quadratic_costs.py │ │ └── quadratic_problem.py │ ├── py.typed │ ├── solvers │ ├── __init__.py │ ├── linear │ │ ├── __init__.py │ │ ├── _solve.py │ │ ├── acceleration.py │ │ ├── continuous_barycenter.py │ │ ├── discrete_barycenter.py │ │ ├── implicit_differentiation.py │ │ ├── lineax_implicit.py │ │ ├── lr_utils.py │ │ ├── sinkhorn.py │ │ ├── sinkhorn_lr.py │ │ └── univariate.py │ ├── quadratic │ │ ├── __init__.py │ │ ├── _solve.py │ │ ├── gromov_wasserstein.py │ │ ├── gromov_wasserstein_lr.py │ │ ├── gw_barycenter.py │ │ └── lower_bound.py │ ├── utils.py │ └── was_solver.py │ ├── tools │ ├── __init__.py │ ├── conformal.py │ ├── gaussian_mixture │ │ ├── __init__.py │ │ ├── fit_gmm.py │ │ ├── fit_gmm_pair.py │ │ ├── gaussian.py │ │ ├── gaussian_mixture.py │ │ ├── gaussian_mixture_pair.py │ │ ├── linalg.py │ │ ├── probabilities.py │ │ └── scale_tril.py │ ├── k_means.py │ ├── plot.py │ ├── progot.py │ ├── segment_sinkhorn.py │ ├── sinkhorn_divergence.py │ ├── sliced.py │ ├── soft_sort.py │ └── unreg.py │ ├── types.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── experimental └── mmsinkhorn_test.py ├── geometry ├── costs_test.py ├── geodesic_test.py ├── geometry_test.py ├── graph_test.py ├── lr_cost_test.py ├── lr_kernel_test.py ├── pointcloud_test.py ├── regularizers_test.py └── scaling_cost_test.py ├── initializers ├── linear │ ├── sinkhorn_init_test.py │ └── sinkhorn_lr_init_test.py ├── neural │ ├── __init__.py │ └── meta_initializer_test.py └── quadratic │ └── gw_init_test.py ├── math ├── lse_test.py ├── math_utils_test.py └── matrix_square_root_test.py ├── neural ├── __init__.py ├── conftest.py ├── methods │ ├── genot_test.py │ ├── monge_gap_test.py │ ├── neuraldual_test.py │ └── otfm_test.py └── networks │ └── icnn_test.py ├── problems └── linear │ └── potentials_test.py ├── solvers ├── linear │ ├── continuous_barycenter_test.py │ ├── discrete_barycenter_test.py │ ├── sinkhorn_diff_test.py │ ├── sinkhorn_grid_test.py │ ├── sinkhorn_lr_test.py │ ├── sinkhorn_misc_test.py │ ├── sinkhorn_test.py │ └── univariate_test.py └── quadratic │ ├── fgw_test.py │ ├── gw_barycenter_test.py │ ├── gw_test.py │ └── lower_bound_test.py ├── tools ├── conformal_test.py ├── gaussian_mixture │ ├── fit_gmm_pair_test.py │ ├── fit_gmm_test.py │ ├── gaussian_mixture_pair_test.py │ ├── gaussian_mixture_test.py │ ├── gaussian_test.py │ ├── linalg_test.py │ ├── probabilities_test.py │ └── scale_tril_test.py ├── k_means_test.py ├── plot_test.py ├── segment_sinkhorn_test.py ├── sinkhorn_divergence_test.py ├── sliced_test.py ├── soft_sort_test.py └── unreg_test.py └── utils_test.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | strict_yaml_branch: main 4 | 5 | coverage: 6 | range: 75..100 7 | status: 8 | project: 9 | default: 10 | target: 1 11 | patch: off 12 | 13 | comment: 14 | layout: reach, diff, files 15 | behavior: default 16 | require_changes: true 17 | branches: 18 | - main 19 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | charset = utf-8 7 | 8 | [{*py,*.rst}] 9 | indent_size = 2 10 | indent_style = space 11 | max_line_length = 80 12 | 13 | [{*.yml,*.yaml}] 14 | indent_size = 2 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/gpu_tests.yaml: -------------------------------------------------------------------------------- 1 | name: GPU Tests 2 | 3 | on: 4 | schedule: 5 | - cron: 00 00 * * 1 6 | push: 7 | branches: [main] 8 | pull_request: 9 | branches: [main] 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | gpu-tests: 17 | name: Python 3.10 on ubuntu-22.04 18 | runs-on: [self-hosted, ott-gpu] 19 | container: 20 | image: docker://michalk8/cuda:12.3.2-cudnn9-devel-ubuntu22.04 21 | options: --gpus="device=2" 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - name: Install dependencies 26 | run: | 27 | python3 -m pip install --upgrade pip 28 | python3 -m pip install -e".[test]" 29 | python3 -m pip install "jax[cuda12]" 30 | 31 | - name: Run nvidia-smi 32 | run: | 33 | nvidia-smi 34 | 35 | - name: Run tests 36 | run: | 37 | python3 -m pytest -m "fast and not cpu" --memray --durations 10 -vv 38 | env: 39 | XLA_PYTHON_CLIENT_PREALLOCATE: 'false' 40 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | schedule: 5 | - cron: 00 00 * * 1 6 | push: 7 | branches: [main] 8 | pull_request: 9 | branches: [main] 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | lint: 17 | name: ${{ matrix.lint-kind }} 18 | runs-on: ubuntu-latest 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | lint-kind: [code, docs] 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python 3.10 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | 31 | - name: Cache pre-commit 32 | uses: actions/cache@v4 33 | if: ${{ matrix.lint-kind == 'code' }} 34 | with: 35 | path: ~/.cache/pre-commit 36 | key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} 37 | 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | python -m pip install tox 42 | 43 | - name: Install PyEnchant 44 | if: ${{ matrix.lint-kind == 'docs' }} 45 | run: | 46 | sudo apt-get update -y 47 | sudo apt-get install libenchant-2-dev 48 | python -m pip install pyenchant 49 | 50 | - name: Lint ${{ matrix.lint-kind }} 51 | run: | 52 | tox -e lint-${{ matrix.lint-kind }} 53 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | id-token: write 12 | environment: publish-pypi 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 3.10 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install tox 25 | 26 | - name: Build package 27 | run: tox -e build-package 28 | 29 | - name: Publish package on PyPI 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | with: 32 | skip-existing: true 33 | verify-metadata: true 34 | verbose: true 35 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | schedule: 5 | - cron: 00 00 * * 1 6 | push: 7 | branches: [main] 8 | pull_request: 9 | branches: [main] 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | fast-tests: 17 | name: Python ${{ matrix.python-version }} ${{ matrix.jax-version }} (fast) 18 | runs-on: ubuntu-latest 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | python-version: ['3.12'] 23 | jax-version: [jax-default, jax-latest] 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python 3.9 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | python -m pip install tox 36 | 37 | - name: Setup environment 38 | run: | 39 | tox -e py${{ matrix.python-version }}-${{ matrix.jax-version }} --notest -v 40 | 41 | - name: Run tests 42 | continue-on-error: ${{ matrix.jax-version == 'jax-latest' }} 43 | run: | 44 | tox -e py${{ matrix.python-version }}-${{ matrix.jax-version }} --skip-pkg-install -- -m fast --memray -n auto -vv 45 | 46 | tests: 47 | name: Python ${{ matrix.python-version }} on ${{ matrix.os }} 48 | runs-on: ${{ matrix.os }} 49 | strategy: 50 | fail-fast: false 51 | matrix: 52 | python-version: ['3.10', '3.11', '3.12', '3.13'] 53 | os: [ubuntu-latest] 54 | include: 55 | - python-version: '3.10' 56 | os: macos-14 57 | - python-version: '3.12' 58 | os: macos-15 59 | 60 | steps: 61 | - uses: actions/checkout@v4 62 | - name: Set up Python ${{ matrix.python-version }} 63 | uses: actions/setup-python@v5 64 | with: 65 | python-version: ${{ matrix.python-version }} 66 | 67 | - name: Install dependencies 68 | run: | 69 | python -m pip install --upgrade pip 70 | python -m pip install tox 71 | 72 | - name: Setup environment 73 | run: | 74 | tox -e py${{ matrix.python-version }} --notest -v 75 | 76 | - name: Run tests 77 | run: | 78 | tox -e py${{ matrix.python-version }} --skip-pkg-install 79 | env: 80 | PYTEST_ADDOPTS: --memray -vv 81 | 82 | - name: Upload coverage 83 | uses: codecov/codecov-action@v4 84 | with: 85 | files: ./coverage.xml 86 | flags: tests-${{ matrix.os }}-${{ matrix.python-version }} 87 | name: unittests 88 | token: ${{ secrets.CODECOV_TOKEN }} 89 | env_vars: OS,PYTHON 90 | fail_ci_if_error: false 91 | verbose: true 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # vscode 163 | .vscode/ 164 | 165 | # generated documentation 166 | docs/html 167 | **/_autosummary 168 | 169 | # macos 170 | **/.DS_Store 171 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | default_language_version: 3 | python: python3 4 | default_stages: 5 | - commit 6 | - push 7 | minimum_pre_commit_version: 3.0.0 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.6.0 11 | hooks: 12 | - id: detect-private-key 13 | - id: check-ast 14 | - id: check-toml 15 | - id: end-of-file-fixer 16 | - id: mixed-line-ending 17 | args: [--fix=lf] 18 | - id: file-contents-sorter 19 | files: docs/spelling/.*\.txt 20 | - id: trailing-whitespace 21 | - id: check-case-conflict 22 | - repo: https://github.com/charliermarsh/ruff-pre-commit 23 | rev: v0.4.10 24 | hooks: 25 | - id: ruff 26 | args: [--fix, --exit-non-zero-on-fix] 27 | - repo: https://github.com/pycqa/isort 28 | rev: 5.13.2 29 | hooks: 30 | - id: isort 31 | name: isort 32 | - repo: https://github.com/google/yapf 33 | rev: v0.40.2 34 | hooks: 35 | - id: yapf 36 | additional_dependencies: [toml] 37 | - repo: https://github.com/nbQA-dev/nbQA 38 | rev: 1.8.5 39 | hooks: 40 | - id: nbqa-pyupgrade 41 | args: [--py39-plus] 42 | - id: nbqa-black 43 | - id: nbqa-isort 44 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 45 | rev: v2.13.0 46 | hooks: 47 | - id: pretty-format-yaml 48 | args: [--autofix, --indent, '2'] 49 | - repo: https://github.com/rstcheck/rstcheck 50 | rev: v6.2.0 51 | hooks: 52 | - id: rstcheck 53 | additional_dependencies: [tomli] 54 | args: [--config=pyproject.toml] 55 | - repo: https://github.com/PyCQA/doc8 56 | rev: v1.1.1 57 | hooks: 58 | - id: doc8 59 | args: [--config=pyproject.toml] 60 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: '3.10' 7 | 8 | sphinx: 9 | builder: html 10 | configuration: docs/conf.py 11 | fail_on_warning: false 12 | 13 | python: 14 | install: 15 | - method: pip 16 | path: . 17 | extra_requirements: [docs, neural] 18 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @misc{cuturi2022optimal, 2 | author = {Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and Davis, Geoff and Teboul, Olivier}, 3 | eprint = {2201.12324}, 4 | eprintclass = {cs.LG}, 5 | eprinttype = {arXiv}, 6 | title = {Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein}, 7 | year = {2022}, 8 | } 9 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to OTT 2 | We'd love to accept your contributions to this project. 3 | 4 | There are many ways to contribute to OTT, with the most common ones being contribution of code, documentation 5 | to the project, participating in discussions or raising issues. 6 | 7 | ## Contributing code or documentation 8 | 1. fork the repository using the [Fork](https://github.com/ott-jax/ott/fork) button on GitHub. 9 | 2. ```shell 10 | git clone https://github.com//ott.git ott && \ 11 | cd ott && \ 12 | pip install -e '.[dev]' && \ 13 | pre-commit install 14 | ``` 15 | 16 | When committing changes, sometimes you might want or need to bypass the pre-commit checks. This can be 17 | done via the `--no-verify` flag as: 18 | ```shell 19 | git commit --no-verify -m "" 20 | ``` 21 | 22 | ## Running tests 23 | In order to run tests, we utilize [tox](https://tox.wiki/): 24 | ```shell 25 | tox run # run linter and all tests on all available Python versions 26 | tox run -- -n auto -m fast # run linter and fast tests in parallel 27 | tox -e py3.9 -- -k "test_euclidean_point_cloud" # run tests matching the expression on Python3.9 28 | tox -e py3.10 -- --memray # test also memory on Python3.10 29 | ``` 30 | Alternatively, tests can be also run using the [pytest](https://docs.pytest.org/): 31 | ```shell 32 | python -m pytest 33 | ``` 34 | This requires the `'[test]'` extra requirements to be installed as `pip install -e.'[test]'`. 35 | 36 | ## Documentation 37 | From the root of the repository, run: 38 | ```shell 39 | tox -e clean-docs,build-docs,lint-docs # remove, build and lint the documentation 40 | ``` 41 | Installing `PyEnchant` is required to run spellchecker, please refer to the 42 | [installation instructions](https://pyenchant.github.io/pyenchant/install.html). On macOS Silicon, it may be necessary 43 | to also set `PYENCHANT_LIBRARY_PATH` environment variable, as, e.g., 44 | `export PYENCHANT_LIBRARY_PATH=/opt/homebrew/lib/libenchant-2.2.dylib`. False positives and correctly spelled words can 45 | be added to one of the files in [docs/spelling](https://github.com/ott-jax/ott/tree/main/docs/spelling). 46 | 47 | ## Building the package 48 | The package can be built using: 49 | ```shell 50 | tox -e build-package 51 | ``` 52 | Afterwards, the built package will be located under `dist/`. 53 | 54 | ## Code reviews 55 | All submissions, including submissions by project members, require review. We use GitHub 56 | [pull requests](https://github.com/ott-jax/ott/pulls) for this purpose. Consult 57 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. 58 | 59 | # Community guidelines 60 | We abide by the principles of openness, respect, and consideration of others of the Python Software Foundation's 61 | [code of conduct](https://www.python.org/psf/codeofconduct/). 62 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune docs 2 | prune tests 3 | prune .github 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | logo 2 | 3 | # Optimal Transport Tools (OTT) 4 | [![Downloads](https://static.pepy.tech/badge/ott-jax)](https://pypi.org/project/ott-jax/) 5 | [![Tests](https://img.shields.io/github/actions/workflow/status/ott-jax/ott/tests.yml?branch=main)](https://github.com/ott-jax/ott/actions/workflows/tests.yml) 6 | [![Docs](https://img.shields.io/readthedocs/ott-jax/latest)](https://ott-jax.readthedocs.io/en/latest/) 7 | [![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)](https://app.codecov.io/gh/ott-jax/ott) 8 | 9 | **See the [full documentation](https://ott-jax.readthedocs.io/en/latest/).** 10 | 11 | ## What is OTT-JAX? 12 | A ``JAX`` powered library to compute optimal transport at scale and on accelerators, ``OTT-JAX`` includes the fastest 13 | implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, momentum, acceleration, initializations) and extensions (low-rank, entropic maps). They can be used directly between two datasets, or within more advanced problems 14 | (Gromov-Wasserstein, barycenters). Some of ``JAX`` features, including 15 | [JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), 16 | [auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and 17 | [implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) 18 | work towards the goal of having end-to-end differentiable outputs. ``OTT-JAX`` is led by a team of researchers at Apple, with contributions from Google and Meta researchers, as well as many academic partners, including TU München, Oxford, ENSAE/IP Paris, ENS Paris and the Hebrew University. 19 | 20 | ## Installation 21 | Install ``OTT-JAX`` from [PyPI](https://pypi.org/project/ott-jax/) as: 22 | ```bash 23 | pip install ott-jax 24 | ``` 25 | or with ``conda`` via [conda-forge](https://anaconda.org/conda-forge/ott-jax) as: 26 | ```bash 27 | conda install -c conda-forge ott-jax 28 | ``` 29 | 30 | ## What is optimal transport? 31 | Optimal transport can be loosely described as the branch of mathematics and optimization that studies 32 | *matching problems*: given two families of points, and a cost function on pairs of points, find a "good" (low cost) way 33 | to associate bijectively to every point in the first family another in the second. 34 | 35 | Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally 36 | two sets of $n$ points using a pairwise cost can be solved with the 37 | [Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm), solving it costs an order of $O(n^3)$ 38 | operations, and lacks flexibility, since one may want to couple families of different sizes. 39 | 40 | Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous 41 | generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved 42 | so-called quadratic matching problems. 43 | 44 | In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly 45 | (2D vectors, compared with the squared Euclidean distance): 46 | 47 | ## Example 48 | ```python 49 | import jax 50 | import jax.numpy as jnp 51 | 52 | from ott.geometry import pointcloud 53 | from ott.problems.linear import linear_problem 54 | from ott.solvers.linear import sinkhorn 55 | 56 | # sample two point clouds and their weights. 57 | rngs = jax.random.split(jax.random.key(0), 4) 58 | n, m, d = 12, 14, 2 59 | x = jax.random.normal(rngs[0], (n,d)) + 1 60 | y = jax.random.uniform(rngs[1], (m,d)) 61 | a = jax.random.uniform(rngs[2], (n,)) 62 | b = jax.random.uniform(rngs[3], (m,)) 63 | a, b = a / jnp.sum(a), b / jnp.sum(b) 64 | # Computes the couplings using the Sinkhorn algorithm. 65 | geom = pointcloud.PointCloud(x, y) 66 | prob = linear_problem.LinearProblem(geom, a, b) 67 | 68 | solver = sinkhorn.Sinkhorn() 69 | out = solver(prob) 70 | ``` 71 | 72 | The call to `solver(prob)` above works out the optimal transport solution. The `out` object contains a transport matrix 73 | (here of size $12\times 14$) that quantifies the association strength between each point of the first point cloud, to one or 74 | more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost 75 | functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). 76 | 77 | ![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/docs/_static/images/couplings.png) 78 | 79 | ## Citation 80 | If you have found this work useful, please consider citing this reference: 81 | 82 | ``` 83 | @article{cuturi2022optimal, 84 | title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein}, 85 | author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and 86 | Davis, Geoff and Teboul, Olivier}, 87 | journal={arXiv preprint arXiv:2201.12324}, 88 | year={2022} 89 | } 90 | ``` 91 | ## See also 92 | The [moscot](https://moscot.readthedocs.io/en/latest/index.html) package for OT analysis of multi-omics data also uses OTT as a backbone. 93 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile clean 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | clean: 23 | @rm -rf $(BUILDDIR)/ 24 | @rm -rf $(SOURCEDIR)/_autosummary 25 | @rm -rf $(SOURCEDIR)/**/_autosummary 26 | -------------------------------------------------------------------------------- /docs/_static/images/colab-badge.svg: -------------------------------------------------------------------------------- 1 | Open in ColabOpen in Colab 2 | -------------------------------------------------------------------------------- /docs/_static/images/couplings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ott-jax/ott/d28d5d45b0bd26d2e3d54fe1085f2835dec5f5d6/docs/_static/images/couplings.png -------------------------------------------------------------------------------- /docs/_static/images/logoOTT.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ott-jax/ott/d28d5d45b0bd26d2e3d54fe1085f2835dec5f5d6/docs/_static/images/logoOTT.ico -------------------------------------------------------------------------------- /docs/_static/images/logoOTT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ott-jax/ott/d28d5d45b0bd26d2e3d54fe1085f2835dec5f5d6/docs/_static/images/logoOTT.png -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | {% block methods %} 7 | {%- if methods %} 8 | .. rubric:: {{ _('Methods') }} 9 | 10 | .. autosummary:: 11 | :toctree: . 12 | {% for item in methods %} 13 | {%- if item not in ['__init__', 'tree_flatten', 'tree_unflatten', 'bind', 'tabulate', 'module_paths'] %} 14 | ~{{ name }}.{{ item }} 15 | {%- endif %} 16 | {%- endfor %} 17 | {%- endif %} 18 | {%- endblock %} 19 | {% block attributes %} 20 | {%- if attributes %} 21 | .. rubric:: {{ _('Attributes') }} 22 | 23 | .. autosummary:: 24 | :toctree: . 25 | {% for item in attributes %} 26 | ~{{ name }}.{{ item }} 27 | {%- endfor %} 28 | {%- endif %} 29 | {% endblock %} 30 | -------------------------------------------------------------------------------- /docs/bibliography.rst: -------------------------------------------------------------------------------- 1 | Bibliography 2 | ============ 3 | 4 | .. bibliography:: 5 | :cited: 6 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing Guide 2 | ================== 3 | 4 | .. include:: ../CONTRIBUTING.md 5 | :parser: myst_parser.sphinx_ 6 | -------------------------------------------------------------------------------- /docs/experimental/index.rst: -------------------------------------------------------------------------------- 1 | ott.experimental 2 | ================ 3 | .. module:: ott.experimental 4 | 5 | The :mod:`ott.experimental` module groups experimental code that might be useful 6 | for users but whose API we expect to change significantly in coming months. 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | mmsinkhorn 12 | -------------------------------------------------------------------------------- /docs/experimental/mmsinkhorn.rst: -------------------------------------------------------------------------------- 1 | ott.experimental.mmsinkhorn 2 | =========================== 3 | .. module:: ott.experimental.mmsinkhorn 4 | .. currentmodule:: ott.experimental 5 | 6 | Solvers for multimarginal entropic OT problems, defined using :math:`k` point 7 | clouds of variable sizes in dimension :math:`d`, as proposed in 8 | :cite:`benamou:15`, and presented in :cite:`piran:24` (Algorithm 1). 9 | 10 | Multimarginal Sinkhorn 11 | ---------------------- 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | mmsinkhorn.MMSinkhorn 16 | mmsinkhorn.MMSinkhornOutput 17 | -------------------------------------------------------------------------------- /docs/geometry.rst: -------------------------------------------------------------------------------- 1 | ott.geometry 2 | ============ 3 | .. module:: ott.geometry 4 | .. currentmodule:: ott.geometry 5 | 6 | This package implements several classes to define a geometry, arguably the most 7 | influential ingredient of optimal transport problem. In its full generality, a 8 | :class:`~ott.geometry.geometry.Geometry` defines source points (input measure), 9 | target points (target measure) and a ground cost function (resp. a positive 10 | kernel function) that quantifies how expensive (resp. easy) it is to displace a 11 | unit of mass from any of the input points to the target points. 12 | 13 | The geometry package proposes a few simple geometries. The simplest of all would 14 | be that for which input and target points coincide, and the geometry between 15 | them simplifies to a symmetric cost or kernel matrix. In the very particular 16 | case where these points happen to lie on grid (a Cartesian product in full 17 | generality, e.g., 2- or-3-dimensional grids), the 18 | :class:`~ott.geometry.grid.Grid` geometry will prove useful. 19 | 20 | For more general settings where input/target points do not coincide, one can 21 | alternatively instantiate a :class:`~ott.geometry.geometry.Geometry` through a 22 | rectangular cost matrix. 23 | 24 | However, it is often preferable in applications to define ground costs 25 | "symbolically", by listing instead points in the input/target point clouds, to 26 | specify directly a cost *function* between them. Such functions should follow 27 | the :class:`~ott.geometry.costs.CostFn` class description. We provide a few 28 | standard cost functions that are meaningful in an OT context, notably the 29 | (unbalanced, regularized) Bures distances between Gaussians :cite:`janati:20`. 30 | That cost can be used for instance to compute a distance between Gaussian 31 | mixtures, as proposed in :cite:`chen:19a` and revisited in :cite:`delon:20`. 32 | 33 | To be useful with :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solvers, 34 | :class:`Geometries ` typically need to provide 35 | an ``epsilon`` regularization parameter. We propose either to set that value 36 | once for all, or implement an annealing 37 | :class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler. 38 | 39 | Geometries 40 | ---------- 41 | .. autosummary:: 42 | :toctree: _autosummary 43 | 44 | geometry.Geometry 45 | pointcloud.PointCloud 46 | grid.Grid 47 | graph.Graph 48 | geodesic.Geodesic 49 | low_rank.LRCGeometry 50 | low_rank.LRKGeometry 51 | epsilon_scheduler.Epsilon 52 | epsilon_scheduler.DEFAULT_EPSILON_SCALE 53 | 54 | Cost Functions 55 | -------------- 56 | .. autosummary:: 57 | :toctree: _autosummary 58 | 59 | costs.CostFn 60 | costs.TICost 61 | costs.SqPNorm 62 | costs.PNormP 63 | costs.SqEuclidean 64 | costs.RegTICost 65 | costs.Euclidean 66 | costs.EuclideanP 67 | costs.Cosine 68 | costs.Arccos 69 | costs.Bures 70 | costs.UnbalancedBures 71 | costs.SoftDTW 72 | distrib_costs.UnivariateWasserstein 73 | 74 | Regularizers 75 | ------------ 76 | .. autosummary:: 77 | :toctree: _autosummary 78 | 79 | regularizers.ProximalOperator 80 | regularizers.PostComposition 81 | regularizers.Regularization 82 | regularizers.Orthogonal 83 | regularizers.Quadratic 84 | regularizers.L1 85 | regularizers.SqL2 86 | regularizers.STVS 87 | regularizers.SqKOverlap 88 | 89 | Utilities 90 | --------- 91 | .. autosummary:: 92 | :toctree: _autosummary 93 | 94 | segment.segment_point_cloud 95 | -------------------------------------------------------------------------------- /docs/initializers/index.rst: -------------------------------------------------------------------------------- 1 | ott.initializers 2 | ================ 3 | .. module:: ott.initializers 4 | 5 | The :mod:`ott.initializers` module implement simple strategies to initialize 6 | solvers. For convex solvers, these initializations can be used to gain 7 | computational efficiency, but only have an impact in that respect. 8 | When used on more advanced and non-convex problems, these initializations 9 | play a far more important role. 10 | 11 | Two problems and their solvers fall in the convex category, those are the 12 | :class:`~ott.problems.linear.linear_problem.LinearProblem` solved with a 13 | :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver, or the fixed-support 14 | Wasserstein barycenter problems :cite:`cuturi:14` described in 15 | :class:`~ott.problems.linear.barycenter_problem.FixedBarycenterProblem` and 16 | solved with a :class:`~ott.solvers.linear.discrete_barycenter.FixedBarycenter` 17 | solver. 18 | 19 | When the problem is *not* convex, which describes pretty much all other pairings 20 | of problems/solvers in ``OTT``, notably the quadratic problem , initializers 21 | play a more important role: different initializations will very likely result 22 | in different end solutions. 23 | 24 | .. toctree:: 25 | :maxdepth: 2 26 | 27 | linear 28 | quadratic 29 | -------------------------------------------------------------------------------- /docs/initializers/linear.rst: -------------------------------------------------------------------------------- 1 | ott.initializers.linear 2 | ======================= 3 | .. module:: ott.initializers.linear 4 | .. currentmodule:: ott.initializers.linear 5 | 6 | Initializers for linear OT problems, focusing on Sinkhorn and low-rank solvers. 7 | 8 | Sinkhorn Initializers 9 | --------------------- 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | initializers.DefaultInitializer 14 | initializers.GaussianInitializer 15 | initializers.SortingInitializer 16 | initializers.SubsampleInitializer 17 | 18 | Low-rank Sinkhorn Initializers 19 | ------------------------------ 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | initializers_lr.LRInitializer 24 | initializers_lr.RandomInitializer 25 | initializers_lr.Rank2Initializer 26 | initializers_lr.KMeansInitializer 27 | initializers_lr.GeneralizedKMeansInitializer 28 | -------------------------------------------------------------------------------- /docs/initializers/quadratic.rst: -------------------------------------------------------------------------------- 1 | ott.initializers.quadratic 2 | ========================== 3 | .. module:: ott.initializers.quadratic 4 | .. currentmodule:: ott.initializers.quadratic 5 | 6 | Two families of initializers are described in the following to provide the first 7 | iteration of Gromov-Wasserstein solvers. They apply respectively to the simpler 8 | GW entropic solver :cite:`peyre:16`. 9 | 10 | Gromov-Wasserstein Initializers 11 | ------------------------------- 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | initializers.QuadraticInitializer 16 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/math.rst: -------------------------------------------------------------------------------- 1 | ott.math 2 | ======== 3 | .. currentmodule:: ott.math 4 | .. automodule:: ott.math 5 | 6 | The :mod:`ott.math` module holds low level computational primitives that 7 | appear in some more advanced optimal transport problems. 8 | :mod:`ott.math.fixed_point_loop` implements a fixed-point iteration `while` loop 9 | that can be automatically differentiated, and which might 10 | be of more general interest to other `JAX` users. 11 | :mod:`ott.math.matrix_square_root` contains an implementation of the 12 | matrix square-root using the Newton-Schulz iterations. That implementation is 13 | itself differentiable using either :term:`implicit differentiation` or 14 | :term:`unrolling` of the updates of these iterations. 15 | :mod:`ott.math.utils` contains various low-level routines re-implemented for 16 | their usage in `JAX`. Of particular interest are the custom jvp/vjp 17 | re-implementations for `logsumexp` and `norm` that have a behavior that differs, 18 | in terms of differentiability, from the standard `JAX` implementations. 19 | 20 | 21 | Fixed-point Iteration 22 | --------------------- 23 | .. autosummary:: 24 | :toctree: _autosummary 25 | 26 | fixed_point_loop.fixpoint_iter 27 | 28 | Matrix Square Root 29 | ------------------ 30 | .. autosummary:: 31 | :toctree: _autosummary 32 | 33 | matrix_square_root.sqrtm 34 | 35 | Miscellaneous 36 | ------------- 37 | .. autosummary:: 38 | :toctree: _autosummary 39 | 40 | utils.norm 41 | utils.logsumexp 42 | utils.softmin 43 | utils.lambertw 44 | -------------------------------------------------------------------------------- /docs/neural/datasets.rst: -------------------------------------------------------------------------------- 1 | ott.neural.datasets 2 | =================== 3 | .. module:: ott.neural.datasets 4 | .. currentmodule:: ott.neural 5 | 6 | The :mod:`ott.neural.datasets` contains datasets and needed for solving 7 | (conditional) neural optimal transport problems. 8 | 9 | Datasets 10 | -------- 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | datasets.OTData 15 | datasets.OTDataset 16 | -------------------------------------------------------------------------------- /docs/neural/index.rst: -------------------------------------------------------------------------------- 1 | ott.neural 2 | ========== 3 | .. module:: ott.neural 4 | 5 | In contrast to most methods presented in :mod:`ott.solvers`, which output 6 | vectors or matrices, the goal of the :mod:`ott.neural` module is to parameterize 7 | optimal transport maps and couplings as neural networks. These neural networks 8 | can generalize to new samples, in the sense that they can be conveniently 9 | evaluated outside training samples. This module implements layers, models 10 | and solvers to estimate such neural networks. 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | 15 | datasets 16 | methods 17 | networks 18 | -------------------------------------------------------------------------------- /docs/neural/methods.rst: -------------------------------------------------------------------------------- 1 | ott.neural.methods 2 | ================== 3 | .. module:: ott.neural.methods 4 | .. currentmodule:: ott.neural.methods 5 | 6 | Monge Gap 7 | --------- 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | monge_gap.monge_gap 12 | monge_gap.monge_gap_from_samples 13 | monge_gap.MongeGapEstimator 14 | 15 | Neural Dual 16 | ----------- 17 | .. autosummary:: 18 | :toctree: _autosummary 19 | 20 | neuraldual.W2NeuralDual 21 | expectile_neural_dual.ExpectileNeuralDual 22 | expectile_neural_dual.ENOTPotentials 23 | 24 | ott.neural.methods.flows 25 | ======================== 26 | .. module:: ott.neural.methods.flows 27 | .. currentmodule:: ott.neural.methods.flows 28 | 29 | Flows 30 | ----- 31 | .. autosummary:: 32 | :toctree: _autosummary 33 | 34 | otfm.OTFlowMatching 35 | genot.GENOT 36 | dynamics.BaseFlow 37 | dynamics.StraightFlow 38 | dynamics.ConstantNoiseFlow 39 | dynamics.BrownianBridge 40 | -------------------------------------------------------------------------------- /docs/neural/networks.rst: -------------------------------------------------------------------------------- 1 | ott.neural.networks 2 | =================== 3 | .. module:: ott.neural.networks 4 | .. currentmodule:: ott.neural.networks 5 | 6 | Networks 7 | -------- 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | icnn.ICNN 12 | velocity_field.VelocityField 13 | potentials.BasePotential 14 | potentials.PotentialMLP 15 | potentials.MLP 16 | potentials.PotentialTrainState 17 | 18 | 19 | ott.neural.networks.layers 20 | ========================== 21 | .. module:: ott.neural.networks.layers 22 | .. currentmodule:: ott.neural.networks.layers 23 | 24 | Layers 25 | ------ 26 | .. autosummary:: 27 | :toctree: _autosummary 28 | 29 | conjugate.FenchelConjugateSolver 30 | conjugate.FenchelConjugateLBFGS 31 | conjugate.ConjugateResults 32 | posdef.PositiveDense 33 | posdef.PosDefPotentials 34 | time_encoder.cyclical_time_encoder 35 | -------------------------------------------------------------------------------- /docs/problems/index.rst: -------------------------------------------------------------------------------- 1 | ott.problems 2 | ============ 3 | .. module:: ott.problems 4 | 5 | The :mod:`ott.problems` module describes the low level optimal transport 6 | problems that are solved by :mod:`ott.solvers`. 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | linear 12 | quadratic 13 | -------------------------------------------------------------------------------- /docs/problems/linear.rst: -------------------------------------------------------------------------------- 1 | ott.problems.linear 2 | =================== 3 | .. module:: ott.problems.linear 4 | .. currentmodule:: ott.problems.linear 5 | 6 | The :mod:`ott.problems.linear` describes the simplest family of optimal 7 | transport problems, those that involve computing the Kantorovich problem itself, 8 | also known as the linear optimal transport problem, or, more generally, 9 | objective functions that are sums of such optimal transport costs, which 10 | includes the two variants of Wasserstein barycenter problems. 11 | 12 | The module also holds dual potential variables, a class of functions that act 13 | as optimization variables for the dual optimal transport problem. 14 | 15 | OT Problems 16 | ----------- 17 | .. autosummary:: 18 | :toctree: _autosummary 19 | 20 | linear_problem.LinearProblem 21 | barycenter_problem.FixedBarycenterProblem 22 | barycenter_problem.FreeBarycenterProblem 23 | 24 | Dual Potentials 25 | --------------- 26 | .. autosummary:: 27 | :toctree: _autosummary 28 | 29 | potentials.DualPotentials 30 | potentials.EntropicPotentials 31 | -------------------------------------------------------------------------------- /docs/problems/quadratic.rst: -------------------------------------------------------------------------------- 1 | ott.problems.quadratic 2 | ====================== 3 | .. module:: ott.problems.quadratic 4 | .. currentmodule:: ott.problems.quadratic 5 | 6 | The :mod:`ott.problems.quadratic` module describes the quadratic assignment 7 | problem and its generalizations, including notably the fused-problem (including 8 | a linear term) and the more advanced GW barycenter problem. 9 | 10 | OT Problems 11 | ----------- 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | quadratic_problem.QuadraticProblem 16 | gw_barycenter.GWBarycenterProblem 17 | 18 | Costs 19 | ----- 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | quadratic_costs.GWLoss 24 | quadratic_costs.make_square_loss 25 | quadratic_costs.make_kl_loss 26 | -------------------------------------------------------------------------------- /docs/solvers/index.rst: -------------------------------------------------------------------------------- 1 | ott.solvers 2 | =========== 3 | .. module:: ott.solvers 4 | 5 | The :mod:`ott.solvers` module contains the main algorithmic engines of the 6 | ``OTT`` package. The biggest component in this module are without a doubt the 7 | linear solvers in :mod:`ott.solvers.linear`, designed to solve linear OT 8 | problems. More advanced solvers, notably quadratic in 9 | :mod:`ott.solvers.quadratic`, rely on calls to linear solvers as subroutines. 10 | That property itself is implemented in the more abstract 11 | :class:`~ott.solvers.was_solver.WassersteinSolver` class, which provides a 12 | lower-level template at the interface between the two. 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | 17 | linear 18 | quadratic 19 | 20 | Wasserstein Solver 21 | ------------------ 22 | .. autosummary:: 23 | :toctree: _autosummary 24 | 25 | was_solver.WassersteinSolver 26 | 27 | Utilities 28 | --------- 29 | .. autosummary:: 30 | :toctree: _autosummary 31 | 32 | utils.match_linear 33 | utils.match_quadratic 34 | utils.sample_joint 35 | utils.sample_conditional 36 | utils.uniform_sampler 37 | -------------------------------------------------------------------------------- /docs/solvers/linear.rst: -------------------------------------------------------------------------------- 1 | ott.solvers.linear 2 | ================== 3 | .. module:: ott.solvers.linear 4 | .. currentmodule:: ott.solvers.linear 5 | 6 | Linear solvers are the bread-and-butter of OT solvers. They can be called on 7 | their own, either the Sinkhorn 8 | :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or Low-Rank 9 | :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solvers, to match two 10 | datasets. They also appear as subroutines for more advanced solvers in the 11 | :mod:`ott.solvers` module, notably :mod:`ott.solvers.quadratic`. 12 | 13 | Sinkhorn Solvers 14 | ---------------- 15 | .. autosummary:: 16 | :toctree: _autosummary 17 | 18 | solve 19 | sinkhorn.Sinkhorn 20 | sinkhorn.SinkhornState 21 | sinkhorn.SinkhornOutput 22 | sinkhorn_lr.LRSinkhorn 23 | sinkhorn_lr.LRSinkhornState 24 | sinkhorn_lr.LRSinkhornOutput 25 | 26 | Barycenter Solvers 27 | ------------------ 28 | .. autosummary:: 29 | :toctree: _autosummary 30 | 31 | continuous_barycenter.FreeWassersteinBarycenter 32 | continuous_barycenter.FreeBarycenterState 33 | discrete_barycenter.FixedBarycenter 34 | discrete_barycenter.SinkhornBarycenterOutput 35 | 36 | Univariate Solvers 37 | ------------------ 38 | .. autosummary:: 39 | :toctree: _autosummary 40 | 41 | solve_univariate 42 | univariate.uniform_solver 43 | univariate.quantile_solver 44 | univariate.north_west_solver 45 | univariate.UnivariateOutput 46 | 47 | Sinkhorn Acceleration 48 | --------------------- 49 | .. autosummary:: 50 | :toctree: _autosummary 51 | 52 | acceleration.Momentum 53 | acceleration.AndersonAcceleration 54 | 55 | Implicit Differentiation 56 | ------------------------ 57 | .. autosummary:: 58 | :toctree: _autosummary 59 | 60 | implicit_differentiation.ImplicitDiff 61 | implicit_differentiation.solve_jax_cg 62 | lineax_implicit.solve_lineax 63 | 64 | Low-rank Sinkhorn Utilities 65 | --------------------------- 66 | .. autosummary:: 67 | :toctree: _autosummary 68 | 69 | lr_utils.unbalanced_dykstra_lse 70 | lr_utils.unbalanced_dykstra_kernel 71 | -------------------------------------------------------------------------------- /docs/solvers/quadratic.rst: -------------------------------------------------------------------------------- 1 | ott.solvers.quadratic 2 | ===================== 3 | .. module:: ott.solvers.quadratic 4 | .. currentmodule:: ott.solvers.quadratic 5 | 6 | The :mod:`ott.solvers.quadratic` module holds the important family of 7 | Gromov-Wasserstein (GW) solvers and related GW barycenters. They are designed to 8 | solve :mod:`ott.problems.quadratic` problems. 9 | 10 | Gromov-Wasserstein Solvers 11 | -------------------------- 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | solve 16 | gromov_wasserstein.GromovWasserstein 17 | gromov_wasserstein.GWOutput 18 | gromov_wasserstein_lr.LRGromovWasserstein 19 | gromov_wasserstein_lr.LRGWOutput 20 | lower_bound.third_lower_bound 21 | 22 | 23 | Barycenter Solvers 24 | ------------------ 25 | .. autosummary:: 26 | :toctree: _autosummary 27 | 28 | gw_barycenter.GWBarycenterState 29 | gw_barycenter.GromovWassersteinBarycenter 30 | -------------------------------------------------------------------------------- /docs/spelling/misc.txt: -------------------------------------------------------------------------------- 1 | Eulerian 2 | Utils 3 | alg 4 | arg 5 | args 6 | coef 7 | cond 8 | dtype 9 | eps 10 | eq 11 | euler 12 | fn 13 | fns 14 | fu 15 | gaussian 16 | gmm 17 | gv 18 | init 19 | iter 20 | iters 21 | jnp 22 | jvp 23 | kwargs 24 | laplacian 25 | loc 26 | logsumexp 27 | lse 28 | maxiter 29 | ndarray 30 | nicolson 31 | num 32 | params 33 | pre 34 | rng 35 | rngs 36 | sqeucl 37 | th 38 | topk 39 | vec 40 | vjp 41 | wolfe 42 | xy 43 | yy 44 | -------------------------------------------------------------------------------- /docs/spelling/technical.txt: -------------------------------------------------------------------------------- 1 | Barycenters 2 | Brenier 3 | Bures 4 | Chebyshev 5 | Cholesky 6 | Conformalize 7 | DTW 8 | Danskin 9 | Datasets 10 | Dykstra 11 | Expectile 12 | Fenchel 13 | Frobenius 14 | Gangbo 15 | Gangbo-McCann 16 | Gaussians 17 | Gromov 18 | Hessians 19 | Higham 20 | Jacobian 21 | Jacobians 22 | Kantorovich 23 | Kullback 24 | Leibler 25 | Mahalanobis 26 | McCann 27 | Monge 28 | Moreau 29 | Postcomposition 30 | SGD 31 | Schrödinger 32 | Schur 33 | Seidel 34 | Sinkhorn 35 | UNet 36 | Unbalancedness 37 | Wasserstein 38 | adaptively 39 | backend 40 | backpropagates 41 | backpropagation 42 | barycenter 43 | barycenters 44 | barycentric 45 | bijective 46 | binarized 47 | boolean 48 | centroids 49 | checkpointing 50 | chromatin 51 | collinear 52 | combinatorial 53 | conformally 54 | covariance 55 | covariances 56 | covariates 57 | dataclass 58 | dataloaders 59 | dataset 60 | datasets 61 | debiased 62 | debiasing 63 | differentiability 64 | dimensionality 65 | discretization 66 | discretize 67 | downweighted 68 | dualize 69 | dualizing 70 | duals 71 | eigendecomposition 72 | elementwise 73 | embeddings 74 | entropic 75 | epigenetic 76 | expectile 77 | featurized 78 | grayscale 79 | heterogeneous 80 | histone 81 | hyperparameter 82 | hyperparameters 83 | iPSCs 84 | iPSCs 85 | iff 86 | initializations 87 | initializer 88 | initializers 89 | instantiation 90 | invertible 91 | iso 92 | iteratively 93 | jax 94 | jit 95 | jitting 96 | linearization 97 | linearized 98 | logit 99 | macOS 100 | methylation 101 | minimizer 102 | minimizers 103 | multimarginal 104 | neuroimaging 105 | normed 106 | numerics 107 | omics 108 | optimality 109 | overfitting 110 | parallelization 111 | parallelize 112 | parameterization 113 | parameterizing 114 | piecewise 115 | pluripotent 116 | polymatching 117 | polynomials 118 | polytope 119 | positivity 120 | postfix 121 | potentials 122 | precompile 123 | precompute 124 | precomputes 125 | preconditioner 126 | preprocess 127 | preprocessing 128 | proteome 129 | prox 130 | pytree 131 | quantile 132 | quantiles 133 | quantizes 134 | recenter 135 | recentered 136 | regularizer 137 | regularizers 138 | reimplementation 139 | renormalize 140 | renormalized 141 | reparameterization 142 | reproducibility 143 | rescale 144 | rescaled 145 | rescaling 146 | reweighted 147 | reweighting 148 | reweightings 149 | runtime 150 | scRNA 151 | scalings 152 | semidefinite 153 | sigmoid 154 | simplicial 155 | softplus 156 | stateful 157 | sublinear 158 | submodule 159 | submodules 160 | suboptimal 161 | subpopulation 162 | subpopulations 163 | subsample 164 | subsampled 165 | subsamples 166 | subsampling 167 | thresholding 168 | transcriptome 169 | undirected 170 | univariate 171 | unnormalized 172 | unregularized 173 | unscaled 174 | url 175 | vectorized 176 | voxel 177 | -------------------------------------------------------------------------------- /docs/tools.rst: -------------------------------------------------------------------------------- 1 | ott.tools 2 | ========= 3 | .. module:: ott.tools 4 | .. currentmodule:: ott.tools 5 | 6 | The :mod:`~ott.tools` package contains high level functions that build on 7 | outputs produced by lower-level components in the toolbox, such as 8 | :mod:`~ott.solvers`. 9 | 10 | In particular, we provide user-friendly APIs to unregularized OT quantities, 11 | such as the :term:`Wasserstein distance` for two point clouds of the same size. 12 | We also provide functions to pad efficiently point clouds when doing large scale 13 | OT between them in parallel, implementations of the Sinkhorn 14 | divergence :cite:`genevay:18,sejourne:19`, sliced Wasserstein distances 15 | :cite:`rabin:12`, differentiable approximations to ranks and quantile functions 16 | :cite:`cuturi:19`, and various tools to study Gaussians with the 17 | 2-:term:`Wasserstein distance` :cite:`gelbrich:90,delon:20`. 18 | 19 | Unregularized Optimal Transport 20 | ------------------------------- 21 | .. autosummary:: 22 | :toctree: _autosummary 23 | 24 | unreg.hungarian 25 | unreg.HungarianOutput 26 | unreg.wassdis_p 27 | 28 | 29 | Segmented Sinkhorn 30 | ------------------ 31 | .. autosummary:: 32 | :toctree: _autosummary 33 | 34 | segment_sinkhorn.segment_sinkhorn 35 | 36 | Sinkhorn Divergence 37 | ------------------- 38 | .. autosummary:: 39 | :toctree: _autosummary 40 | 41 | sinkhorn_divergence.sinkdiv 42 | sinkhorn_divergence.sinkhorn_divergence 43 | sinkhorn_divergence.SinkhornDivergenceOutput 44 | sinkhorn_divergence.segment_sinkhorn_divergence 45 | 46 | Sliced Wasserstein Distance 47 | --------------------------- 48 | .. autosummary:: 49 | :toctree: _autosummary 50 | 51 | sliced.random_proj_sphere 52 | sliced.sliced_wasserstein 53 | 54 | ProgOT 55 | ------ 56 | .. autosummary:: 57 | :toctree: _autosummary 58 | 59 | progot.ProgOT 60 | progot.ProgOTOutput 61 | progot.get_alpha_schedule 62 | progot.get_epsilon_schedule 63 | 64 | Conformal Prediction 65 | -------------------- 66 | .. autosummary:: 67 | :toctree: _autosummary 68 | 69 | conformal.OTCP 70 | conformal.sobol_ball_sampler 71 | 72 | Soft Sorting Algorithms 73 | ----------------------- 74 | .. autosummary:: 75 | :toctree: _autosummary 76 | 77 | soft_sort.multivariate_cdf_quantile_maps 78 | soft_sort.quantile 79 | soft_sort.quantile_normalization 80 | soft_sort.quantize 81 | soft_sort.ranks 82 | soft_sort.sort 83 | soft_sort.sort_with 84 | soft_sort.topk_mask 85 | 86 | Clustering 87 | ---------- 88 | .. autosummary:: 89 | :toctree: _autosummary 90 | 91 | k_means.k_means 92 | k_means.KMeansOutput 93 | 94 | Plotting 95 | -------- 96 | .. autosummary:: 97 | :toctree: _autosummary 98 | 99 | plot.Plot 100 | 101 | ott.tools.gaussian_mixture package 102 | ---------------------------------- 103 | .. currentmodule:: ott.tools.gaussian_mixture 104 | .. automodule:: ott.tools.gaussian_mixture 105 | 106 | This package implements various tools to manipulate Gaussian mixtures with a 107 | slightly modified Wasserstein geometry: here a Gaussian mixture is no longer 108 | strictly regarded as a density :math:`\mathbb{R}^d`, but instead as a point 109 | cloud in the space of Gaussians in :math:`\mathbb{R}^d`. This viewpoint provides 110 | a new approach to compare, and fit Gaussian mixtures, as described for instance 111 | in :cite:`delon:20` and references therein. 112 | 113 | Gaussian Mixtures 114 | ^^^^^^^^^^^^^^^^^ 115 | .. autosummary:: 116 | :toctree: _autosummary 117 | 118 | gaussian.Gaussian 119 | gaussian_mixture.GaussianMixture 120 | gaussian_mixture_pair.GaussianMixturePair 121 | fit_gmm.initialize 122 | fit_gmm.fit_model_em 123 | fit_gmm_pair.get_fit_model_em_fn 124 | -------------------------------------------------------------------------------- /docs/tutorials/barycenter/index.rst: -------------------------------------------------------------------------------- 1 | Barycenter 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :glob: 7 | 8 | * 9 | -------------------------------------------------------------------------------- /docs/tutorials/geometry/index.rst: -------------------------------------------------------------------------------- 1 | Geometry 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :glob: 7 | 8 | * 9 | -------------------------------------------------------------------------------- /docs/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | geometry/index 8 | linear/index 9 | quadratic/index 10 | neural/index 11 | barycenter/index 12 | misc/index 13 | -------------------------------------------------------------------------------- /docs/tutorials/linear/index.rst: -------------------------------------------------------------------------------- 1 | Linear OT 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :glob: 7 | 8 | * 9 | -------------------------------------------------------------------------------- /docs/tutorials/misc/index.rst: -------------------------------------------------------------------------------- 1 | Miscellaneous 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :glob: 7 | 8 | * 9 | -------------------------------------------------------------------------------- /docs/tutorials/neural/index.rst: -------------------------------------------------------------------------------- 1 | Neural OT 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :glob: 7 | 8 | * 9 | -------------------------------------------------------------------------------- /docs/tutorials/quadratic/index.rst: -------------------------------------------------------------------------------- 1 | Quadratic OT 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :glob: 7 | 8 | * 9 | -------------------------------------------------------------------------------- /docs/utils.rst: -------------------------------------------------------------------------------- 1 | ott.utils 2 | ========= 3 | .. module:: ott.utils 4 | .. currentmodule:: ott.utils 5 | 6 | This package contains miscellaneous functions, e.g., progress callback 7 | function for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | default_progress_fn 13 | tqdm_progress_fn 14 | batched_vmap 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from setuptools import setup 15 | 16 | # for packaging tools not supporting, e.g., PEP 517, PEP 660 17 | setup() 18 | -------------------------------------------------------------------------------- /src/ott/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import contextlib 15 | 16 | from . import ( 17 | datasets, 18 | experimental, 19 | geometry, 20 | initializers, 21 | math, 22 | problems, 23 | solvers, 24 | tools, 25 | utils, 26 | ) 27 | 28 | with contextlib.suppress(ImportError): 29 | from . import neural 30 | 31 | from ._version import __version__ 32 | 33 | del contextlib 34 | -------------------------------------------------------------------------------- /src/ott/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from importlib.metadata import PackageNotFoundError, version 15 | 16 | try: 17 | __version__ = version("ott-jax") 18 | except PackageNotFoundError: 19 | __version__ = "" 20 | 21 | del version, PackageNotFoundError 22 | -------------------------------------------------------------------------------- /src/ott/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import dataclasses 15 | from typing import Iterator, Literal, NamedTuple, Optional, Tuple 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | __all__ = ["create_gaussian_mixture_samplers", "Dataset", "GaussianMixture"] 22 | 23 | from ott import utils 24 | 25 | Name_t = Literal["simple", "circle", "square_five", "square_four"] 26 | 27 | 28 | class Dataset(NamedTuple): 29 | r"""Samplers from source and target measures. 30 | 31 | Args: 32 | source_iter: loader for the source measure 33 | target_iter: loader for the target measure 34 | """ 35 | source_iter: Iterator[jnp.ndarray] 36 | target_iter: Iterator[jnp.ndarray] 37 | 38 | 39 | @dataclasses.dataclass 40 | class GaussianMixture: 41 | """A mixture of Gaussians. 42 | 43 | Args: 44 | name: the name specifying the centers of the mixture components: 45 | 46 | - ``simple`` - data clustered in one center, 47 | - ``circle`` - two-dimensional Gaussians arranged on a circle, 48 | - ``square_five`` - two-dimensional Gaussians on a square with 49 | one Gaussian in the center, and 50 | - ``square_four`` - two-dimensional Gaussians in the corners of a 51 | rectangle 52 | 53 | batch_size: batch size of the samples 54 | rng: initial PRNG key 55 | scale: scale of the Gaussian means 56 | std: the standard deviation of the individual Gaussian samples 57 | """ 58 | name: Name_t 59 | batch_size: int 60 | rng: jax.Array 61 | scale: float = 5.0 62 | std: float = 0.5 63 | 64 | def __post_init__(self) -> None: 65 | gaussian_centers = { 66 | "simple": 67 | np.array([[0, 0]]), 68 | "circle": 69 | np.array([ 70 | (1, 0), 71 | (-1, 0), 72 | (0, 1), 73 | (0, -1), 74 | (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 75 | (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 76 | (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 77 | (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 78 | ]), 79 | "square_five": 80 | np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]]), 81 | "square_four": 82 | np.array([[1, 0], [0, 1], [-1, 0], [0, -1]]), 83 | } 84 | if self.name not in gaussian_centers: 85 | raise ValueError( 86 | f"{self.name} is not a valid dataset for GaussianMixture" 87 | ) 88 | self.centers = gaussian_centers[self.name] 89 | 90 | def __iter__(self) -> Iterator[jnp.array]: 91 | """Random sample generator from Gaussian mixture. 92 | 93 | Returns: 94 | A generator of samples from the Gaussian mixture. 95 | """ 96 | return self._create_sample_generators() 97 | 98 | def _create_sample_generators(self) -> Iterator[jnp.array]: 99 | rng = self.rng 100 | while True: 101 | rng1, rng2, rng = jax.random.split(rng, 3) 102 | means = jax.random.choice(rng1, self.centers, (self.batch_size,)) 103 | normal_samples = jax.random.normal(rng2, (self.batch_size, 2)) 104 | samples = self.scale * means + (self.std ** 2) * normal_samples 105 | yield samples 106 | 107 | 108 | def create_gaussian_mixture_samplers( 109 | name_source: Name_t, 110 | name_target: Name_t, 111 | train_batch_size: int = 2048, 112 | valid_batch_size: int = 2048, 113 | rng: Optional[jax.Array] = None, 114 | ) -> Tuple[Dataset, Dataset, int]: 115 | """Gaussian samplers. 116 | 117 | Args: 118 | name_source: name of the source sampler 119 | name_target: name of the target sampler 120 | train_batch_size: the training batch size 121 | valid_batch_size: the validation batch size 122 | rng: initial PRNG key 123 | 124 | Returns: 125 | The dataset and dimension of the data. 126 | """ 127 | rng = utils.default_prng_key(rng) 128 | rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) 129 | train_dataset = Dataset( 130 | source_iter=iter( 131 | GaussianMixture(name_source, batch_size=train_batch_size, rng=rng1) 132 | ), 133 | target_iter=iter( 134 | GaussianMixture(name_target, batch_size=train_batch_size, rng=rng2) 135 | ) 136 | ) 137 | valid_dataset = Dataset( 138 | source_iter=iter( 139 | GaussianMixture(name_source, batch_size=valid_batch_size, rng=rng3) 140 | ), 141 | target_iter=iter( 142 | GaussianMixture(name_target, batch_size=valid_batch_size, rng=rng4) 143 | ) 144 | ) 145 | dim_data = 2 146 | return train_dataset, valid_dataset, dim_data 147 | -------------------------------------------------------------------------------- /src/ott/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import mmsinkhorn 15 | -------------------------------------------------------------------------------- /src/ott/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import ( 15 | costs, 16 | distrib_costs, 17 | epsilon_scheduler, 18 | geodesic, 19 | geometry, 20 | graph, 21 | grid, 22 | pointcloud, 23 | regularizers, 24 | segment, 25 | ) 26 | -------------------------------------------------------------------------------- /src/ott/geometry/distrib_costs.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional 15 | 16 | import jax.numpy as jnp 17 | import jax.tree_util as jtu 18 | 19 | from ott.geometry import costs, pointcloud 20 | from ott.problems.linear import linear_problem 21 | from ott.solvers.linear import univariate 22 | 23 | __all__ = ["UnivariateWasserstein"] 24 | 25 | 26 | @jtu.register_pytree_node_class 27 | class UnivariateWasserstein(costs.CostFn): 28 | """1D Wasserstein cost for two 1D distributions. 29 | 30 | This ground cost between considers vectors as a family of values. 31 | The Wasserstein distance between them is the 1D OT cost, using a user-defined 32 | ground cost. 33 | 34 | Args: 35 | solve_fn: 1D optimal transport solver, e.g., 36 | :func:`~ott.solvers.linear.univariate.uniform_distance`. 37 | ground_cost: Cost used to compute the 1D optimal transport between vectors. 38 | Should be a translation-invariant (TI) cost for correctness. 39 | If :obj:`None`, defaults to :class:`~ott.geometry.costs.SqEuclidean`. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | solve_fn: Callable[[linear_problem.LinearProblem], 45 | univariate.UnivariateOutput], 46 | ground_cost: Optional[costs.TICost] = None, 47 | ): 48 | super().__init__() 49 | self.ground_cost = ( 50 | costs.SqEuclidean() if ground_cost is None else ground_cost 51 | ) 52 | self._solve_fn = solve_fn 53 | 54 | def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: 55 | """Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist. 56 | 57 | Args: 58 | x: Array of shape ``[n,]``. 59 | y: Array of shape ``[m,]``. 60 | 61 | Returns: 62 | The transport cost. 63 | """ 64 | geom = pointcloud.PointCloud( 65 | x[:, None], y[:, None], cost_fn=self.ground_cost 66 | ) 67 | prob = linear_problem.LinearProblem(geom) 68 | out = self._solve_fn(prob) 69 | return jnp.squeeze(out.ot_costs) 70 | 71 | def tree_flatten(self): # noqa: D102 72 | return (self.ground_cost,), (self._solve_fn,) 73 | 74 | @classmethod 75 | def tree_unflatten(cls, aux_data, children): # noqa: D102 76 | return cls(solve_fn=aux_data[0], ground_cost=children[0]) 77 | -------------------------------------------------------------------------------- /src/ott/geometry/epsilon_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | import jax.numpy as jnp 17 | import jax.tree_util as jtu 18 | 19 | __all__ = ["Epsilon", "DEFAULT_EPSILON_SCALE"] 20 | 21 | #: Scaling applied to statistic (mean/std) of cost to compute default epsilon. 22 | DEFAULT_EPSILON_SCALE = 0.05 23 | 24 | 25 | @jtu.register_pytree_node_class 26 | class Epsilon: 27 | r"""Scheduler class for the regularization parameter epsilon. 28 | 29 | An epsilon scheduler outputs a regularization strength, to be used by the 30 | :term:`Sinkhorn algorithm` or variant, at any iteration count. That value is 31 | either the final, targeted regularization, or one that is larger, obtained by 32 | geometric decay of an initial multiplier. 33 | 34 | Args: 35 | target: The epsilon regularizer that is targeted. 36 | init: Initial value when using epsilon scheduling, understood as a multiple 37 | of the ``target``, following :math:`\text{init} \text{decay}^{\text{it}}`. 38 | decay: Geometric decay factor, :math:`\leq 1`. 39 | """ 40 | 41 | def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0): 42 | assert decay <= 1.0, f"Decay must be <= 1, found {decay}." 43 | self.target = target 44 | self.init = init 45 | self.decay = decay 46 | 47 | def __call__(self, it: Optional[int]) -> jnp.array: 48 | """Intermediate regularizer value at a given iteration number. 49 | 50 | Args: 51 | it: Current iteration. If :obj:`None`, return :attr:`target`. 52 | 53 | Returns: 54 | The epsilon value at the iteration. 55 | """ 56 | if it is None: 57 | return self.target 58 | # the multiple is either 1.0 or a larger init value that is decayed. 59 | multiple = jnp.maximum(self.init * (self.decay ** it), 1.0) 60 | return multiple * self.target 61 | 62 | def __repr__(self) -> str: 63 | return ( 64 | f"{self.__class__.__name__}(target={self.target:.4f}, " 65 | f"init={self.init:.4f}, decay={self.decay:.4f})" 66 | ) 67 | 68 | def tree_flatten(self): # noqa: D102 69 | return (self.target,), {"init": self.init, "decay": self.decay} 70 | 71 | @classmethod 72 | def tree_unflatten(cls, aux_data, children): # noqa: D102 73 | return cls(*children, **aux_data) 74 | -------------------------------------------------------------------------------- /src/ott/initializers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import contextlib 15 | 16 | from . import linear, quadratic 17 | 18 | with contextlib.suppress(ImportError): 19 | from . import neural 20 | 21 | del contextlib 22 | -------------------------------------------------------------------------------- /src/ott/initializers/linear/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import initializers, initializers_lr 15 | -------------------------------------------------------------------------------- /src/ott/initializers/neural/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import meta_initializer 15 | -------------------------------------------------------------------------------- /src/ott/initializers/quadratic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import initializers 15 | -------------------------------------------------------------------------------- /src/ott/math/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import fixed_point_loop, matrix_square_root, unbalanced_functions, utils 15 | -------------------------------------------------------------------------------- /src/ott/math/unbalanced_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable 15 | 16 | import jax.numpy as jnp 17 | 18 | 19 | def phi_star(h: jnp.ndarray, rho: float) -> jnp.ndarray: 20 | """Legendre transform of KL, :cite:`sejourne:19`, p. 9.""" 21 | return rho * (jnp.exp(h / rho) - 1) 22 | 23 | 24 | def derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: 25 | """Derivative of Legendre transform of phi_starKL, see phi_star.""" 26 | # TODO(cuturi): use jax.grad directly. 27 | return jnp.exp(f / rho) 28 | 29 | 30 | def grad_of_marginal_fit( 31 | c: jnp.ndarray, h: jnp.ndarray, tau: float, epsilon: float 32 | ) -> jnp.ndarray: 33 | """Compute grad of terms linked to marginals in objective. 34 | 35 | Computes gradient w.r.t. f ( or g) of terms in :cite:`sejourne:19`, 36 | left-hand-side of eq. 15 terms involving phi_star). 37 | 38 | Args: 39 | c: jnp.ndarray, first target marginal (either a or b in practice) 40 | h: jnp.ndarray, potential (either f or g in practice) 41 | tau: float, strength (in ]0,1]) of regularizer w.r.t. marginal 42 | epsilon: regularization 43 | 44 | Returns: 45 | a vector of the same size as c or h 46 | """ 47 | if tau == 1.0: 48 | return c 49 | r = rho(epsilon, tau) 50 | return jnp.where(c > 0, c * derivative_phi_star(-h, r), 0.0) 51 | 52 | 53 | def second_derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: 54 | """Second Derivative of Legendre transform of KL, see phi_star.""" 55 | return jnp.exp(f / rho) / rho 56 | 57 | 58 | def diag_jacobian_of_marginal_fit( 59 | c: jnp.ndarray, h: jnp.ndarray, tau: float, epsilon: float, 60 | derivative: Callable[[jnp.ndarray, float], jnp.ndarray] 61 | ): 62 | """Compute grad of terms linked to marginals in objective. 63 | 64 | Computes second derivative w.r.t. f ( or g) of terms in :cite:`sejourne:19`, 65 | left-hand-side of eq. 32 (terms involving phi_star) 66 | 67 | Args: 68 | c: jnp.ndarray, first target marginal (either a or b in practice) 69 | h: jnp.ndarray, potential (either f or g in practice) 70 | tau: float, strength (in ]0,1]) of regularizer w.r.t. marginal 71 | epsilon: regularization 72 | derivative: Callable 73 | 74 | Returns: 75 | a vector of the same size as c or h. 76 | """ 77 | if tau == 1.0: 78 | return 0.0 79 | 80 | r = rho(epsilon, tau) 81 | # here no minus sign because we are taking derivative w.r.t -h 82 | return jnp.where( 83 | c > 0, 84 | c * second_derivative_phi_star(-h, r) * 85 | derivative(c * derivative_phi_star(-h, r)), 0.0 86 | ) 87 | 88 | 89 | def rho(epsilon: float, tau: float) -> float: # noqa: D103 90 | return (epsilon * tau) / (1.0 - tau) 91 | -------------------------------------------------------------------------------- /src/ott/neural/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import datasets, methods, networks 15 | -------------------------------------------------------------------------------- /src/ott/neural/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import collections 15 | import dataclasses 16 | from typing import Any, Dict, Optional, Sequence 17 | 18 | import numpy as np 19 | 20 | __all__ = ["OTData", "OTDataset"] 21 | 22 | Item_t = Dict[str, np.ndarray] 23 | 24 | 25 | @dataclasses.dataclass(repr=False, frozen=True) 26 | class OTData: 27 | """Distribution data for (conditional) optimal transport problems. 28 | 29 | Args: 30 | lin: Linear term of the samples. 31 | quad: Quadratic term of the samples. 32 | condition: Condition corresponding to the data distribution. 33 | """ 34 | lin: Optional[np.ndarray] = None 35 | quad: Optional[np.ndarray] = None 36 | condition: Optional[np.ndarray] = None 37 | 38 | def __getitem__(self, ix: int) -> Item_t: 39 | return {k: v[ix] for k, v in self.__dict__.items() if v is not None} 40 | 41 | def __len__(self) -> int: 42 | if self.lin is not None: 43 | return len(self.lin) 44 | if self.quad is not None: 45 | return len(self.quad) 46 | return 0 47 | 48 | 49 | class OTDataset: 50 | """Dataset for optimal transport problems. 51 | 52 | Args: 53 | src_data: Samples from the source distribution. 54 | tgt_data: Samples from the target distribution. 55 | src_conditions: Conditions for the source data. 56 | tgt_conditions: Conditions for the target data. 57 | is_aligned: Whether the samples from the source and the target data 58 | are paired. If yes, the source and the target conditions must match. 59 | seed: Random seed used to match source and target when not aligned. 60 | """ 61 | SRC_PREFIX = "src" 62 | TGT_PREFIX = "tgt" 63 | 64 | def __init__( 65 | self, 66 | src_data: OTData, 67 | tgt_data: OTData, 68 | src_conditions: Optional[Sequence[Any]] = None, 69 | tgt_conditions: Optional[Sequence[Any]] = None, 70 | is_aligned: bool = False, 71 | seed: Optional[int] = None, 72 | ): 73 | self.src_data = src_data 74 | self.tgt_data = tgt_data 75 | 76 | if src_conditions is None: 77 | src_conditions = [None] * len(src_data) 78 | self.src_conditions = list(src_conditions) 79 | if tgt_conditions is None: 80 | tgt_conditions = [None] * len(tgt_data) 81 | self.tgt_conditions = list(tgt_conditions) 82 | 83 | self._tgt_cond_to_ix = collections.defaultdict(list) 84 | for ix, cond in enumerate(tgt_conditions): 85 | self._tgt_cond_to_ix[cond].append(ix) 86 | 87 | self.is_aligned = is_aligned 88 | self._rng = np.random.default_rng(seed) 89 | 90 | self._verify_integrity() 91 | 92 | def _verify_integrity(self) -> None: 93 | assert len(self.src_data) == len(self.src_conditions) 94 | assert len(self.tgt_data) == len(self.tgt_conditions) 95 | 96 | if self.is_aligned: 97 | assert len(self.src_data) == len(self.tgt_data) 98 | assert self.src_conditions == self.tgt_conditions 99 | else: 100 | sym_diff = set(self.src_conditions 101 | ).symmetric_difference(self.tgt_conditions) 102 | assert not sym_diff, sym_diff 103 | 104 | def _sample_from_target(self, src_ix: int) -> Item_t: 105 | src_cond = self.src_conditions[src_ix] 106 | tgt_ixs = self._tgt_cond_to_ix[src_cond] 107 | ix = self._rng.choice(tgt_ixs) 108 | return self.tgt_data[ix] 109 | 110 | def __getitem__(self, ix: int) -> Item_t: 111 | src = self.src_data[ix] 112 | src = {f"{self.SRC_PREFIX}_{k}": v for k, v in src.items()} 113 | 114 | tgt = self.tgt_data[ix] if self.is_aligned else self._sample_from_target(ix) 115 | tgt = {f"{self.TGT_PREFIX}_{k}": v for k, v in tgt.items()} 116 | 117 | return {**src, **tgt} 118 | 119 | def __len__(self) -> int: 120 | return len(self.src_data) 121 | -------------------------------------------------------------------------------- /src/ott/neural/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import monge_gap, neuraldual 15 | -------------------------------------------------------------------------------- /src/ott/neural/methods/flows/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import dynamics, genot, otfm 15 | -------------------------------------------------------------------------------- /src/ott/neural/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import icnn, layers, potentials, velocity_field 15 | -------------------------------------------------------------------------------- /src/ott/neural/networks/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import conjugate, posdef, time_encoder 15 | -------------------------------------------------------------------------------- /src/ott/neural/networks/layers/conjugate.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import abc 15 | from typing import Callable, Literal, NamedTuple, Optional 16 | 17 | import jax.numpy as jnp 18 | from jaxopt import LBFGS 19 | 20 | from ott import utils 21 | 22 | __all__ = [ 23 | "ConjugateResults", 24 | "FenchelConjugateSolver", 25 | "FenchelConjugateLBFGS", 26 | "DEFAULT_CONJUGATE_SOLVER", 27 | ] 28 | 29 | 30 | class ConjugateResults(NamedTuple): 31 | r"""Holds the results of numerically conjugating a function. 32 | 33 | Args: 34 | val: the conjugate value, i.e., :math:`f^\star(y)` 35 | grad: the gradient, i.e., :math:`\nabla f^\star(y)` 36 | num_iter: the number of iterations taken by the solver 37 | """ 38 | val: float 39 | grad: jnp.ndarray 40 | num_iter: int 41 | 42 | 43 | class FenchelConjugateSolver(abc.ABC): 44 | r"""Abstract conjugate solver class. 45 | 46 | Given a function :math:`f`, numerically estimate the Fenchel conjugate 47 | :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle`. 48 | """ 49 | 50 | @abc.abstractmethod 51 | def solve( 52 | self, 53 | f: Callable[[jnp.ndarray], jnp.ndarray], 54 | y: jnp.ndarray, 55 | x_init: Optional[jnp.ndarray] = None 56 | ) -> ConjugateResults: 57 | """Solve for the conjugate. 58 | 59 | Args: 60 | f: function to conjugate 61 | y: point to conjugate 62 | x_init: initial point to search over 63 | 64 | Returns: 65 | The solution to the conjugation. 66 | """ 67 | 68 | 69 | @utils.register_pytree_node 70 | class FenchelConjugateLBFGS(FenchelConjugateSolver): 71 | """Solve for the conjugate using :class:`~jaxopt.LBFGS`. 72 | 73 | Args: 74 | gtol: gradient tolerance 75 | max_iter: maximum number of iterations 76 | max_linesearch_iter: maximum number of line search iterations 77 | linesearch_type: type of line search 78 | linesearch_init: strategy for line search initialization 79 | increase_factor: factor by which to increase the step size during 80 | the line search 81 | """ 82 | 83 | gtol: float = 1e-3 84 | max_iter: int = 10 85 | max_linesearch_iter: int = 10 86 | linesearch_type: Literal["zoom", "backtracking", 87 | "hager-zhang"] = "backtracking" 88 | linesearch_init: Literal["increase", "max", "current"] = "increase" 89 | increase_factor: float = 1.5 90 | 91 | def solve( # noqa: D102 92 | self, 93 | f: Callable[[jnp.ndarray], jnp.ndarray], 94 | y: jnp.ndarray, 95 | x_init: Optional[jnp.array] = None 96 | ) -> ConjugateResults: 97 | assert y.ndim 98 | 99 | solver = LBFGS( 100 | fun=lambda x: f(x) - x.ravel().dot(y.ravel()), 101 | tol=self.gtol, 102 | maxiter=self.max_iter, 103 | linesearch=self.linesearch_type, 104 | linesearch_init=self.linesearch_init, 105 | increase_factor=self.increase_factor, 106 | implicit_diff=False, 107 | unroll=False 108 | ) 109 | 110 | out = solver.run(y if x_init is None else x_init) 111 | return ConjugateResults( 112 | val=-out.state.value, grad=out.params, num_iter=out.state.iter_num 113 | ) 114 | 115 | 116 | DEFAULT_CONJUGATE_SOLVER = FenchelConjugateLBFGS( 117 | gtol=1e-5, 118 | max_iter=20, 119 | max_linesearch_iter=20, 120 | linesearch_type="backtracking", 121 | ) 122 | -------------------------------------------------------------------------------- /src/ott/neural/networks/layers/time_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import jax.numpy as jnp 15 | 16 | __all__ = ["cyclical_time_encoder"] 17 | 18 | 19 | def cyclical_time_encoder(t: jnp.ndarray, n_freqs: int = 128) -> jnp.ndarray: 20 | r"""Encode time :math:`t` into a cyclical representation. 21 | 22 | Time :math:`t` is encoded as :math:`cos(\hat{t})` and :math:`sin(\hat{t})` 23 | where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`. 24 | 25 | Args: 26 | t: Time of shape ``[n, 1]``. 27 | n_freqs: Frequency :math:`n_f` of the cyclical encoding. 28 | 29 | Returns: 30 | Encoded time of shape ``[n, 2 * n_freqs]``. 31 | """ 32 | freq = 2 * jnp.arange(n_freqs) * jnp.pi 33 | t = freq * t 34 | return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1) 35 | -------------------------------------------------------------------------------- /src/ott/neural/networks/velocity_field.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional, Sequence 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | import optax 20 | from flax import linen as nn 21 | from flax.training import train_state 22 | 23 | from ott.neural.networks.layers import time_encoder 24 | 25 | __all__ = ["VelocityField"] 26 | 27 | 28 | class VelocityField(nn.Module): 29 | r"""Neural vector field. 30 | 31 | This class learns a map :math:`v: \mathbb{R}\times \mathbb{R}^d 32 | \rightarrow \mathbb{R}^d` solving the ODE :math:`\frac{dx}{dt} = v(t, x)`. 33 | Given a source distribution at time :math:`t_0`, the velocity field can be 34 | used to transport the source distribution given at :math:`t_0` to 35 | a target distribution given at :math:`t_1` by integrating :math:`v(t, x)` 36 | from :math:`t=t_0` to :math:`t=t_1`. 37 | 38 | Args: 39 | hidden_dims: Dimensionality of the embedding of the data. 40 | output_dims: Dimensionality of the embedding of the output. 41 | condition_dims: Dimensionality of the embedding of the condition. 42 | If :obj:`None`, the velocity field has no conditions. 43 | time_dims: Dimensionality of the time embedding. 44 | If :obj:`None`, ``hidden_dims`` is used. 45 | time_encoder: Time encoder for the velocity field. 46 | act_fn: Activation function. 47 | """ 48 | hidden_dims: Sequence[int] 49 | output_dims: Sequence[int] 50 | condition_dims: Optional[Sequence[int]] = None 51 | time_dims: Optional[Sequence[int]] = None 52 | time_encoder: Callable[[jnp.ndarray], 53 | jnp.ndarray] = time_encoder.cyclical_time_encoder 54 | act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu 55 | dropout_rate: float = 0.0 56 | 57 | @nn.compact 58 | def __call__( 59 | self, 60 | t: jnp.ndarray, 61 | x: jnp.ndarray, 62 | condition: Optional[jnp.ndarray] = None, 63 | train: bool = True, 64 | ) -> jnp.ndarray: 65 | """Forward pass through the neural vector field. 66 | 67 | Args: 68 | t: Time of shape ``[batch, 1]``. 69 | x: Data of shape ``[batch, ...]``. 70 | condition: Conditioning vector of shape ``[batch, ...]``. 71 | train: If `True`, enables dropout for training. 72 | 73 | Returns: 74 | Output of the neural vector field of shape ``[batch, output_dim]``. 75 | """ 76 | time_dims = self.hidden_dims if self.time_dims is None else self.time_dims 77 | 78 | t = self.time_encoder(t) 79 | for time_dim in time_dims: 80 | t = self.act_fn(nn.Dense(time_dim)(t)) 81 | t = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(t) 82 | 83 | for hidden_dim in self.hidden_dims: 84 | x = self.act_fn(nn.Dense(hidden_dim)(x)) 85 | x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) 86 | 87 | if self.condition_dims is not None: 88 | assert condition is not None, "No condition was passed." 89 | for cond_dim in self.condition_dims: 90 | condition = self.act_fn(nn.Dense(cond_dim)(condition)) 91 | condition = nn.Dropout( 92 | rate=self.dropout_rate, deterministic=not train 93 | )( 94 | condition 95 | ) 96 | feats = jnp.concatenate([t, x, condition], axis=-1) 97 | else: 98 | feats = jnp.concatenate([t, x], axis=-1) 99 | 100 | for output_dim in self.output_dims[:-1]: 101 | feats = self.act_fn(nn.Dense(output_dim)(feats)) 102 | feats = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(feats) 103 | 104 | # No activation function for the final layer 105 | return nn.Dense(self.output_dims[-1])(feats) 106 | 107 | def create_train_state( 108 | self, 109 | rng: jax.Array, 110 | optimizer: optax.OptState, 111 | input_dim: int, 112 | condition_dim: Optional[int] = None, 113 | ) -> train_state.TrainState: 114 | """Create the training state. 115 | 116 | Args: 117 | rng: Random number generator. 118 | optimizer: Optimizer. 119 | input_dim: Dimensionality of the velocity field. 120 | condition_dim: Dimensionality of the condition of the velocity field. 121 | 122 | Returns: 123 | The training state. 124 | """ 125 | t, x = jnp.ones((1, 1)), jnp.ones((1, input_dim)) 126 | if self.condition_dims is None: 127 | cond = None 128 | else: 129 | assert condition_dim > 0, "Condition dimension must be positive." 130 | cond = jnp.ones((1, condition_dim)) 131 | 132 | params = self.init(rng, t, x, cond, train=False)["params"] 133 | return train_state.TrainState.create( 134 | apply_fn=self.apply, params=params, tx=optimizer 135 | ) 136 | -------------------------------------------------------------------------------- /src/ott/problems/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import linear, quadratic 15 | -------------------------------------------------------------------------------- /src/ott/problems/linear/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import barycenter_problem, linear_problem, potentials 15 | -------------------------------------------------------------------------------- /src/ott/problems/linear/linear_problem.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | from ott.geometry import geometry 20 | 21 | __all__ = ["LinearProblem"] 22 | 23 | # TODO(michalk8): move to typing.py when refactoring the types 24 | MarginalFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] 25 | TransportAppFunc = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, int], 26 | jnp.ndarray] 27 | 28 | 29 | @jax.tree_util.register_pytree_node_class 30 | class LinearProblem: 31 | r"""Linear OT problem. 32 | 33 | This class describes the main ingredients appearing in a linear OT problem. 34 | Namely, a ``geom`` object (including cost structure/points) describing point 35 | clouds or the support of measures, followed by probability masses ``a`` and 36 | ``b``. Unbalancedness of the problem is also kept track of, through two 37 | coefficients ``tau_a`` and ``tau_b``, which are both kept between 0 and 1 38 | (1 corresponding to a balanced OT problem). 39 | 40 | Args: 41 | geom: The ground geometry cost of the linear problem. 42 | a: The first marginal. If ``None``, it will be uniform. 43 | b: The second marginal. If ``None``, it will be uniform. 44 | tau_a: If :math:`<1`, defines how much unbalanced the problem is 45 | on the first marginal. 46 | tau_b: If :math:`< 1`, defines how much unbalanced the problem is 47 | on the second marginal. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | geom: geometry.Geometry, 53 | a: Optional[jnp.ndarray] = None, 54 | b: Optional[jnp.ndarray] = None, 55 | tau_a: float = 1.0, 56 | tau_b: float = 1.0 57 | ): 58 | self.geom = geom 59 | self._a = a 60 | self._b = b 61 | self.tau_a = tau_a 62 | self.tau_b = tau_b 63 | 64 | @property 65 | def a(self) -> jnp.ndarray: 66 | """First marginal.""" 67 | if self._a is not None: 68 | return self._a 69 | n, _ = self.geom.shape 70 | return jnp.full((n,), fill_value=1.0 / n, dtype=self.dtype) 71 | 72 | @property 73 | def b(self) -> jnp.ndarray: 74 | """Second marginal.""" 75 | if self._b is not None: 76 | return self._b 77 | _, m = self.geom.shape 78 | return jnp.full((m,), fill_value=1.0 / m, dtype=self.dtype) 79 | 80 | @property 81 | def is_balanced(self) -> bool: 82 | """Whether the problem is balanced.""" 83 | return self.tau_a == 1.0 and self.tau_b == 1.0 84 | 85 | @property 86 | def is_uniform(self) -> bool: 87 | """True if no weights ``a,b`` were passed, and have defaulted to uniform.""" 88 | return self._a is None and self._b is None 89 | 90 | @property 91 | def is_equal_size(self) -> bool: 92 | """True if square shape, i.e. ``n == m``.""" 93 | return self.geom.shape[0] == self.geom.shape[1] 94 | 95 | @property 96 | def is_assignment(self) -> bool: 97 | """True if assignment problem.""" 98 | return self.is_equal_size and self.is_uniform and self.is_balanced 99 | 100 | @property 101 | def epsilon(self) -> float: 102 | """Entropic regularization.""" 103 | return self.geom.epsilon 104 | 105 | @property 106 | def dtype(self) -> jnp.dtype: 107 | """The data type of the geometry.""" 108 | return self.geom.dtype 109 | 110 | def get_transport_functions( 111 | self, lse_mode: bool 112 | ) -> Tuple[MarginalFunc, MarginalFunc, TransportAppFunc]: 113 | """Instantiate useful functions for Sinkhorn depending on lse_mode.""" 114 | geom = self.geom 115 | if lse_mode: 116 | marginal_a = lambda f, g: geom.marginal_from_potentials(f, g, 1) 117 | marginal_b = lambda f, g: geom.marginal_from_potentials(f, g, 0) 118 | app_transport = geom.apply_transport_from_potentials 119 | else: 120 | marginal_a = lambda f, g: geom.marginal_from_scalings( 121 | geom.scaling_from_potential(f), geom.scaling_from_potential(g), 1 122 | ) 123 | marginal_b = lambda f, g: geom.marginal_from_scalings( 124 | geom.scaling_from_potential(f), geom.scaling_from_potential(g), 0 125 | ) 126 | app_transport = lambda f, g, z, axis: geom.apply_transport_from_scalings( 127 | geom.scaling_from_potential(f), geom.scaling_from_potential(g), z, 128 | axis 129 | ) 130 | return marginal_a, marginal_b, app_transport 131 | 132 | def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 133 | return ([self.geom, self._a, self._b], { 134 | "tau_a": self.tau_a, 135 | "tau_b": self.tau_b 136 | }) 137 | 138 | @classmethod 139 | def tree_unflatten( # noqa: D102 140 | cls, aux_data: Dict[str, Any], children: Sequence[Any] 141 | ) -> "LinearProblem": 142 | return cls(*children, **aux_data) 143 | -------------------------------------------------------------------------------- /src/ott/problems/quadratic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import gw_barycenter, quadratic_costs, quadratic_problem 15 | -------------------------------------------------------------------------------- /src/ott/problems/quadratic/quadratic_costs.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, NamedTuple 15 | 16 | import jax.numpy as jnp 17 | import jax.scipy as jsp 18 | 19 | __all__ = ["make_square_loss", "make_kl_loss"] 20 | 21 | 22 | class Loss(NamedTuple): # noqa: D101 23 | func: Callable[[jnp.ndarray], jnp.ndarray] 24 | is_linear: bool 25 | 26 | 27 | class GWLoss(NamedTuple): 28 | r"""Efficient decomposition of the Gromov-Wasserstein loss function. 29 | 30 | The loss function :math:`L` is assumed to match the form given in eq. 5. of 31 | :cite:`peyre:16`: 32 | 33 | .. math:: 34 | L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y) 35 | 36 | Args: 37 | f1: First linear term. 38 | f2: Second linear term. 39 | h1: First quadratic term. 40 | h2: Second quadratic term. 41 | """ 42 | f1: Loss 43 | f2: Loss 44 | h1: Loss 45 | h2: Loss 46 | 47 | 48 | def make_square_loss() -> GWLoss: 49 | """Squared Euclidean loss for Gromov-Wasserstein. 50 | 51 | See Prop. 1 and Remark 1 of :cite:`peyre:16` for more information. 52 | 53 | Returns: 54 | The squared Euclidean loss. 55 | """ 56 | f1 = Loss(lambda x: x ** 2, is_linear=False) 57 | f2 = Loss(lambda y: y ** 2, is_linear=False) 58 | h1 = Loss(lambda x: x, is_linear=True) 59 | h2 = Loss(lambda y: 2.0 * y, is_linear=True) 60 | return GWLoss(f1, f2, h1, h2) 61 | 62 | 63 | def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: 64 | r"""Kullback-Leibler loss for Gromov-Wasserstein. 65 | 66 | See Prop. 1 and Remark 1 of :cite:`peyre:16` for more information. 67 | 68 | Args: 69 | clipping_value: Value used to avoid :math:`\log(0)`. 70 | 71 | Returns: 72 | The KL loss. 73 | """ 74 | f1 = Loss(lambda x: -jsp.special.entr(x) - x, is_linear=False) 75 | f2 = Loss(lambda y: y, is_linear=True) 76 | h1 = Loss(lambda x: x, is_linear=True) 77 | h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False) 78 | return GWLoss(f1, f2, h1, h2) 79 | -------------------------------------------------------------------------------- /src/ott/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ott-jax/ott/d28d5d45b0bd26d2e3d54fe1085f2835dec5f5d6/src/ott/py.typed -------------------------------------------------------------------------------- /src/ott/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import linear, quadratic, utils, was_solver 15 | -------------------------------------------------------------------------------- /src/ott/solvers/linear/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import ( 15 | acceleration, 16 | continuous_barycenter, 17 | discrete_barycenter, 18 | implicit_differentiation, 19 | lr_utils, 20 | sinkhorn, 21 | sinkhorn_lr, 22 | univariate, 23 | ) 24 | from ._solve import solve, solve_univariate 25 | 26 | __all__ = [ 27 | "acceleration", 28 | "continuous_barycenter", 29 | "discrete_barycenter", 30 | "implicit_differentiation", 31 | "lr_utils", 32 | "sinkhorn", 33 | "sinkhorn_lr", 34 | "univariate", 35 | "solve", 36 | "solve_univariate", 37 | ] 38 | -------------------------------------------------------------------------------- /src/ott/solvers/linear/_solve.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Optional, Union 15 | 16 | import jax.numpy as jnp 17 | 18 | from ott.geometry import geometry, pointcloud 19 | from ott.problems.linear import linear_problem 20 | from ott.solvers.linear import sinkhorn, sinkhorn_lr, univariate 21 | 22 | __all__ = ["solve", "solve_univariate"] 23 | 24 | 25 | def solve( 26 | geom: geometry.Geometry, 27 | a: Optional[jnp.ndarray] = None, 28 | b: Optional[jnp.ndarray] = None, 29 | tau_a: float = 1.0, 30 | tau_b: float = 1.0, 31 | rank: int = -1, 32 | **kwargs: Any 33 | ) -> Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]: 34 | """Solve linear regularized OT problem using Sinkhorn iterations. 35 | 36 | Args: 37 | geom: The ground geometry of the linear problem. 38 | a: The first marginal. If :obj:`None`, it will be uniform. 39 | b: The second marginal. If :obj:`None`, it will be uniform. 40 | tau_a: If :math:`< 1`, defines how much unbalanced the problem is 41 | on the first marginal. 42 | tau_b: If :math:`< 1`, defines how much unbalanced the problem is 43 | on the second marginal. 44 | rank: 45 | Rank constraint on the coupling to minimize the linear OT problem 46 | :cite:`scetbon:21`. If :math:`-1`, no rank constraint is used. 47 | kwargs: Keyword arguments for 48 | :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or 49 | :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, 50 | depending on the ``rank``. 51 | 52 | Returns: 53 | The Sinkhorn output. 54 | """ 55 | prob = linear_problem.LinearProblem(geom, a=a, b=b, tau_a=tau_a, tau_b=tau_b) 56 | if rank > 0: 57 | solver = sinkhorn_lr.LRSinkhorn(rank=rank, **kwargs) 58 | else: 59 | solver = sinkhorn.Sinkhorn(**kwargs) 60 | return solver(prob) 61 | 62 | 63 | def solve_univariate( 64 | geom: pointcloud.PointCloud, 65 | a: Optional[jnp.ndarray] = None, 66 | b: Optional[jnp.ndarray] = None, 67 | *, 68 | return_transport: bool = False, 69 | return_dual_variables: bool = False, 70 | ) -> univariate.UnivariateOutput: 71 | """Solve 1D OT problems between two :math:`d`-dimensional point clouds. 72 | 73 | This function selects the underlying solver based on the following criteria: 74 | 75 | - :func:`~ott.solvers.linear.univariate.north_west_solver` - if 76 | ``return_dual_variables = True``. 77 | - :func:`~ott.solvers.linear.univariate.uniform_solver` - if ``a`` and 78 | ``b`` are both uniform and have the same size. 79 | - :func:`~ott.solvers.linear.univariate.quantile_solver` - otherwise. 80 | 81 | Args: 82 | geom: Geometry containing two :math:`d`-dimensional point clouds and 83 | a ground :class:`translation-invariant cost `. 84 | a: The first marginal. If :obj:`None`, it will be uniform. 85 | b: The second marginal. If :obj:`None`, it will be uniform. 86 | return_transport: Whether to also return the mapped pairs used to compute 87 | the :attr:`~ott.solvers.linear.univariate.UnivariateOutput.transport_matrices`. 88 | return_dual_variables: Whether to also return the dual variables. 89 | 90 | Returns: 91 | The univariate output. 92 | """ # noqa: E501 93 | prob = linear_problem.LinearProblem(geom, a=a, b=b) 94 | if return_dual_variables: 95 | return univariate.north_west_solver(prob) 96 | if prob.is_uniform and prob.is_equal_size: 97 | return univariate.uniform_solver(prob, return_transport=return_transport) 98 | return univariate.quantile_solver(prob, return_transport=return_transport) 99 | -------------------------------------------------------------------------------- /src/ott/solvers/linear/lineax_implicit.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional, TypeVar 15 | 16 | import equinox as eqx 17 | import lineax as lx 18 | from jaxtyping import Array, Float, PyTree 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import jax.tree_util as jtu 23 | 24 | _T = TypeVar("_T") 25 | _FlatPyTree = tuple[list[_T], jtu.PyTreeDef] 26 | 27 | __all__ = ["CustomTransposeLinearOperator", "solve_lineax"] 28 | 29 | 30 | class CustomTransposeLinearOperator(lx.FunctionLinearOperator): 31 | """Implement a linear operator that can specify its transpose directly.""" 32 | fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]] 33 | fn_t: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]] 34 | input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field() 35 | input_structure_t: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field() 36 | tags: frozenset[object] 37 | 38 | def __init__(self, fn, fn_t, input_structure, input_structure_t, tags=()): 39 | super().__init__(fn, input_structure, tags) 40 | self.fn_t = eqx.filter_closure_convert(fn_t, input_structure_t) 41 | self.input_structure_t = input_structure_t 42 | 43 | def transpose(self): 44 | """Provide custom transposition operator from function.""" 45 | return lx.FunctionLinearOperator(self.fn_t, self.input_structure_t) 46 | 47 | 48 | def solve_lineax( 49 | lin: Callable, 50 | b: jnp.ndarray, 51 | lin_t: Optional[Callable] = None, 52 | symmetric: bool = False, 53 | nonsym_solver: Optional[lx.AbstractLinearSolver] = None, 54 | ridge_identity: float = 0.0, 55 | ridge_kernel: float = 0.0, 56 | **kwargs: Any 57 | ) -> jnp.ndarray: 58 | """Wrapper around lineax solvers. 59 | 60 | Args: 61 | lin: Linear operator 62 | b: vector. Returned `x` is such that `lin(x)=b` 63 | lin_t: Linear operator, corresponding to transpose of `lin`. 64 | symmetric: whether `lin` is symmetric. 65 | nonsym_solver: solver used when handling non-symmetric cases. Note that 66 | :class:`~lineax.CG` is used by default in the symmetric case. 67 | ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0` 68 | ridge_identity: handles rank deficient transport matrices (this happens 69 | typically when rows/cols in cost/kernel matrices are collinear, or, 70 | equivalently when two points from either measure are close). 71 | kwargs: arguments passed to :class:`~lineax.AbstractLinearSolver` linear 72 | solver. 73 | """ 74 | input_structure = jax.eval_shape(lambda: b) 75 | kwargs.setdefault("rtol", 1e-6) 76 | kwargs.setdefault("atol", 1e-6) 77 | 78 | if ridge_kernel > 0.0 or ridge_identity > 0.0: 79 | lin_reg = lambda x: lin(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x 80 | lin_t_reg = lambda x: lin_t(x) + ridge_kernel * jnp.sum( 81 | x 82 | ) + ridge_identity * x 83 | else: 84 | lin_reg, lin_t_reg = lin, lin_t 85 | 86 | if symmetric: 87 | solver = lx.CG(**kwargs) 88 | fn_operator = lx.FunctionLinearOperator( 89 | lin_reg, input_structure, tags=lx.positive_semidefinite_tag 90 | ) 91 | return lx.linear_solve(fn_operator, b, solver).value 92 | # In the non-symmetric case, use NormalCG by default, but consider 93 | # user defined choice of alternative lx solver. 94 | solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver 95 | solver = solver_type(**kwargs) 96 | fn_operator = CustomTransposeLinearOperator( 97 | lin_reg, lin_t_reg, input_structure, input_structure 98 | ) 99 | return lx.linear_solve(fn_operator, b, solver).value 100 | -------------------------------------------------------------------------------- /src/ott/solvers/quadratic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import ( 15 | gromov_wasserstein, 16 | gromov_wasserstein_lr, 17 | gw_barycenter, 18 | lower_bound, 19 | ) 20 | from ._solve import solve 21 | -------------------------------------------------------------------------------- /src/ott/solvers/quadratic/_solve.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Literal, Optional, Union 15 | 16 | import jax.numpy as jnp 17 | 18 | from ott.geometry import geometry 19 | from ott.problems.quadratic import quadratic_costs, quadratic_problem 20 | from ott.solvers.linear import sinkhorn 21 | from ott.solvers.quadratic import gromov_wasserstein as gw 22 | from ott.solvers.quadratic import gromov_wasserstein_lr as lrgw 23 | 24 | __all__ = ["solve"] 25 | 26 | 27 | def solve( 28 | geom_xx: geometry.Geometry, 29 | geom_yy: geometry.Geometry, 30 | geom_xy: Optional[geometry.Geometry] = None, 31 | fused_penalty: float = 1.0, 32 | a: Optional[jnp.ndarray] = None, 33 | b: Optional[jnp.ndarray] = None, 34 | tau_a: float = 1.0, 35 | tau_b: float = 1.0, 36 | loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", 37 | gw_unbalanced_correction: bool = True, 38 | rank: int = -1, 39 | linear_solver_kwargs: Optional[Dict[str, Any]] = None, 40 | **kwargs: Any, 41 | ) -> Union[gw.GWOutput, lrgw.LRGWOutput]: 42 | """Solve quadratic regularized OT problem using a Gromov-Wasserstein solver. 43 | 44 | Args: 45 | geom_xx: Ground geometry of the first space. 46 | geom_yy: Ground geometry of the second space. 47 | geom_xy: Geometry defining the linear penalty term for 48 | fused Gromov-Wasserstein :cite:`vayer:19`. If :obj:`None`, the problem 49 | reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`. 50 | fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein, 51 | i.e. ``problem = purely quadratic + fused_penalty * linear problem``. 52 | a: The first marginal. If :obj:`None`, it will be uniform. 53 | b: The second marginal. If :obj:`None`, it will be uniform. 54 | tau_a: If :math:`< 1`, defines how much unbalanced the problem is 55 | on the first marginal. 56 | tau_b: If :math:`< 1`, defines how much unbalanced the problem is 57 | on the second marginal. 58 | loss: Gromov-Wasserstein loss function, see 59 | :class:`~ott.problems.quadratic.quadratic_costs.GWLoss` for more 60 | information. If ``rank > 0``, ``'sqeucl'`` is always used. 61 | gw_unbalanced_correction: Whether the unbalanced version of 62 | :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` 63 | only affect the resolution of the linearization of the GW problem 64 | in the inner loop. Only used when ``rank = -1``. 65 | rank: Rank constraint on the coupling to minimize the quadratic OT problem 66 | :cite:`scetbon:22`. If :math:`-1`, no rank constraint is used. 67 | linear_solver_kwargs: Keyword arguments for 68 | :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`, if ``rank > 0``. 69 | kwargs: Keyword arguments for 70 | :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` or 71 | :class:`~ott.solvers.quadratic.gromov_wasserstein_lr.LRGromovWasserstein`, 72 | depending on the ``rank`` 73 | 74 | Returns: 75 | The Gromov-Wasserstein output. 76 | """ 77 | prob = quadratic_problem.QuadraticProblem( 78 | geom_xx=geom_xx, 79 | geom_yy=geom_yy, 80 | geom_xy=geom_xy, 81 | fused_penalty=fused_penalty, 82 | a=a, 83 | b=b, 84 | tau_a=tau_a, 85 | tau_b=tau_b, 86 | loss=loss, 87 | gw_unbalanced_correction=gw_unbalanced_correction 88 | ) 89 | 90 | if rank > 0: 91 | solver = lrgw.LRGromovWasserstein(rank, **kwargs) 92 | else: 93 | if linear_solver_kwargs is None: 94 | linear_solver_kwargs = {} 95 | linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) 96 | solver = gw.GromovWasserstein(linear_solver, **kwargs) 97 | 98 | return solver(prob) 99 | -------------------------------------------------------------------------------- /src/ott/solvers/quadratic/lower_bound.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING, Any, Optional 15 | 16 | from ott.geometry import pointcloud 17 | from ott.problems.quadratic import quadratic_problem 18 | from ott.solvers import linear 19 | from ott.solvers.linear import sinkhorn 20 | 21 | if TYPE_CHECKING: 22 | from ott.geometry import distrib_costs 23 | 24 | __all__ = ["third_lower_bound"] 25 | 26 | 27 | def third_lower_bound( 28 | prob: quadratic_problem.QuadraticProblem, 29 | distrib_cost: "distrib_costs.UnivariateWasserstein", 30 | epsilon: Optional[float] = None, 31 | **kwargs: Any, 32 | ) -> sinkhorn.SinkhornOutput: 33 | """Computes the third lower bound distance from :cite:`memoli:11`, def. 6.3. 34 | 35 | Args: 36 | prob: Quadratic OT problem. 37 | distrib_cost: Univariate Wasserstein cost used to compare two point clouds 38 | in different spaces. Each point is seen as its distribution of costs 39 | to other points in its respective point cloud. 40 | epsilon: Entropy regularization. 41 | kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`. 42 | 43 | Returns: 44 | An approximation of the GW coupling that can be used to initialize 45 | the solution of the quadratic OT problem. 46 | """ 47 | dists_xx = prob.geom_xx.cost_matrix 48 | dists_yy = prob.geom_yy.cost_matrix 49 | geom_xy = pointcloud.PointCloud( 50 | dists_xx, dists_yy, cost_fn=distrib_cost, epsilon=epsilon 51 | ) 52 | 53 | return linear.solve(geom_xy, **kwargs) 54 | -------------------------------------------------------------------------------- /src/ott/solvers/was_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple, Union 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | from ott.solvers.linear import sinkhorn, sinkhorn_lr 20 | 21 | if TYPE_CHECKING: 22 | from ott.solvers.linear import continuous_barycenter 23 | 24 | __all__ = ["WassersteinSolver"] 25 | 26 | State = Union[sinkhorn.SinkhornState, sinkhorn_lr.LRSinkhornState, 27 | "continuous_barycenter.FreeBarycenterState"] 28 | 29 | 30 | @jax.tree_util.register_pytree_node_class 31 | class WassersteinSolver: 32 | """A generic solver for problems that use a linear problem in inner loop.""" 33 | 34 | def __init__( 35 | self, 36 | linear_solver: Union["sinkhorn.Sinkhorn", "sinkhorn_lr.LRSinkhorn"], 37 | threshold: float = 1e-3, 38 | min_iterations: int = 5, 39 | max_iterations: int = 50, 40 | store_inner_errors: bool = False, 41 | ): 42 | self.linear_solver = linear_solver 43 | self.min_iterations = min_iterations 44 | self.max_iterations = max_iterations 45 | self.threshold = threshold 46 | self.store_inner_errors = store_inner_errors 47 | 48 | @property 49 | def rank(self) -> int: 50 | """Rank of the linear OT solver.""" 51 | return self.linear_solver.rank if self.is_low_rank else -1 52 | 53 | @property 54 | def is_low_rank(self) -> bool: 55 | """Whether the solver is low-rank.""" 56 | return isinstance(self.linear_solver, sinkhorn_lr.LRSinkhorn) 57 | 58 | def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 59 | return ([self.linear_solver, self.threshold], { 60 | "min_iterations": self.min_iterations, 61 | "max_iterations": self.max_iterations, 62 | "store_inner_errors": self.store_inner_errors, 63 | }) 64 | 65 | @classmethod 66 | def tree_unflatten( # noqa: D102 67 | cls, aux_data: Dict[str, Any], children: Sequence[Any] 68 | ) -> "WassersteinSolver": 69 | return cls(*children, **aux_data) 70 | 71 | def _converged(self, state: State, iteration: int) -> bool: 72 | costs, i, tol = state.costs, iteration, self.threshold 73 | return jnp.logical_and( 74 | i >= 2, jnp.isclose(costs[i - 2], costs[i - 1], rtol=tol) 75 | ) 76 | 77 | def _diverged(self, state: State, iteration: int) -> bool: 78 | return jnp.logical_not(jnp.isfinite(state.costs[iteration - 1])) 79 | 80 | def _continue(self, state: State, iteration: int) -> bool: 81 | """Continue while not(converged) and not(diverged).""" 82 | return jnp.logical_or( 83 | iteration <= 2, 84 | jnp.logical_and( 85 | jnp.logical_not(self._diverged(state, iteration)), 86 | jnp.logical_not(self._converged(state, iteration)) 87 | ) 88 | ) 89 | -------------------------------------------------------------------------------- /src/ott/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import ( 15 | conformal, 16 | gaussian_mixture, 17 | k_means, 18 | plot, 19 | progot, 20 | segment_sinkhorn, 21 | sinkhorn_divergence, 22 | sliced, 23 | soft_sort, 24 | unreg, 25 | ) 26 | -------------------------------------------------------------------------------- /src/ott/tools/gaussian_mixture/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import fit_gmm_pair, gaussian, gaussian_mixture, gaussian_mixture_pair 15 | -------------------------------------------------------------------------------- /src/ott/tools/gaussian_mixture/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Iterable, List, Tuple 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | 20 | def get_mean_and_var( 21 | points: jnp.ndarray, # (n, d) 22 | weights: jnp.ndarray, # (n,) 23 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 24 | """Get the mean and variance of a weighted set of points.""" 25 | weights_sum = jnp.sum(weights, axis=-1) # (1,) 26 | mean = ( 27 | # matmul((1, n), (n, d)) -> (1, d) 28 | jnp.matmul(weights, points) / weights_sum 29 | ) 30 | # center points 31 | centered = points - mean[None, :] # (n, d) - (1, d) 32 | var = ( 33 | # matmul((1, n), (n, d)) -> (1, d) 34 | jnp.matmul(weights, centered ** 2) / weights_sum 35 | ) 36 | return mean, var 37 | 38 | 39 | def get_mean_and_cov( 40 | points: jnp.ndarray, # (n, d) 41 | weights: jnp.ndarray, # (n,) 42 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 43 | """Get the mean and covariance of a weighted set of points.""" 44 | weights_sum = jnp.sum(weights, axis=-1, keepdims=True) # (1,) 45 | mean = ( 46 | # matmul((1, n), (n, d)) -> (1, d) 47 | jnp.matmul(weights, points) / weights_sum 48 | ) 49 | # center points 50 | centered = points - mean[None, :] # (n, d) - (1, d) 51 | cov = ( 52 | jnp.matmul( 53 | # (1, n) (d, n) 54 | weights[None, :] * jnp.swapaxes(centered, axis1=-2, axis2=-1), 55 | # (n, d) 56 | centered 57 | ) / weights_sum 58 | ) 59 | return mean, cov 60 | 61 | 62 | def flat_to_tril(x: jnp.ndarray, size: int) -> jnp.ndarray: 63 | """Map flat values to lower triangular matrices. 64 | 65 | Args: 66 | x: flat values 67 | size: size of lower triangular matrices. x should have shape 68 | (..., size(size+1)/2), and the final matrices should have shape 69 | (..., size, size). 70 | 71 | Returns: 72 | Lower triangular matrices. 73 | """ 74 | m = jnp.zeros(x.shape[:-1] + (size, size)) 75 | tril = jnp.tril_indices(size) 76 | return m.at[..., tril[0], tril[1]].set(x) 77 | 78 | 79 | def tril_to_flat(m: jnp.ndarray) -> jnp.ndarray: 80 | """Flatten lower triangular matrices. 81 | 82 | Args: 83 | m: lower triangular matrices of shape (..., size, size) 84 | 85 | Returns: 86 | A vector of shape (..., size (size+1) // 2) 87 | """ 88 | size = m.shape[-1] 89 | tril = jnp.tril_indices(size) 90 | return m[..., tril[0], tril[1]] 91 | 92 | 93 | def apply_to_diag( 94 | m: jnp.ndarray, fn: Callable[[jnp.ndarray], jnp.ndarray] 95 | ) -> jnp.ndarray: 96 | """Apply a function to the diagonal of a matrix.""" 97 | size = m.shape[-1] 98 | diag = jnp.diagonal(m, axis1=-2, axis2=-1) 99 | ind = jnp.arange(size) 100 | return m.at[..., ind, ind].set(fn(diag)) 101 | 102 | 103 | def matrix_powers( 104 | m: jnp.ndarray, 105 | powers: Iterable[float], 106 | ) -> List[jnp.ndarray]: 107 | """Raise a real, symmetric matrix to multiple powers.""" 108 | eigs, q = jnp.linalg.eigh(m) 109 | qt = jnp.swapaxes(q, axis1=-2, axis2=-1) 110 | ret = [] 111 | for power in powers: 112 | ret.append(jnp.matmul(jnp.expand_dims(eigs ** power, -2) * q, qt)) 113 | return ret 114 | 115 | 116 | def invmatvectril( 117 | m: jnp.ndarray, x: jnp.ndarray, lower: bool = True 118 | ) -> jnp.ndarray: 119 | """Multiply x by the inverse of a triangular matrix. 120 | 121 | Args: 122 | m: triangular matrix, shape (d, d) 123 | x: array of points, shape (n, d) 124 | lower: if True, m is lower triangular; otherwise m is upper triangular 125 | 126 | Returns: 127 | m^{-1} x 128 | """ 129 | return jnp.transpose( 130 | jax.scipy.linalg.solve_triangular(m, jnp.transpose(x), lower=lower) 131 | ) 132 | 133 | 134 | def get_random_orthogonal(rng: jax.Array, dim: int) -> jnp.ndarray: 135 | """Get a random orthogonal matrix with the specified dimension.""" 136 | m = jax.random.normal(rng, shape=[dim, dim]) 137 | q, _ = jnp.linalg.qr(m) 138 | return q 139 | -------------------------------------------------------------------------------- /src/ott/tools/gaussian_mixture/probabilities.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | __all__ = ["Probabilities"] 20 | 21 | 22 | @jax.tree_util.register_pytree_node_class 23 | class Probabilities: 24 | """Parameterized array of probabilities of length n. 25 | 26 | The internal representation is a length n-1 unconstrained array. We convert 27 | to a length n simplex by appending a 0 and taking a softmax. 28 | """ 29 | 30 | _params: jnp.ndarray 31 | 32 | def __init__(self, params): 33 | self._params = params 34 | 35 | @classmethod 36 | def from_random( 37 | cls, 38 | rng: jax.Array, 39 | n_dimensions: int, 40 | stdev: Optional[float] = 0.1, 41 | ) -> "Probabilities": 42 | """Construct a random Probabilities.""" 43 | return cls(params=jax.random.normal(rng, shape=(n_dimensions - 1,)) * stdev) 44 | 45 | @classmethod 46 | def from_probs(cls, probs: jnp.ndarray) -> "Probabilities": 47 | """Construct Probabilities from a vector of probabilities.""" 48 | log_probs = jnp.log(probs) 49 | log_probs_normalized, norm = log_probs[:-1], log_probs[-1] 50 | log_probs_normalized -= norm 51 | return cls(params=log_probs_normalized) 52 | 53 | @property 54 | def params(self): # noqa: D102 55 | return self._params 56 | 57 | @property 58 | def dtype(self): # noqa: D102 59 | return self._params.dtype 60 | 61 | def unnormalized_log_probs(self) -> jnp.ndarray: 62 | """Get the unnormalized log probabilities.""" 63 | return jnp.concatenate([self._params, jnp.zeros((1,))], axis=-1) 64 | 65 | def log_probs(self) -> jnp.ndarray: 66 | """Get the log probabilities.""" 67 | return jax.nn.log_softmax(self.unnormalized_log_probs()) 68 | 69 | def probs(self) -> jnp.ndarray: 70 | """Get the probabilities.""" 71 | return jax.nn.softmax(self.unnormalized_log_probs()) 72 | 73 | def sample(self, rng: jax.Array, size: int) -> jnp.ndarray: 74 | """Sample from the distribution.""" 75 | return jax.random.categorical( 76 | rng, logits=self.unnormalized_log_probs(), shape=(size,) 77 | ) 78 | 79 | def tree_flatten(self): # noqa: D102 80 | children = (self.params,) 81 | aux_data = {} 82 | return children, aux_data 83 | 84 | @classmethod 85 | def tree_unflatten(cls, aux_data, children): # noqa: D102 86 | return cls(*children, **aux_data) 87 | 88 | def __repr__(self): 89 | class_name = type(self).__name__ 90 | children, aux = self.tree_flatten() 91 | return "{}({})".format( 92 | class_name, ", ".join([repr(c) for c in children] + 93 | [f"{k}: {repr(v)}" for k, v in aux.items()]) 94 | ) 95 | 96 | def __hash__(self): 97 | return jax.tree_util.tree_flatten(self).__hash__() 98 | 99 | def __eq__(self, other): 100 | return jax.tree_util.tree_flatten(self) == jax.tree_util.tree_flatten(other) 101 | -------------------------------------------------------------------------------- /src/ott/tools/sliced.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional, Tuple 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | from ott import utils 20 | from ott.geometry import costs, pointcloud 21 | from ott.solvers import linear 22 | from ott.solvers.linear import univariate 23 | 24 | __all__ = ["random_proj_sphere", "sliced_wasserstein"] 25 | 26 | Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray] 27 | 28 | 29 | def random_proj_sphere( 30 | x: jnp.ndarray, 31 | n_proj: int = 1000, 32 | rng: Optional[jax.Array] = None 33 | ) -> jnp.ndarray: 34 | """Project data on directions sampled randomly from sphere. 35 | 36 | Args: 37 | x: Array of size ``[n, dim]``. 38 | n_proj: Number of randomly generated projections. 39 | rng: Key used to sample feature extractors. 40 | 41 | Returns: 42 | Array of size ``[n, n_proj]`` features. 43 | """ 44 | rng = utils.default_prng_key(rng) 45 | dim = x.shape[-1] 46 | proj_m = jax.random.normal(rng, (n_proj, dim)) 47 | proj_m /= jnp.linalg.norm(proj_m, axis=1, keepdims=True) 48 | return x @ proj_m.T 49 | 50 | 51 | def sliced_wasserstein( 52 | x: jnp.ndarray, 53 | y: jnp.ndarray, 54 | a: Optional[jnp.ndarray] = None, 55 | b: Optional[jnp.ndarray] = None, 56 | cost_fn: Optional[costs.CostFn] = None, 57 | proj_fn: Optional[Projector] = None, 58 | weights: Optional[jnp.ndarray] = None, 59 | return_transport: bool = False, 60 | return_dual_variables: bool = False, 61 | **kwargs: Any, 62 | ) -> Tuple[jnp.ndarray, univariate.UnivariateOutput]: 63 | r"""Compute the Sliced Wasserstein distance between two weighted point clouds. 64 | 65 | Follows the approach outlined in :cite:`rabin:12` to compute a proxy for OT 66 | distances that relies on creating features (possibly randomly) for data, 67 | through e.g., projections, and then sum the 1D Wasserstein distances between 68 | these features' univariate distributions on both source and target samples. 69 | 70 | Args: 71 | x: Array of shape ``[n, dim]`` of source points' coordinates. 72 | y: Array of shape ``[m, dim]`` of target points' coordinates. 73 | a: Array of shape ``[n,]`` of source probability weights. 74 | b: Array of shape ``[m,]`` of target probability weights. 75 | cost_fn: Cost function. Must be a submodular function of two real arguments, 76 | i.e. such that :math:`\partial c(x,y)/\partial x \partial y <0`. If 77 | :obj:`None`, use :class:`~ott.geometry.costs.SqEuclidean`. 78 | proj_fn: Projection function, mapping any ``[b, dim]`` matrix of coordinates 79 | to ``[b, n_proj]`` matrix of features, on which 1D transports (for 80 | ``n_proj`` directions) are subsequently computed independently. 81 | By default, use :func:`~ott.tools.sliced.random_proj_sphere`. 82 | weights: Array of shape ``[n_proj,]`` of weights used to average the 83 | ``n_proj`` 1D Wasserstein contributions (one for each feature) and form 84 | the sliced Wasserstein distance. Uniform by default, resulting in average 85 | of all these values. 86 | return_transport: Whether to store ``n_proj`` transport plans in the output. 87 | return_dual_variables: Whether to store ``n_proj`` pairs of dual vectors 88 | in the output. 89 | kwargs: Keyword arguments to ``proj_fn``. Could for instance 90 | include, as done with default projector, number of ``n_proj`` projections, 91 | as well as a ``rng`` key to sample as many directions. 92 | 93 | Returns: 94 | The sliced Wasserstein distance with the corresponding output object. 95 | """ 96 | if proj_fn is None: 97 | proj_fn = random_proj_sphere 98 | 99 | x_proj, y_proj = proj_fn(x, **kwargs), proj_fn(y, **kwargs), 100 | geom = pointcloud.PointCloud(x_proj, y_proj, cost_fn=cost_fn) 101 | 102 | out = linear.solve_univariate( 103 | geom, 104 | a, 105 | b, 106 | return_transport=return_transport, 107 | return_dual_variables=return_dual_variables 108 | ) 109 | return jnp.average(out.ot_costs, weights=weights), out 110 | -------------------------------------------------------------------------------- /src/ott/tools/unreg.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import NamedTuple, Optional, Tuple 15 | 16 | import jax.experimental.sparse as jesp 17 | import jax.numpy as jnp 18 | 19 | from optax import assignment 20 | 21 | from ott.geometry import costs, geometry, pointcloud 22 | 23 | __all__ = ["hungarian"] 24 | 25 | 26 | class HungarianOutput(NamedTuple): 27 | r"""Output of the Hungarian solver. 28 | 29 | Args: 30 | geom: geometry object 31 | paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs 32 | of indices, for which the optimal transport assigns mass. Namely, for each 33 | index :math:`0 \leq k < n`, if one has 34 | :math:`i := \text{paired_indices}[0, k]` and 35 | :math:`j := \text{paired_indices}[1, k]`, then point :math:`i` in 36 | the first geometry sends mass to point :math:`j` in the second. 37 | """ 38 | geom: geometry.Geometry 39 | paired_indices: Optional[jnp.ndarray] = None 40 | 41 | @property 42 | def matrix(self) -> jesp.BCOO: 43 | """``[n, n]`` transport matrix in sparse format, with ``n`` NNZ entries.""" 44 | n, _ = self.geom.shape 45 | unit_mass = jnp.ones((n,)) / n 46 | indices = self.paired_indices.swapaxes(0, 1) 47 | return jesp.BCOO((unit_mass, indices), shape=(n, n)) 48 | 49 | 50 | def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]: 51 | """Solve matching problem using :term:`Hungarian algorithm` from :mod:`optax`. 52 | 53 | Args: 54 | geom: Geometry object with square (shape ``[n,n]``) 55 | :attr:`~ott.geometry.geometry.Geomgetry.cost matrix`. 56 | 57 | Returns: 58 | The value of the unregularized OT problem, along with an output 59 | object listing relevant information on outputs. 60 | """ 61 | n, m = geom.shape 62 | assert n == m, f"Hungarian can only match same # of points, got {n} and {m}." 63 | i, j = assignment.hungarian_algorithm(geom.cost_matrix) 64 | 65 | hungarian_out = HungarianOutput(geom=geom, paired_indices=jnp.stack((i, j))) 66 | return jnp.sum(geom.cost_matrix[i, j]) / n, hungarian_out 67 | 68 | 69 | def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, *, p: float = 2.0) -> float: 70 | """Compute the :term:`Wasserstein distance`, uses :term:`Hungarian algorithm`. 71 | 72 | Uses :func:`hungarian` to solve the :term:`optimal matching problem` between 73 | two point clouds of the same size, to compute a :term:`Wasserstein distance` 74 | estimator. 75 | 76 | Note: 77 | At the moment, only supports point clouds of the same size to be easily 78 | cast as an optimal matching problem. 79 | 80 | Args: 81 | x: ``[n,d]`` point cloud 82 | y: ``[n,d]`` point cloud of the same size 83 | p: order of the Wasserstein distance, non-negative float. 84 | 85 | Returns: 86 | The `p`-Wasserstein distance between these point clouds. 87 | """ 88 | geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p)) 89 | cost, _ = hungarian(geom) 90 | return cost ** (1. / p) 91 | -------------------------------------------------------------------------------- /src/ott/types.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Protocol 15 | 16 | import jax.numpy as jnp 17 | 18 | __all__ = ["Transport"] 19 | 20 | # TODO(michalk8): introduce additional types here 21 | 22 | 23 | class Transport(Protocol): 24 | """Interface for the solution of a transport problem. 25 | 26 | Classes implementing those function do not have to inherit from it, the 27 | class can however be used in type hints to support duck typing. 28 | """ 29 | 30 | @property 31 | def matrix(self) -> jnp.ndarray: 32 | ... 33 | 34 | def apply(self, inputs: jnp.ndarray, axis: int) -> jnp.ndarray: 35 | ... 36 | 37 | def marginal(self, axis: int = 0) -> jnp.ndarray: 38 | ... 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ott-jax/ott/d28d5d45b0bd26d2e3d54fe1085f2835dec5f5d6/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import itertools 15 | from collections import abc 16 | from typing import Any, Mapping, Optional, Sequence 17 | 18 | import pytest 19 | 20 | import jax 21 | import jax.experimental 22 | 23 | import matplotlib as mpl 24 | 25 | 26 | def pytest_sessionstart(session: pytest.Session) -> None: 27 | mpl.use("Agg") 28 | 29 | 30 | def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: 31 | if not hasattr(metafunc.function, "pytestmark"): 32 | # no annotation 33 | return 34 | 35 | fast_marks = [m for m in metafunc.function.pytestmark if m.name == "fast"] 36 | if fast_marks: 37 | mark, = fast_marks 38 | selected: Optional[Mapping[str, Any]] = mark.kwargs.pop("only_fast", None) 39 | ids: Optional[Sequence[str]] = mark.kwargs.pop("ids", None) 40 | 41 | if mark.args: 42 | argnames, argvalues = mark.args 43 | else: 44 | argnames = tuple(mark.kwargs.keys()) 45 | argvalues = [(vs,) if not isinstance(vs, (str, abc.Iterable)) else vs 46 | for vs in mark.kwargs.values()] 47 | argvalues = list(itertools.product(*argvalues)) 48 | 49 | opt = str(metafunc.config.getoption("-m")) 50 | if "fast" in opt: # filter if `-m fast` was passed 51 | if selected is None: 52 | combinations = argvalues 53 | elif isinstance(selected, dict): 54 | combinations = [] 55 | for vs in argvalues: 56 | if selected == dict(zip(argnames, vs)): 57 | combinations.append(vs) 58 | elif isinstance(selected, (tuple, list)): 59 | # TODO(michalk8): support passing ids? 60 | combinations = [argvalues[s] for s in selected] 61 | ids = None if ids is None else [ids[s] for s in selected] 62 | elif isinstance(selected, int): 63 | combinations = [argvalues[selected]] 64 | ids = None if ids is None else [ids[selected]] 65 | else: 66 | raise TypeError(f"Invalid fast selection type `{type(selected)}`.") 67 | else: 68 | combinations = argvalues 69 | 70 | if argnames: 71 | metafunc.parametrize(argnames, combinations, ids=ids) 72 | 73 | 74 | @pytest.fixture() 75 | def rng() -> jax.Array: 76 | return jax.random.key(0) 77 | 78 | 79 | @pytest.fixture() 80 | def enable_x64() -> bool: 81 | with jax.experimental.enable_x64(True): 82 | try: 83 | yield 84 | finally: 85 | pass 86 | -------------------------------------------------------------------------------- /tests/experimental/mmsinkhorn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | from ott.experimental import mmsinkhorn 21 | from ott.geometry import costs, pointcloud 22 | from ott.solvers import linear 23 | 24 | 25 | class TestMMSinkhorn: 26 | 27 | @pytest.mark.fast.with_args( 28 | a_none=[True, False], b_none=[True, False], only_fast=0 29 | ) 30 | def test_match_2sinkhorn(self, a_none: bool, b_none: bool, rng: jax.Array): 31 | """Test consistency of MMSinkhorn for 2 margins vs regular Sinkhorn.""" 32 | n, m, d = 5, 10, 7 33 | rngs = jax.random.split(rng, 5) 34 | x = jax.random.normal(rngs[0], (n, d)) 35 | y = jax.random.normal(rngs[1], (m, d)) + 1 36 | if a_none: 37 | a = None 38 | else: 39 | a = jax.random.uniform(rngs[2], (n,)) 40 | a = a.at[0].set(0.0) 41 | a /= jnp.sum(a) 42 | 43 | if b_none: 44 | b = None 45 | else: 46 | b = jax.random.uniform(rngs[3], (m,)) 47 | b.at[2].set(0.0) 48 | b /= jnp.sum(b) 49 | cost_fn = costs.PNormP(1.8) 50 | geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn) 51 | out = linear.solve(geom, a=a, b=b, threshold=1e-5) 52 | 53 | ab = None if a is None and b is None else [a, b] 54 | solver = jax.jit(mmsinkhorn.MMSinkhorn(threshold=1e-5)) 55 | out_ms = solver([x, y], ab, cost_fns=cost_fn) 56 | assert out.converged 57 | assert out_ms.converged 58 | 59 | for axis in [0, 1]: 60 | f = out.potentials[axis] 61 | f_ms = out_ms.potentials[axis] 62 | f -= jnp.mean(f) 63 | f_ms -= jnp.mean(f_ms) 64 | np.testing.assert_allclose(f, f_ms, rtol=1e-3, atol=1e-3) 65 | np.testing.assert_allclose( 66 | out.ot_prob.geom.epsilon, out_ms.epsilon, rtol=1e-5, atol=1e-5 67 | ) 68 | np.testing.assert_allclose(out.matrix, out_ms.tensor, rtol=1e-2, atol=1e-3) 69 | np.testing.assert_allclose( 70 | out.ent_reg_cost, out_ms.ent_reg_cost, rtol=1e-6, atol=1e-6 71 | ) 72 | 73 | @pytest.mark.fast.with_args( 74 | a_s_none=[True, False], costs_none=[True, False], only_fast=0 75 | ) 76 | def test_mm_sinkhorn(self, a_s_none: bool, costs_none: bool, rng: jax.Array): 77 | """Test correctness of MMSinkhorn for 4 marginals.""" 78 | n_s, d = [13, 5, 10, 3], 7 79 | 80 | rngs = jax.random.split(rng, len(n_s)) 81 | x_s = [jax.random.normal(rng, (n, d)) for rng, n in zip(rngs, n_s)] 82 | 83 | if a_s_none: 84 | a_s = None 85 | else: 86 | a_s = [jax.random.uniform(rng, (n,)) for rng, n in zip(rngs, n_s)] 87 | a_s = [a / jnp.sum(a) for a in a_s] 88 | 89 | if costs_none: 90 | cost_fns = None 91 | else: 92 | cost_fns = [costs.PNormP(1.5) for _ in range(3)] 93 | cost_fns += [costs.PNormP(1.1) for _ in range(3)] 94 | 95 | out_ms = jax.jit(mmsinkhorn.MMSinkhorn(norm_error=1.1) 96 | )(x_s, a_s, cost_fns=cost_fns) 97 | assert out_ms.converged 98 | np.testing.assert_array_equal(out_ms.tensor.shape, n_s) 99 | for i in range(len(n_s)): 100 | np.testing.assert_allclose( 101 | out_ms.marginals[i], out_ms.a_s[i], rtol=1e-3, atol=1e-3 102 | ) 103 | 104 | def test_mm_sinkhorn_diff(self, rng: jax.Array): 105 | """Test differentiability (Danskin) of MMSinkhorn's ent_reg_cost.""" 106 | n_s, d = [13, 5, 7, 3], 2 107 | 108 | rngs = jax.random.split(rng, 2 * len(n_s) + 1) 109 | x_s = [ 110 | jax.random.normal(rng, (n, d)) for rng, n in zip(rngs[:len(n_s)], n_s) 111 | ] 112 | 113 | deltas = [ 114 | jax.random.normal(rng, (n, d)) for rng, n in zip(rngs[len(n_s):], n_s) 115 | ] 116 | eps = 1e-3 117 | x_s_p = [x + eps * delta for x, delta in zip(x_s, deltas)] 118 | x_s_m = [x - eps * delta for x, delta in zip(x_s, deltas)] 119 | 120 | solver = mmsinkhorn.MMSinkhorn(threshold=1e-5) 121 | ent_reg = jax.jit(lambda x_s: solver(x_s).ent_reg_cost) 122 | out_p = ent_reg(x_s_p) 123 | out_m = ent_reg(x_s_m) 124 | ent_g = jax.grad(ent_reg) 125 | g_s = ent_g(x_s) 126 | first_order = 0 127 | for g, delta in zip(g_s, deltas): 128 | first_order += jnp.sum(g * delta) 129 | 130 | np.testing.assert_allclose((out_p - out_m) / (2 * eps), 131 | first_order, 132 | rtol=1e-3, 133 | atol=1e-3) 134 | -------------------------------------------------------------------------------- /tests/geometry/geometry_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | from ott.geometry import epsilon_scheduler, geometry, pointcloud 22 | 23 | 24 | @pytest.mark.fast() 25 | class TestCostMeanStd: 26 | 27 | @pytest.mark.parametrize("geom_type", ["pc", "geometry"]) 28 | def test_cost_stdmean(self, rng: jax.Array, geom_type: str): 29 | """Test consistency of std evaluation.""" 30 | n, m, d = 5, 18, 10 31 | default_scale = epsilon_scheduler.DEFAULT_EPSILON_SCALE 32 | rngs = jax.random.split(rng, 5) 33 | x = jax.random.normal(rngs[0], (n, d)) 34 | y = jax.random.normal(rngs[1], (m, d)) + 1 35 | 36 | geom = pointcloud.PointCloud(x, y) 37 | if geom_type == "geometry": 38 | geom = geometry.Geometry(cost_matrix=geom.cost_matrix) 39 | 40 | std = jnp.std(geom.cost_matrix) 41 | mean = jnp.mean(geom.cost_matrix) 42 | np.testing.assert_allclose(geom.std_cost_matrix, std, rtol=1e-5, atol=1e-5) 43 | 44 | eps = pointcloud.PointCloud(x, y, relative_epsilon="mean").epsilon 45 | np.testing.assert_allclose(default_scale * mean, eps, rtol=1e-5, atol=1e-5) 46 | 47 | eps = pointcloud.PointCloud(x, y, relative_epsilon="std").epsilon 48 | np.testing.assert_allclose(default_scale * std, eps, rtol=1e-5, atol=1e-5) 49 | -------------------------------------------------------------------------------- /tests/geometry/lr_kernel_test.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import pytest 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from ott.geometry import costs, low_rank, pointcloud 10 | from ott.solvers import linear 11 | 12 | 13 | @pytest.mark.fast() 14 | class TestLRCGeometry: 15 | 16 | @pytest.mark.parametrize("std", [1e-1, 1.0, 1e2]) 17 | @pytest.mark.parametrize("kernel", ["gaussian", "arccos"]) 18 | def test_positive_features( 19 | self, rng: jax.Array, kernel: Literal["gaussian", "arccos"], std: float 20 | ): 21 | rng1, rng2 = jax.random.split(rng, 2) 22 | x = jax.random.normal(rng1, (10, 2)) 23 | y = jax.random.normal(rng2, (12, 2)) 24 | rank = 5 25 | 26 | geom = low_rank.LRKGeometry.from_pointcloud( 27 | x, y, kernel=kernel, std=std, rank=rank 28 | ) 29 | 30 | if kernel == "gaussian": 31 | assert geom.rank == rank 32 | else: 33 | assert geom.rank == rank + 1 34 | np.testing.assert_array_equal(geom.k1 >= 0.0, True) 35 | np.testing.assert_array_equal(geom.k2 >= 0.0, True) 36 | 37 | @pytest.mark.parametrize("n", [0, 1, 2]) 38 | def test_arccos_j_function(self, rng: jax.Array, n: int): 39 | 40 | def j(theta: float) -> float: 41 | if n == 0: 42 | return jnp.pi - theta 43 | if n == 1: 44 | return jnp.sin(theta) + (jnp.pi - theta) * jnp.cos(theta) 45 | if n == 2: 46 | return 3.0 * jnp.sin(theta) * jnp.cos(theta) + (jnp.pi - theta) * ( 47 | 1.0 + 2.0 * jnp.cos(theta) ** 2 48 | ) 49 | raise NotImplementedError(n) 50 | 51 | x = jnp.abs(jax.random.normal(rng, (32,))) 52 | cost_fn = costs.Arccos(n) 53 | 54 | gt = jax.vmap(j)(x) 55 | pred = jax.vmap(cost_fn._j)(x) 56 | 57 | np.testing.assert_allclose(gt, pred, rtol=1e-4, atol=1e-4) 58 | 59 | @pytest.mark.parametrize("std", [1e-2, 1e-1, 1.0]) 60 | @pytest.mark.parametrize("kernel", ["gaussian", "arccos"]) 61 | def test_kernel_approximation( 62 | self, rng: jax.Array, kernel: Literal["gaussian", "arccos"], std: float 63 | ): 64 | rng, rng1, rng2 = jax.random.split(rng, 3) 65 | x = jax.random.normal(rng1, (230, 5)) 66 | y = jax.random.normal(rng2, (260, 5)) 67 | n = 1 68 | 69 | cost_fn = costs.SqEuclidean() if kernel == "gaussian" else costs.Arccos(n) 70 | pc = pointcloud.PointCloud(x, y, epsilon=std, cost_fn=cost_fn) 71 | gt_cost = pc.cost_matrix 72 | 73 | max_abs_diff = [] 74 | for rank in [50, 100, 400]: 75 | rng, rng_approx = jax.random.split(rng, 2) 76 | geom = low_rank.LRKGeometry.from_pointcloud( 77 | x, y, rank=rank, kernel=kernel, std=std, n=n, rng=rng_approx 78 | ) 79 | pred_cost = geom.cost_matrix 80 | max_abs_diff.append(np.max(np.abs(gt_cost - pred_cost))) 81 | 82 | # test higher rank better approximates the cost 83 | np.testing.assert_array_equal(np.diff(max_abs_diff) <= 0.0, True) 84 | 85 | @pytest.mark.parametrize(("kernel", "n", "std"), [("gaussian", None, 1e-2), 86 | ("gaussian", None, 1e-1), 87 | ("arccos", 0, 1.0001), 88 | ("arccos", 1, 2.0), 89 | ("arccos", 2, 1.05)]) 90 | def test_sinkhorn_approximation( 91 | self, 92 | rng: jax.Array, 93 | kernel: Literal["gaussian", "arccos"], 94 | std: float, 95 | n: Optional[int], 96 | ): 97 | rng, rng1, rng2 = jax.random.split(rng, 3) 98 | x = jax.random.normal(rng1, (83, 5)) 99 | x /= jnp.linalg.norm(x, keepdims=True) 100 | y = jax.random.normal(rng2, (96, 5)) 101 | y /= jnp.linalg.norm(y, keepdims=True) 102 | solve_fn = jax.jit(lambda g: linear.solve(g, lse_mode=False)) 103 | 104 | cost_fn = costs.SqEuclidean() if kernel == "gaussian" else costs.Arccos(n) 105 | geom = pointcloud.PointCloud(x, y, epsilon=std, cost_fn=cost_fn) 106 | gt_out = solve_fn(geom) 107 | 108 | cs = [] 109 | for rank in [5, 40, 80]: 110 | rng, rng_approx = jax.random.split(rng, 2) 111 | geom = low_rank.LRKGeometry.from_pointcloud( 112 | x, y, rank=rank, kernel=kernel, std=std, n=n, rng=rng_approx 113 | ) 114 | 115 | pred_out = solve_fn(geom) 116 | cs.append(pred_out.reg_ot_cost) 117 | 118 | diff = np.diff(np.abs(gt_out.reg_ot_cost - np.array(cs))) 119 | try: 120 | # test higher rank better approximates the Sinkhorn solution 121 | np.testing.assert_array_equal(diff <= 0.0, True) 122 | except AssertionError: 123 | np.testing.assert_allclose(diff, 0.0, rtol=1e-2, atol=1e-2) 124 | -------------------------------------------------------------------------------- /tests/initializers/linear/sinkhorn_lr_init_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | from ott.geometry import pointcloud 21 | from ott.initializers.linear import initializers_lr 22 | from ott.problems.linear import linear_problem 23 | from ott.solvers.linear import sinkhorn_lr 24 | 25 | 26 | class TestLRInitializers: 27 | 28 | @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) 29 | def test_generalized_k_means_has_correct_rank( 30 | self, rng: jax.Array, rank: int 31 | ): 32 | n, d = 27, 5 33 | x = jax.random.normal(rng, (n, d)) 34 | pc = pointcloud.PointCloud(x, epsilon=0.5) 35 | prob = linear_problem.LinearProblem(pc) 36 | 37 | initializer = initializers_lr.GeneralizedKMeansInitializer(rank) 38 | q, r, g = initializer(prob) 39 | 40 | assert jnp.linalg.matrix_rank(q) == rank 41 | assert jnp.linalg.matrix_rank(r) == rank 42 | 43 | def test_better_initialization_helps(self, rng: jax.Array): 44 | n, d, rank = 81, 13, 3 45 | epsilon = 1e-1 46 | rng1, rng2 = jax.random.split(rng, 2) 47 | x = jax.random.normal(rng1, (n, d)) 48 | y = jax.random.normal(rng2, (n, d)) 49 | pc = pointcloud.PointCloud(x, y, epsilon=5e-1) 50 | prob = linear_problem.LinearProblem(pc) 51 | 52 | solver_random = sinkhorn_lr.LRSinkhorn( 53 | rank=rank, 54 | epsilon=epsilon, 55 | initializer=initializers_lr.RandomInitializer(rank), 56 | max_iterations=10000, 57 | ) 58 | solver_init = sinkhorn_lr.LRSinkhorn( 59 | rank=rank, 60 | epsilon=epsilon, 61 | initializer=initializers_lr.KMeansInitializer(rank), 62 | max_iterations=10000 63 | ) 64 | 65 | out_random = solver_random(prob) 66 | out_init = solver_init(prob) 67 | 68 | assert out_random.converged 69 | assert out_init.converged 70 | # converged earlier 71 | np.testing.assert_array_less((out_init.errors > -1).sum(), 72 | (out_random.errors > -1).sum()) 73 | # converged to a better solution 74 | assert out_init.reg_ot_cost <= out_random.reg_ot_cost 75 | -------------------------------------------------------------------------------- /tests/initializers/neural/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | _ = pytest.importorskip("ott.initializers.neural") 17 | -------------------------------------------------------------------------------- /tests/initializers/neural/meta_initializer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | import pytest 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from flax import linen as nn 22 | 23 | from ott.geometry import pointcloud 24 | from ott.initializers.neural import meta_initializer as meta_init 25 | from ott.problems.linear import linear_problem 26 | from ott.solvers.linear import sinkhorn 27 | 28 | 29 | class MetaMLP(nn.Module): 30 | potential_size: int 31 | num_hidden_units: int = 512 32 | num_hidden_layers: int = 3 33 | 34 | @nn.compact 35 | def __call__(self, z: jnp.ndarray) -> jnp.ndarray: 36 | for _ in range(self.num_hidden_layers): 37 | z = nn.relu(nn.Dense(self.num_hidden_units)(z)) 38 | return nn.Dense(self.potential_size)(z) 39 | 40 | 41 | def create_ot_problem( 42 | rng: jax.Array, 43 | n: int, 44 | m: int, 45 | d: int, 46 | epsilon: float = 1e-2, 47 | batch_size: Optional[int] = None 48 | ) -> linear_problem.LinearProblem: 49 | # define ot problem 50 | x_rng, y_rng = jax.random.split(rng) 51 | mu_a = jnp.array([-1.0, 1.0]) * 5 52 | mu_b = jnp.array([0.0, 0.0]) 53 | x = jax.random.normal(x_rng, (n, d)) + mu_a 54 | y = jax.random.normal(y_rng, (m, d)) + mu_b 55 | geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) 56 | return linear_problem.LinearProblem(geom=geom) 57 | 58 | 59 | @pytest.mark.fast() 60 | class TestMetaInitializer: 61 | 62 | @pytest.mark.parametrize("lse_mode", [True, False]) 63 | def test_meta_initializer(self, rng: jax.Array, lse_mode: bool): 64 | """Tests Meta initializer""" 65 | n, m, d = 32, 30, 2 66 | epsilon = 1e-2 67 | 68 | ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, batch_size=3) 69 | 70 | # run sinkhorn 71 | solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, max_iterations=3000) 72 | sink_out = jax.jit(solver)(ot_problem) 73 | 74 | # overfit the initializer to the problem. 75 | meta_model = MetaMLP(n) 76 | meta_initializer = meta_init.MetaInitializer(ot_problem.geom, meta_model) 77 | for _ in range(50): 78 | _, _, meta_initializer.state = meta_initializer.update( 79 | meta_initializer.state, a=ot_problem.a, b=ot_problem.b 80 | ) 81 | 82 | solver = sinkhorn.Sinkhorn( 83 | initializer=meta_initializer, lse_mode=lse_mode, max_iterations=3000 84 | ) 85 | meta_out = jax.jit(solver)(ot_problem) 86 | 87 | # check initializer is better 88 | if lse_mode: 89 | assert sink_out.converged 90 | assert meta_out.converged 91 | assert sink_out.n_iters > meta_out.n_iters 92 | else: 93 | assert sink_out.n_iters >= meta_out.n_iters 94 | -------------------------------------------------------------------------------- /tests/initializers/quadratic/gw_init_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import numpy as np 18 | 19 | from ott.geometry import pointcloud 20 | from ott.initializers.linear import initializers_lr 21 | from ott.problems.quadratic import quadratic_problem 22 | from ott.solvers.quadratic import gromov_wasserstein_lr 23 | 24 | 25 | class TestQuadraticInitializers: 26 | 27 | def test_explicit_initializer_lr(self): 28 | rank = 10 29 | q_init = initializers_lr.Rank2Initializer(rank) 30 | solver = gromov_wasserstein_lr.LRGromovWasserstein( 31 | rank=rank, 32 | initializer=q_init, 33 | min_iterations=0, 34 | inner_iterations=10, 35 | max_iterations=2000 36 | ) 37 | 38 | assert solver.initializer is q_init 39 | assert solver.initializer.rank == rank 40 | 41 | @pytest.mark.parametrize("eps", [0.0, 1e-1]) 42 | def test_gw_better_initialization_helps(self, rng: jax.Array, eps: float): 43 | n, m, d1, d2, rank = 83, 84, 12, 8, 5 44 | rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) 45 | 46 | geom_x = pointcloud.PointCloud( 47 | jax.random.normal(rng1, (n, d1)), 48 | jax.random.normal(rng2, (n, d1)), 49 | epsilon=eps, 50 | ) 51 | geom_y = pointcloud.PointCloud( 52 | jax.random.normal(rng3, (m, d2)), 53 | jax.random.normal(rng4, (m, d2)), 54 | epsilon=eps, 55 | ) 56 | problem = quadratic_problem.QuadraticProblem(geom_x, geom_y) 57 | solver_random = gromov_wasserstein_lr.LRGromovWasserstein( 58 | rank=rank, 59 | initializer=initializers_lr.RandomInitializer(rank), 60 | epsilon=eps, 61 | min_iterations=0, 62 | inner_iterations=10, 63 | max_iterations=2000 64 | ) 65 | solver_kmeans = gromov_wasserstein_lr.LRGromovWasserstein( 66 | rank=rank, 67 | initializer=initializers_lr.KMeansInitializer(rank), 68 | epsilon=eps, 69 | min_iterations=0, 70 | inner_iterations=10, 71 | max_iterations=2000 72 | ) 73 | 74 | out_random = solver_random(problem) 75 | out_kmeans = solver_kmeans(problem) 76 | 77 | random_cost = out_random.reg_gw_cost 78 | random_errors = out_random.errors[out_random.errors > -1] 79 | kmeans_cost = out_kmeans.reg_gw_cost 80 | kmeans_errors = out_kmeans.errors[out_kmeans.errors > -1] 81 | 82 | np.testing.assert_array_less(kmeans_cost, random_cost) 83 | np.testing.assert_array_equal(random_errors >= 0.0, True) 84 | np.testing.assert_array_equal(kmeans_errors >= 0.0, True) 85 | -------------------------------------------------------------------------------- /tests/math/lse_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | from ott.math import utils as mu 21 | 22 | 23 | @pytest.mark.fast() 24 | class TestGeometryLse: 25 | 26 | def test_lse(self, rng: jax.Array): 27 | """Test consistency of custom lse's jvp.""" 28 | n, m = 12, 8 29 | rngs = jax.random.split(rng, 5) 30 | mat = jax.random.normal(rngs[0], (n, m)) 31 | # picking potentially negative weights on purpose 32 | b_0 = jax.random.normal(rngs[1], (m,)) 33 | b_1 = jax.random.normal(rngs[2], (n, 1)) 34 | 35 | def lse_(x, axis, b, return_sign): 36 | out = mu.logsumexp(x, axis, False, b, return_sign) 37 | return jnp.sum(out[0] if return_sign else out) 38 | 39 | lse = jax.value_and_grad(lse_, argnums=(0, 2)) 40 | for axis in (0, 1): 41 | _, g = lse(mat, axis, None, False) 42 | delta_mat = jax.random.normal(rngs[3], (n, m)) 43 | eps = 1e-3 44 | val_peps = lse(mat + eps * delta_mat, axis, None, False)[0] 45 | val_meps = lse(mat - eps * delta_mat, axis, None, False)[0] 46 | np.testing.assert_allclose((val_peps - val_meps) / (2 * eps), 47 | jnp.sum(delta_mat * g[0]), 48 | rtol=1e-3, 49 | atol=1e-2) 50 | for b, dim, axis in zip((b_0, b_1), (m, n), (1, 0)): 51 | delta_b = jax.random.normal(rngs[4], (dim,)).reshape(b.shape) 52 | _, g = lse(mat, axis, b, True) 53 | eps = 1e-3 54 | val_peps = lse(mat + eps * delta_mat, axis, b + eps * delta_b, True)[0] 55 | val_meps = lse(mat - eps * delta_mat, axis, b - eps * delta_b, True)[0] 56 | np.testing.assert_allclose((val_peps - val_meps) / (2 * eps), 57 | jnp.sum(delta_mat * g[0]) + 58 | jnp.sum(delta_b * g[1]), 59 | rtol=1e-3, 60 | atol=1e-2) 61 | -------------------------------------------------------------------------------- /tests/math/math_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import functools 15 | 16 | import pytest 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | from ott.math import utils as mu 23 | 24 | 25 | class TestNorm: 26 | 27 | @pytest.mark.parametrize("ord", [1.1, 2.0, jnp.inf]) 28 | def test_norm( 29 | self, 30 | rng: jax.Array, 31 | ord, 32 | ): 33 | d = 5 34 | # Test that mu.norm returns properly 0 gradients at 0, when comparing point 35 | # against itself. 36 | f = lambda x: mu.norm(x - x, ord=ord) 37 | x = jax.random.uniform(rng, (d,)) 38 | np.testing.assert_array_equal(jax.grad(f)(x), 0.0) 39 | 40 | # Test same property when using an axis, on jacobians. 41 | ds = (4, 7, 3) 42 | z = jax.random.uniform(rng, ds) 43 | h = lambda x: mu.norm(x - x, ord=ord, axis=1) 44 | jac = jax.jacobian(h)(z) 45 | np.testing.assert_array_equal(jac, 0.0) 46 | 47 | # Check native jax's norm still returns NaNs when ord is float. 48 | if ord != jnp.inf: 49 | g = lambda x: jnp.linalg.norm(x - x, ord=ord) 50 | np.testing.assert_array_equal(jnp.isnan(jax.grad(g)(x)), True) 51 | h = lambda x: jnp.linalg.norm(x - x, ord=ord, axis=1) 52 | jac = jax.jacobian(h)(z) 53 | np.testing.assert_array_equal(jnp.isnan(jac), True) 54 | 55 | # Check both coincide outside of 0. 56 | f = functools.partial(mu.norm, ord=ord) 57 | g = functools.partial(jnp.linalg.norm, ord=ord) 58 | 59 | np.testing.assert_array_equal(jax.grad(f)(x), jax.grad(g)(x)) 60 | np.testing.assert_array_equal(jax.hessian(f)(x), jax.hessian(g)(x)) 61 | 62 | # Check both coincide outside of 0 when using axis. 63 | f = functools.partial(mu.norm, ord=ord, axis=1) 64 | g = functools.partial(jnp.linalg.norm, ord=ord, axis=1) 65 | 66 | np.testing.assert_array_equal(jax.jacobian(f)(z), jax.jacobian(g)(z)) 67 | -------------------------------------------------------------------------------- /tests/neural/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | _ = pytest.importorskip("ott.neural") 17 | -------------------------------------------------------------------------------- /tests/neural/methods/genot_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Literal 15 | 16 | import pytest 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import jax.tree_util as jtu 21 | import numpy as np 22 | 23 | import optax 24 | 25 | from ott.neural.methods.flows import dynamics, genot 26 | from ott.neural.networks import velocity_field 27 | from ott.solvers import utils as solver_utils 28 | 29 | 30 | def get_match_fn(typ: Literal["lin", "quad", "fused"]): 31 | if typ == "lin": 32 | return solver_utils.match_linear 33 | if typ == "quad": 34 | return solver_utils.match_quadratic 35 | if typ == "fused": 36 | return solver_utils.match_quadratic 37 | raise NotImplementedError(typ) 38 | 39 | 40 | class TestGENOT: 41 | 42 | @pytest.mark.parametrize( 43 | "dl", [ 44 | "lin_dl", "quad_dl", "fused_dl", "lin_cond_dl", "quad_cond_dl", 45 | "fused_cond_dl" 46 | ] 47 | ) 48 | def test_genot(self, rng: jax.Array, dl: str, request): 49 | rng_init, rng_call, rng_data = jax.random.split(rng, 3) 50 | problem_type = dl.split("_")[0] 51 | dl = request.getfixturevalue(dl) 52 | 53 | src_dim = dl.lin_dim + dl.quad_src_dim 54 | tgt_dim = dl.lin_dim + dl.quad_tgt_dim 55 | cond_dim = dl.cond_dim 56 | 57 | vf = velocity_field.VelocityField( 58 | hidden_dims=[7, 7, 7], 59 | output_dims=[15, tgt_dim], 60 | condition_dims=None if cond_dim is None else [1, 3, 2], 61 | dropout_rate=0.5, 62 | ) 63 | model = genot.GENOT( 64 | vf, 65 | flow=dynamics.ConstantNoiseFlow(0.0), 66 | data_match_fn=get_match_fn(problem_type), 67 | source_dim=src_dim, 68 | target_dim=tgt_dim, 69 | condition_dim=cond_dim, 70 | rng=rng_init, 71 | optimizer=optax.adam(learning_rate=1e-4), 72 | ) 73 | 74 | _logs = model(dl.loader, n_iters=2, rng=rng_call) 75 | 76 | batch = next(iter(dl.loader)) 77 | batch = jtu.tree_map(jnp.asarray, batch) 78 | src_cond = batch.get("src_condition") 79 | batch_size = 4 if src_cond is None else src_cond.shape[0] 80 | src = jax.random.normal(rng_data, (batch_size, src_dim)) 81 | 82 | res = model.transport(src, condition=src_cond) 83 | 84 | assert len(_logs["loss"]) == 2 85 | np.testing.assert_array_equal(jnp.isfinite(res), True) 86 | assert res.shape == (batch_size, tgt_dim) 87 | -------------------------------------------------------------------------------- /tests/neural/methods/otfm_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.tree_util as jtu 19 | import numpy as np 20 | 21 | import optax 22 | 23 | from ott.neural.methods.flows import dynamics, otfm 24 | from ott.neural.networks import velocity_field 25 | from ott.solvers import utils as solver_utils 26 | 27 | 28 | class TestOTFM: 29 | 30 | @pytest.mark.parametrize("dl", ["lin_dl", "lin_cond_dl"]) 31 | def test_otfm(self, rng: jax.Array, dl: str, request): 32 | dl = request.getfixturevalue(dl) 33 | dim, cond_dim = dl.lin_dim, dl.cond_dim 34 | 35 | vf = velocity_field.VelocityField( 36 | hidden_dims=[5, 5, 5], 37 | output_dims=[7, dim], 38 | condition_dims=None if cond_dim is None else [4, 3, 2], 39 | ) 40 | fm = otfm.OTFlowMatching( 41 | vf, 42 | dynamics.ConstantNoiseFlow(0.0), 43 | match_fn=jax.jit(solver_utils.match_linear), 44 | rng=rng, 45 | optimizer=optax.adam(learning_rate=1e-3), 46 | condition_dim=cond_dim, 47 | ) 48 | 49 | _logs = fm(dl.loader, n_iters=3) 50 | 51 | batch = next(iter(dl.loader)) 52 | batch = jtu.tree_map(jnp.asarray, batch) 53 | src_cond = batch.get("src_condition") 54 | 55 | res_fwd = fm.transport(batch["src_lin"], condition=src_cond) 56 | res_bwd = fm.transport(batch["tgt_lin"], t0=1.0, t1=0.0, condition=src_cond) 57 | 58 | assert len(_logs["loss"]) == 3 59 | 60 | assert res_fwd.shape == batch["src_lin"].shape 61 | assert res_bwd.shape == batch["tgt_lin"].shape 62 | np.testing.assert_array_equal(jnp.isfinite(res_fwd), True) 63 | np.testing.assert_array_equal(jnp.isfinite(res_bwd), True) 64 | -------------------------------------------------------------------------------- /tests/neural/networks/icnn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | from ott.neural.networks import icnn 21 | 22 | 23 | @pytest.mark.fast() 24 | class TestICNN: 25 | 26 | def test_icnn_convexity(self, rng: jax.Array): 27 | """Tests convexity of ICNN.""" 28 | n_samples, n_features = 10, 2 29 | dim_hidden = (64, 64) 30 | 31 | # define icnn model 32 | model = icnn.ICNN(n_features, dim_hidden=dim_hidden) 33 | 34 | # initialize model 35 | rng1, rng2 = jax.random.split(rng, 2) 36 | params = model.init(rng1, jnp.ones((1, n_features))) 37 | 38 | # check convexity 39 | x = jax.random.normal(rng1, (n_samples, n_features)) * 0.1 40 | y = jax.random.normal(rng2, (n_samples, n_features)) 41 | 42 | out_x = model.apply(params, x) 43 | out_y = model.apply(params, y) 44 | 45 | out = [] 46 | for t in jnp.linspace(0, 1): 47 | out_xy = model.apply(params, t * x + (1 - t) * y) 48 | out.append((t * out_x + (1 - t) * out_y) - out_xy) 49 | 50 | np.testing.assert_array_equal(np.array(out) >= 0.0, True) 51 | 52 | def test_icnn_hessian(self, rng: jax.Array): 53 | """Tests if Hessian of ICNN is positive-semidefinite.""" 54 | 55 | # define icnn model 56 | n_features = 2 57 | dim_hidden = (64, 64) 58 | model = icnn.ICNN(n_features, dim_hidden=dim_hidden) 59 | 60 | # initialize model 61 | rng1, rng2 = jax.random.split(rng) 62 | params = model.init(rng1, jnp.ones((1, n_features))) 63 | 64 | # check if the Hessian is positive-semidefinite via eigenvalues 65 | data = jax.random.normal(rng2, (n_features,)) 66 | 67 | # add batch dimension and compute the Hessian 68 | hessian = jax.hessian(lambda x: model.apply(params, x[None]))(data) 69 | 70 | # compute the Eigenvalues 71 | w = jnp.linalg.eigvalsh((hessian + hessian.T) / 2.0) 72 | np.testing.assert_array_equal(w >= 0, True) 73 | -------------------------------------------------------------------------------- /tests/solvers/linear/discrete_barycenter_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax.numpy as jnp 17 | 18 | from ott.geometry import grid, pointcloud 19 | from ott.problems.linear import barycenter_problem as bp 20 | from ott.solvers.linear import discrete_barycenter as db 21 | 22 | 23 | class TestDiscreteBarycenter: 24 | 25 | @pytest.mark.parametrize(("lse_mode", "debiased", "epsilon"), 26 | [(True, True, 1e-2), (False, False, 2e-2)], 27 | ids=["lse-deb", "scal-no-deb"]) 28 | def test_discrete_barycenter_grid( 29 | self, lse_mode: bool, debiased: bool, epsilon: float 30 | ): 31 | """Tests the discrete barycenters on a 5x5x5 grid. 32 | 33 | Puts two masses on opposing ends of the hypercube with small noise in 34 | between. Check that their W barycenter sits (mostly) at the middle of the 35 | hypercube (e.g. index (5x5x5-1)/2) 36 | 37 | Args: 38 | lse_mode: bool, lse or scaling computations. 39 | debiased: bool, use (or not) debiasing as proposed in 40 | https://arxiv.org/abs/2006.02575 41 | epsilon: float, regularization parameter 42 | """ 43 | size = jnp.array([5, 5, 5]) 44 | grid_3d = grid.Grid(grid_size=size, epsilon=epsilon) 45 | a1 = jnp.ones(size).ravel() 46 | a2 = jnp.ones(size).ravel() 47 | a1 = a1.at[0].set(10000) 48 | a2 = a2.at[-1].set(10000) 49 | a1 /= jnp.sum(a1) 50 | a2 /= jnp.sum(a2) 51 | threshold = 1e-2 52 | 53 | fixed_bp = bp.FixedBarycenterProblem( 54 | geom=grid_3d, a=jnp.stack((a1, a2)), weights=jnp.array([0.5, 0.5]) 55 | ) 56 | solver = db.FixedBarycenter( 57 | threshold=threshold, lse_mode=lse_mode, debiased=debiased 58 | ) 59 | out = solver(fixed_bp) 60 | bar, errors = out.histogram, out.errors 61 | 62 | assert bar[(jnp.prod(size) - 1) // 2] > 0.7 63 | assert bar[(jnp.prod(size) - 1) // 2] < 1 64 | err = errors[jnp.isfinite(errors)][-1] 65 | assert threshold > err 66 | 67 | @pytest.mark.parametrize(("lse_mode", "epsilon"), [(True, 1e-3), 68 | (False, 1e-2)], 69 | ids=["lse", "scale"]) 70 | def test_discrete_barycenter_pointcloud(self, lse_mode: bool, epsilon: float): 71 | """Tests the discrete barycenters on pointclouds. 72 | 73 | Two measures supported on the same set of points (a 1D grid), barycenter is 74 | evaluated on a different set of points (still in 1D). 75 | 76 | Args: 77 | lse_mode: bool, lse or scaling computations 78 | epsilon: float 79 | """ 80 | n = 50 81 | ma = 0.2 82 | mb = 0.8 83 | # define two narrow Gaussian bumps in segment [0,1] 84 | a = jnp.exp(-(jnp.arange(0, n) / (n - 1) - ma) ** 2 / 0.01) + 1e-10 85 | b = jnp.exp(-(jnp.arange(0, n) / (n - 1) - mb) ** 2 / 0.01) + 1e-10 86 | a = a / jnp.sum(a) 87 | b = b / jnp.sum(b) 88 | 89 | # positions on the real line where weights are supported. 90 | x = jnp.atleast_2d(jnp.arange(0, n) / (n - 1)).T 91 | 92 | # choose a different support, half the size, for the barycenter. 93 | # note this is the reason why we do not use debiasing in this case. 94 | x_support_bar = jnp.atleast_2d((jnp.arange(0, (n / 2)) / 95 | (n / 2 - 1) - 0.5) * 0.9 + 0.5).T 96 | 97 | geom = pointcloud.PointCloud(x, x_support_bar, epsilon=epsilon) 98 | fixed_bp = bp.FixedBarycenterProblem(geom, a=jnp.stack((a, b))) 99 | 100 | out = db.FixedBarycenter(lse_mode=lse_mode)(fixed_bp) 101 | bar = out.histogram 102 | # check the barycenter has bump in the middle. 103 | assert bar[n // 4] > 0.1 104 | -------------------------------------------------------------------------------- /tests/solvers/quadratic/lower_bound_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import numpy as np 19 | 20 | from ott.geometry import costs, distrib_costs, pointcloud 21 | from ott.problems.quadratic import quadratic_problem 22 | from ott.solvers.linear import univariate 23 | from ott.solvers.quadratic import lower_bound 24 | 25 | 26 | @pytest.mark.fast() 27 | class TestLowerBound: 28 | 29 | @pytest.fixture(autouse=True) 30 | def initialize(self, rng: jax.Array): 31 | d_x = 2 32 | d_y = 3 33 | self.n, self.m = 13, 15 34 | rngs = jax.random.split(rng, 4) 35 | self.x = jax.random.uniform(rngs[0], (self.n, d_x)) 36 | self.y = jax.random.uniform(rngs[1], (self.m, d_y)) 37 | a = jnp.ones(self.n) 38 | b = jnp.ones(self.m) 39 | self.a = a / jnp.sum(a) 40 | self.b = b / jnp.sum(b) 41 | self.cx = jax.random.uniform(rngs[2], (self.n, self.n)) 42 | self.cy = jax.random.uniform(rngs[3], (self.m, self.m)) 43 | 44 | @pytest.mark.parametrize( 45 | "ground_cost", 46 | [costs.SqEuclidean(), costs.PNormP(1.3)] 47 | ) 48 | def test_lb_pointcloud(self, ground_cost: costs.TICost): 49 | x, y = self.x, self.y 50 | 51 | geom_x = pointcloud.PointCloud(x) 52 | geom_y = pointcloud.PointCloud(y) 53 | prob = quadratic_problem.QuadraticProblem( 54 | geom_x, geom_y, a=self.a, b=self.b 55 | ) 56 | solve_fn = univariate.quantile_solver 57 | distrib_cost = distrib_costs.UnivariateWasserstein( 58 | solve_fn, ground_cost=ground_cost 59 | ) 60 | 61 | solver = jax.jit(lower_bound.third_lower_bound) 62 | out = solver(prob, distrib_cost, epsilon=1e-3) 63 | 64 | assert jnp.isfinite(out.reg_ot_cost) 65 | 66 | @pytest.mark.parametrize(("ground_cost", "uniform", "eps"), 67 | [(costs.SqEuclidean(), True, 1e-2), 68 | (costs.PNormP(1.3), False, 1e-1)]) 69 | def test_lb_different_solvers( 70 | self, ground_cost: costs.TICost, uniform: bool, eps: float 71 | ): 72 | x, y, a, b = self.x, self.y, self.a, self.b 73 | if uniform: 74 | k = min(self.n, self.m) 75 | x, y, a, b = x[:k], y[:k], a[:k], b[:k] 76 | 77 | geom_x = pointcloud.PointCloud(x) 78 | geom_y = pointcloud.PointCloud(y) 79 | prob = quadratic_problem.QuadraticProblem(geom_x, geom_y, a=a, b=b) 80 | 81 | distrib_cost_unif = distrib_costs.UnivariateWasserstein( 82 | solve_fn=univariate.uniform_solver, ground_cost=ground_cost 83 | ) 84 | distrib_cost_quant = distrib_costs.UnivariateWasserstein( 85 | solve_fn=univariate.quantile_solver, ground_cost=ground_cost 86 | ) 87 | distrib_cost_nw = distrib_costs.UnivariateWasserstein( 88 | solve_fn=univariate.north_west_solver, ground_cost=ground_cost 89 | ) 90 | 91 | solver = jax.jit(lower_bound.third_lower_bound) 92 | 93 | out_unif = solver(prob, distrib_cost_unif, epsilon=eps) if uniform else None 94 | out_quant = solver(prob, distrib_cost_quant, epsilon=eps) 95 | out_nw = solver(prob, distrib_cost_nw, epsilon=eps) 96 | 97 | np.testing.assert_allclose( 98 | out_quant.reg_ot_cost, out_nw.reg_ot_cost, rtol=1e-6, atol=1e-6 99 | ) 100 | np.testing.assert_allclose( 101 | out_quant.matrix, out_nw.matrix, rtol=1e-6, atol=1e-6 102 | ) 103 | if out_unif is not None: 104 | np.testing.assert_allclose( 105 | out_quant.reg_ot_cost, out_unif.reg_ot_cost, rtol=1e-6, atol=1e-6 106 | ) 107 | np.testing.assert_allclose( 108 | out_quant.matrix, out_unif.matrix, rtol=1e-6, atol=1e-6 109 | ) 110 | -------------------------------------------------------------------------------- /tests/tools/conformal_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import functools 15 | import math 16 | from typing import Callable, Tuple 17 | 18 | import pytest 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from sklearn import datasets, linear_model, model_selection 24 | 25 | from ott.tools import conformal 26 | 27 | 28 | def get_model_and_data( 29 | *, 30 | n_samples: int, 31 | target_dim: int, 32 | random_state: int = 0, 33 | ) -> Tuple[Callable[[jnp.ndarray], jnp.ndarray], Tuple[jnp.ndarray, ...]]: 34 | x, y = datasets.make_regression( 35 | n_samples=n_samples, 36 | n_features=5, 37 | n_targets=target_dim, 38 | random_state=random_state 39 | ) 40 | x_trn, x_calib, y_trn, y_calib = model_selection.train_test_split( 41 | x, y, train_size=0.7, random_state=random_state 42 | ) 43 | x_calib, x_test, y_calib, y_test = model_selection.train_test_split( 44 | x_calib, y_calib, train_size=0.5, random_state=random_state 45 | ) 46 | model = linear_model.LinearRegression().fit(x_trn, y_trn) 47 | 48 | A = jnp.asarray(model.coef_) 49 | b = jnp.asarray(model.intercept_[None]) 50 | model_fn = jax.jit(lambda x: x @ A.T + b) 51 | data = jax.tree.map( 52 | jnp.asarray, (x_trn, x_calib, x_test, y_trn, y_calib, y_test) 53 | ) 54 | return model_fn, data 55 | 56 | 57 | class TestOTCP: 58 | 59 | @pytest.mark.parametrize("shape", [(16, 2), (58, 9), (128, 9)]) 60 | def test_sobol_ball_sampler(self, shape: Tuple[int, int], rng: jax.Array): 61 | n, d = shape 62 | n_per_radius = math.ceil(math.sqrt(n)) 63 | n_sphere, n_0s = divmod(n, n_per_radius) 64 | n_expected = n_sphere * n_per_radius + (n_0s > 0) 65 | 66 | sample_fn = jax.jit(conformal.sobol_ball_sampler, static_argnames="shape") 67 | points, weights = sample_fn(rng, shape, n_per_radius=None) 68 | 69 | assert weights.shape == (n_expected,) 70 | assert points.shape == (n_expected, d) 71 | np.testing.assert_allclose(weights.sum(), 1.0, rtol=1e-4, atol=1e-4) 72 | if n_0s: 73 | np.testing.assert_array_equal(points[-1], 0.0) 74 | 75 | @pytest.mark.parametrize(("target_dim", "epsilon", "sampler_fn"), [( 76 | 3, 1e-1, functools.partial(conformal.sobol_ball_sampler, n_per_radius=10) 77 | ), (5, 1e-2, None)]) 78 | def test_otcp( 79 | self, rng: jax.Array, target_dim: int, epsilon: float, 80 | sampler_fn: Callable 81 | ): 82 | n_samples = 32 83 | n_target_measure = n_samples 84 | model, data = get_model_and_data(n_samples=n_samples, target_dim=target_dim) 85 | x_trn, x_calib, x_test, y_trn, y_calib, y_test = data 86 | 87 | otcp = conformal.OTCP(model, sampler=sampler_fn) 88 | otcp = jax.jit( 89 | otcp.fit_transport, static_argnames=["n_target"] 90 | )(x_trn, y_trn, epsilon=epsilon, n_target=n_target_measure, rng=rng) 91 | otcp = jax.jit(otcp.calibrate)(x_calib, y_calib) 92 | predict_fn = jax.jit(otcp.predict) 93 | 94 | calib_scores = otcp.calibration_scores 95 | np.testing.assert_array_equal( 96 | jax.jit(otcp.get_scores)(x_calib, y_calib), calib_scores 97 | ) 98 | 99 | # predict backward 100 | preds = predict_fn(x_test[0]) 101 | assert preds.shape == (len(otcp.target_measure), target_dim) 102 | preds = predict_fn(x_test) # vectorized 103 | assert preds.shape == (len(x_test), len(otcp.target_measure), target_dim) 104 | 105 | # predict forward 106 | preds = predict_fn(x_test[0], y_candidates=y_test) 107 | assert preds.shape == (len(y_test),) 108 | assert jnp.isdtype(preds.dtype, jnp.bool) 109 | preds = predict_fn(x_test, y_candidates=y_test) # vectorized 110 | assert preds.shape == (len(x_test), len(y_test)) 111 | assert jnp.isdtype(preds.dtype, jnp.bool) 112 | -------------------------------------------------------------------------------- /tests/tools/gaussian_mixture/fit_gmm_pair_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | from ott.tools.gaussian_mixture import ( 20 | fit_gmm, 21 | fit_gmm_pair, 22 | gaussian_mixture, 23 | gaussian_mixture_pair, 24 | ) 25 | 26 | 27 | class TestFitGmmPair: 28 | 29 | @pytest.fixture(autouse=True) 30 | def initialize(self, rng: jax.Array): 31 | mean_generator0 = jnp.array([[2.0, -1.0], [-2.0, 0.0], [4.0, 3.0]]) 32 | cov_generator0 = jnp.array([[[0.2, 0.0], [0.0, 0.1]], 33 | [[0.6, 0.0], [0.0, 0.3]], 34 | [[0.5, 0.4], [0.4, 0.5]]]) 35 | weights_generator0 = jnp.array([0.3, 0.3, 0.4]) 36 | 37 | gmm_generator0 = ( 38 | gaussian_mixture.GaussianMixture.from_mean_cov_component_weights( 39 | mean=mean_generator0, 40 | cov=cov_generator0, 41 | component_weights=weights_generator0 42 | ) 43 | ) 44 | 45 | # shift the means to the right by varying amounts 46 | mean_generator1 = mean_generator0 + jnp.array([[1.0, -0.5], [-1.0, -1.0], 47 | [-1.0, 0.0]]) 48 | cov_generator1 = cov_generator0 49 | weights_generator1 = weights_generator0 + jnp.array([0.0, 0.1, -0.1]) 50 | 51 | gmm_generator1 = ( 52 | gaussian_mixture.GaussianMixture.from_mean_cov_component_weights( 53 | mean=mean_generator1, 54 | cov=cov_generator1, 55 | component_weights=weights_generator1 56 | ) 57 | ) 58 | 59 | self.epsilon = 1e-2 60 | self.rho = 0.1 61 | self.tau = self.rho / (self.rho + self.epsilon) 62 | 63 | self.rng, subrng0, subrng1 = jax.random.split(rng, num=3) 64 | self.samples_gmm0 = gmm_generator0.sample(rng=subrng0, size=2000) 65 | self.samples_gmm1 = gmm_generator1.sample(rng=subrng1, size=2000) 66 | 67 | # requires Schur decomposition, which jax does not implement on GPU 68 | @pytest.mark.cpu() 69 | @pytest.mark.fast.with_args( 70 | balanced=[False, True], weighted=[False, True], only_fast=0 71 | ) 72 | def test_fit_gmm(self, balanced, weighted): 73 | # dumb integration test that makes sure nothing crashes 74 | tau = 1.0 if balanced else self.tau 75 | 76 | if weighted: 77 | weights0 = jnp.ones(self.samples_gmm0.shape[0]) 78 | weights1 = jnp.ones(self.samples_gmm0.shape[0]) 79 | weights_pooled = jnp.concatenate([weights0, weights1], axis=0) 80 | else: 81 | weights0 = None 82 | weights1 = None 83 | weights_pooled = None 84 | 85 | # Fit a GMM to the pooled samples 86 | samples = jnp.concatenate([self.samples_gmm0, self.samples_gmm1]) 87 | gmm_init = fit_gmm.initialize( 88 | rng=self.rng, 89 | points=samples, 90 | point_weights=weights_pooled, 91 | n_components=3, 92 | verbose=False 93 | ) 94 | gmm = fit_gmm.fit_model_em( 95 | gmm=gmm_init, points=samples, point_weights=None, steps=20 96 | ) 97 | # use the same mixture model for gmm0 and gmm1 initially 98 | pair_init = gaussian_mixture_pair.GaussianMixturePair( 99 | gmm0=gmm, gmm1=gmm, epsilon=self.epsilon, tau=tau 100 | ) 101 | fit_model_em_fn = fit_gmm_pair.get_fit_model_em_fn( 102 | weight_transport=0.1, jit=True 103 | ) 104 | fit_model_em_fn( 105 | pair=pair_init, 106 | points0=self.samples_gmm0, 107 | points1=self.samples_gmm1, 108 | point_weights0=weights0, 109 | point_weights1=weights1, 110 | em_steps=1, 111 | m_steps=10, 112 | verbose=False 113 | ) 114 | -------------------------------------------------------------------------------- /tests/tools/gaussian_mixture/fit_gmm_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.test_util 19 | 20 | from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture 21 | 22 | 23 | @pytest.mark.fast() 24 | class TestFitGmm: 25 | 26 | @pytest.fixture(autouse=True) 27 | def initialize(self, rng: jax.Array): 28 | mean_generator = jnp.array([[2.0, -1.0], [-2.0, 0.0], [4.0, 3.0]]) 29 | cov_generator = jnp.array([[[0.2, 0.0], [0.0, 0.1]], [[0.6, 0.0], 30 | [0.0, 0.3]], 31 | [[0.5, 0.4], [0.4, 0.5]]]) 32 | weights_generator = jnp.array([0.3, 0.3, 0.4]) 33 | 34 | gmm_generator = ( 35 | gaussian_mixture.GaussianMixture.from_mean_cov_component_weights( 36 | mean=mean_generator, 37 | cov=cov_generator, 38 | component_weights=weights_generator 39 | ) 40 | ) 41 | 42 | self.rng, subrng = jax.random.split(rng) 43 | self.samples = gmm_generator.sample(rng=subrng, size=2000) 44 | 45 | def test_integration(self): 46 | # dumb integration test that makes sure nothing crashes 47 | 48 | # Fit a GMM to the samples 49 | gmm_init = fit_gmm.initialize( 50 | rng=self.rng, 51 | points=self.samples, 52 | point_weights=None, 53 | n_components=3, 54 | verbose=False 55 | ) 56 | _ = fit_gmm.fit_model_em( 57 | gmm=gmm_init, points=self.samples, point_weights=None, steps=20 58 | ) 59 | -------------------------------------------------------------------------------- /tests/tools/gaussian_mixture/probabilities_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.tree_util as jtu 19 | import numpy as np 20 | 21 | from ott.tools.gaussian_mixture import probabilities 22 | 23 | 24 | @pytest.mark.fast() 25 | class TestProbabilities: 26 | 27 | def test_probs(self): 28 | pp = probabilities.Probabilities(jnp.array([1.0, 2.0])) 29 | probs = pp.probs() 30 | np.testing.assert_array_equal(probs.shape, (3,)) 31 | np.testing.assert_allclose(jnp.sum(probs), 1.0, rtol=1e-6, atol=1e-6) 32 | np.testing.assert_array_equal(probs > 0.0, True) 33 | 34 | def test_log_probs(self): 35 | pp = probabilities.Probabilities(jnp.array([1.0, 2.0])) 36 | log_probs = pp.log_probs() 37 | probs = jnp.exp(log_probs) 38 | 39 | np.testing.assert_array_equal(log_probs.shape, (3,)) 40 | np.testing.assert_array_equal(probs.shape, (3,)) 41 | np.testing.assert_allclose(jnp.sum(probs), 1.0, rtol=1e-6, atol=1e-6) 42 | np.testing.assert_array_equal(probs > 0.0, True) 43 | 44 | def test_from_random(self, rng: jax.Array): 45 | n_dimensions = 4 46 | pp = probabilities.Probabilities.from_random( 47 | rng=rng, n_dimensions=n_dimensions, stdev=0.1 48 | ) 49 | np.testing.assert_array_equal(pp.probs().shape, (4,)) 50 | 51 | def test_from_probs(self): 52 | probs = jnp.array([0.1, 0.2, 0.3, 0.4]) 53 | pp = probabilities.Probabilities.from_probs(probs) 54 | np.testing.assert_allclose(probs, pp.probs(), rtol=1e-6, atol=1e-6) 55 | 56 | def test_sample(self, rng: jax.Array): 57 | p = 0.4 58 | probs = jnp.array([p, 1.0 - p]) 59 | pp = probabilities.Probabilities.from_probs(probs) 60 | samples = pp.sample(rng, size=10000) 61 | sd = jnp.sqrt(p * (1.0 - p)) 62 | np.testing.assert_allclose(jnp.mean(samples == 0.0), p, atol=3.0 * sd) 63 | 64 | def test_flatten_unflatten(self): 65 | probs = jnp.array([0.1, 0.2, 0.3, 0.4]) 66 | pp = probabilities.Probabilities.from_probs(probs) 67 | children, aux_data = jtu.tree_flatten(pp) 68 | pp_new = jtu.tree_unflatten(aux_data, children) 69 | np.testing.assert_array_equal(pp.params, pp_new.params) 70 | assert pp == pp_new 71 | 72 | def test_pytree_mapping(self): 73 | probs = jnp.array([0.1, 0.2, 0.3, 0.4]) 74 | pp = probabilities.Probabilities.from_probs(probs) 75 | pp_x_2 = jtu.tree_map(lambda x: 2 * x, pp) 76 | np.testing.assert_allclose( 77 | 2.0 * pp.params, pp_x_2.params, rtol=1e-6, atol=1e-6 78 | ) 79 | -------------------------------------------------------------------------------- /tests/tools/gaussian_mixture/scale_tril_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.tree_util as jtu 19 | import numpy as np 20 | 21 | from ott.math import matrix_square_root 22 | from ott.tools.gaussian_mixture import scale_tril 23 | 24 | 25 | @pytest.fixture() 26 | def chol() -> scale_tril.ScaleTriL: 27 | params = jnp.array([0.0, 2.0, jnp.log(3.0)]) 28 | return scale_tril.ScaleTriL(params=params, size=2) 29 | 30 | 31 | @pytest.mark.fast() 32 | class TestScaleTriL: 33 | 34 | def test_cholesky(self, chol: scale_tril.ScaleTriL): 35 | expected = jnp.array([[1.0, 0.0], [2.0, 3.0]]) 36 | np.testing.assert_allclose(chol.cholesky(), expected, atol=1e-4, rtol=1e-4) 37 | 38 | def test_covariance(self, chol: scale_tril.ScaleTriL): 39 | expected = jnp.array([[1.0, 0.0], [2.0, 3.0]]) 40 | np.testing.assert_allclose(chol.covariance(), expected @ expected.T) 41 | 42 | def test_covariance_sqrt(self, chol: scale_tril.ScaleTriL): 43 | actual = chol.covariance_sqrt() 44 | expected = matrix_square_root.sqrtm_only(chol.covariance()) 45 | np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1e-4) 46 | 47 | def test_log_det_covariance(self, chol: scale_tril.ScaleTriL): 48 | expected = jnp.log(jnp.linalg.det(chol.covariance())) 49 | actual = chol.log_det_covariance() 50 | np.testing.assert_almost_equal(actual, expected) 51 | 52 | def test_from_random(self, rng: jax.Array): 53 | n_dimensions = 4 54 | cov = scale_tril.ScaleTriL.from_random( 55 | rng=rng, n_dimensions=n_dimensions, stdev=0.1 56 | ) 57 | np.testing.assert_array_equal( 58 | cov.cholesky().shape, (n_dimensions, n_dimensions) 59 | ) 60 | 61 | def test_from_cholesky(self, rng: jax.Array): 62 | n_dimensions = 4 63 | cholesky = scale_tril.ScaleTriL.from_random( 64 | rng=rng, n_dimensions=n_dimensions, stdev=1.0 65 | ).cholesky() 66 | scale = scale_tril.ScaleTriL.from_cholesky(cholesky) 67 | np.testing.assert_allclose(cholesky, scale.cholesky(), atol=1e-4, rtol=1e-4) 68 | 69 | def test_w2_dist(self, rng: jax.Array): 70 | # make sure distance between a random normal and itself is 0 71 | rng, subrng = jax.random.split(rng) 72 | s = scale_tril.ScaleTriL.from_random(rng=subrng, n_dimensions=3) 73 | w2 = s.w2_dist(s) 74 | expected = 0.0 75 | np.testing.assert_allclose(expected, w2, atol=1e-4, rtol=1e-4) 76 | 77 | # When covariances commute (e.g. if covariance is diagonal), have 78 | # distance between covariances = Frobenius norm^2 of (delta sqrt(cov)), see 79 | # see https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # noqa: E501 80 | size = 4 81 | rng, subrng0, subrng1 = jax.random.split(rng, num=3) 82 | diag0 = jnp.exp(jax.random.normal(subrng0, shape=(size,))) 83 | diag1 = jnp.exp(jax.random.normal(subrng1, shape=(size,))) 84 | s0 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag0)) 85 | s1 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag1)) 86 | w2 = s0.w2_dist(s1) 87 | delta_sigma = jnp.sum((jnp.sqrt(diag0) - jnp.sqrt(diag1)) ** 2) 88 | np.testing.assert_allclose(delta_sigma, w2, atol=1e-4, rtol=1e-4) 89 | 90 | def test_transport(self, rng: jax.Array): 91 | size = 4 92 | rng, subrng0, subrng1 = jax.random.split(rng, num=3) 93 | diag0 = jnp.exp(jax.random.normal(subrng0, shape=(size,))) 94 | s0 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag0)) 95 | diag1 = jnp.exp(jax.random.normal(subrng1, shape=(size,))) 96 | s1 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag1)) 97 | 98 | rng, subrng = jax.random.split(rng) 99 | x = jax.random.normal(subrng, shape=(100, size)) 100 | transported = s0.transport(s1, points=x) 101 | expected = x * jnp.sqrt(diag1)[None] / jnp.sqrt(diag0)[None] 102 | np.testing.assert_allclose(expected, transported, atol=1e-4, rtol=1e-4) 103 | 104 | def test_flatten_unflatten(self, rng: jax.Array): 105 | scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) 106 | children, aux_data = jtu.tree_flatten(scale) 107 | scale_new = jtu.tree_unflatten(aux_data, children) 108 | np.testing.assert_array_equal(scale.params, scale_new.params) 109 | assert scale == scale_new 110 | 111 | def test_pytree_mapping(self, rng: jax.Array): 112 | scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) 113 | scale_x_2 = jtu.tree_map(lambda x: 2 * x, scale) 114 | np.testing.assert_allclose(2.0 * scale.params, scale_x_2.params) 115 | -------------------------------------------------------------------------------- /tests/tools/plot_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | from ott.geometry import pointcloud 20 | from ott.problems.linear import linear_problem 21 | from ott.solvers.linear import sinkhorn 22 | from ott.tools import plot 23 | 24 | 25 | class TestPlotting: 26 | 27 | def test_plot(self, rng: jax.Array, monkeypatch): 28 | monkeypatch.setattr(plt, "show", lambda: None) 29 | n, m, d = 12, 7, 3 30 | rngs = jax.random.split(rng, 3) 31 | xs = [ 32 | jax.random.normal(rngs[0], (n, d)) + 1, 33 | jax.random.normal(rngs[1], (n, d)) + 1 34 | ] 35 | y = jax.random.uniform(rngs[2], (m, d)) 36 | 37 | solver = sinkhorn.Sinkhorn() 38 | ots = [ 39 | solver(linear_problem.LinearProblem(pointcloud.PointCloud(x, y))) 40 | for x in xs 41 | ] 42 | 43 | plott = plot.Plot() 44 | _ = plott(ots[0]) 45 | fig = plt.figure(figsize=(8, 5)) 46 | plott = plot.Plot(fig=fig, title="test") 47 | plott.animate(ots, frame_rate=2, titles=["test1", "test2"]) 48 | -------------------------------------------------------------------------------- /tests/tools/sliced_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional, Tuple 15 | 16 | import pytest 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | from ott.geometry import costs, pointcloud 23 | from ott.solvers import linear 24 | from ott.tools import sliced 25 | 26 | Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray] 27 | 28 | 29 | def custom_proj( 30 | x: jnp.ndarray, 31 | rng: Optional[jax.Array] = None, 32 | n_proj: int = 27 33 | ) -> jnp.ndarray: 34 | dim = x.shape[1] 35 | rng = jax.random.key(42) if rng is None else rng 36 | proj_m = jax.random.uniform(rng, (n_proj, dim)) 37 | return (x @ proj_m.T) ** 2 38 | 39 | 40 | def gen_data( 41 | rng: jax.Array, n: int, m: int, dim: int 42 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: 43 | rngs = jax.random.split(rng, 4) 44 | x = jax.random.uniform(rngs[0], (n, dim)) 45 | y = jax.random.uniform(rngs[1], (m, dim)) 46 | a = jax.random.uniform(rngs[2], (n,)) 47 | b = jax.random.uniform(rngs[3], (m,)) 48 | a /= jnp.sum(a) 49 | b /= jnp.sum(b) 50 | return a, x, b, y 51 | 52 | 53 | class TestSliced: 54 | 55 | @pytest.mark.parametrize("proj_fn", [None, custom_proj]) 56 | @pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None]) 57 | def test_random_projs( 58 | self, rng: jax.Array, cost_fn: Optional[costs.CostFn], 59 | proj_fn: Optional[Projector] 60 | ): 61 | n, m, dim, n_proj = 12, 17, 5, 13 62 | rng1, rng2 = jax.random.split(rng, 2) 63 | a, x, b, y = gen_data(rng1, n, m, dim) 64 | weights = jax.random.uniform(rng2, (n_proj,)) 65 | 66 | # Test non-negative and returns output as needed. 67 | cost, out = sliced.sliced_wasserstein( 68 | x, 69 | y, 70 | a, 71 | b, 72 | cost_fn=cost_fn, 73 | proj_fn=proj_fn, 74 | n_proj=n_proj, 75 | rng=rng2, 76 | weights=weights 77 | ) 78 | assert cost > 0.0 79 | np.testing.assert_array_equal( 80 | cost, jnp.average(out.ot_costs, weights=weights) 81 | ) 82 | 83 | @pytest.mark.parametrize("cost_fn", [costs.SqPNorm(1.4), None]) 84 | def test_consistency_with_id( 85 | self, rng: jax.Array, cost_fn: Optional[costs.CostFn] 86 | ): 87 | n, m, dim = 11, 12, 4 88 | a, x, b, y = gen_data(rng, n, m, dim) 89 | 90 | # Test matches standard implementation when using identity. 91 | cost, _ = sliced.sliced_wasserstein( 92 | x, y, proj_fn=lambda x: x, cost_fn=cost_fn 93 | ) 94 | geom = pointcloud.PointCloud(x=x, y=y, cost_fn=cost_fn) 95 | out_lin = jnp.mean(linear.solve_univariate(geom).ot_costs) 96 | np.testing.assert_allclose(out_lin, cost, rtol=1e-6, atol=1e-6) 97 | 98 | @pytest.mark.parametrize("proj_fn", [None, custom_proj]) 99 | def test_diff(self, rng: jax.Array, proj_fn: Optional[Projector]): 100 | eps = 1e-4 101 | n, m, dim = 13, 16, 7 102 | a, x, b, y = gen_data(rng, n, m, dim) 103 | 104 | # Test differentiability. We assume uniform samples because makes diff 105 | # more accurate (avoiding ties, making computations a lot more sensitive). 106 | dx = jax.random.uniform(rng, (n, dim)) - 0.5 107 | cost_p, _ = sliced.sliced_wasserstein(x + eps * dx, y) 108 | cost_m, _ = sliced.sliced_wasserstein(x - eps * dx, y) 109 | g, _ = jax.jit(jax.grad(sliced.sliced_wasserstein, has_aux=True))(x, y) 110 | 111 | np.testing.assert_allclose( 112 | jnp.sum(g * dx), (cost_p - cost_m) / (2 * eps), atol=1e-3, rtol=1e-3 113 | ) 114 | -------------------------------------------------------------------------------- /tests/tools/unreg_test.py: -------------------------------------------------------------------------------- 1 | # Copyright OTT-JAX 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional, Tuple 15 | 16 | import pytest 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | from ott.geometry import costs, pointcloud 23 | from ott.solvers import linear 24 | from ott.tools import unreg 25 | 26 | 27 | class TestHungarian: 28 | 29 | @pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None]) 30 | def test_matches_sink(self, rng: jax.Array, cost_fn: Optional[costs.CostFn]): 31 | n, m, dim = 12, 12, 5 32 | rng1, rng2 = jax.random.split(rng, 2) 33 | x, y = gen_data(rng1, n, m, dim) 34 | geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005) 35 | cost_hung, out_hung = jax.jit(unreg.hungarian)(geom) 36 | out_sink = jax.jit(linear.solve)(geom) 37 | np.testing.assert_allclose( 38 | out_sink.primal_cost, cost_hung, rtol=1e-3, atol=1e-3 39 | ) 40 | np.testing.assert_allclose( 41 | out_sink.matrix, out_hung.matrix.todense(), rtol=1e-3, atol=1e-3 42 | ) 43 | 44 | @pytest.mark.parametrize("p", [1.3, 2.3]) 45 | def test_wass(self, rng: jax.Array, p: float): 46 | n, m, dim = 12, 12, 5 47 | rng1, rng2 = jax.random.split(rng, 2) 48 | x, y = gen_data(rng1, n, m, dim) 49 | geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p)) 50 | cost_hung, _ = jax.jit(unreg.hungarian)(geom) 51 | w_p = jax.jit(unreg.wassdis_p)(x, y, p=p) 52 | np.testing.assert_allclose(w_p, cost_hung ** (1. / p), rtol=1e-3, atol=1e-3) 53 | 54 | 55 | def gen_data(rng: jax.Array, n: int, m: int, 56 | dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: 57 | rngs = jax.random.split(rng, 4) 58 | x = jax.random.uniform(rngs[0], (n, dim)) 59 | y = jax.random.uniform(rngs[1], (m, dim)) 60 | return x, y 61 | --------------------------------------------------------------------------------