├── .gitignore ├── .gitmodules ├── .travis.yml ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── ReadMe.md ├── appveyor.yml ├── lapsolver ├── __init__.py ├── benchmarks │ ├── __init__.py │ ├── plot_results.py │ └── test_benchmark_solvers.py ├── data │ └── dense │ │ ├── costs0.npz │ │ ├── costs1.npz │ │ ├── costs10.npz │ │ ├── costs2.npz │ │ ├── costs3.npz │ │ ├── costs4.npz │ │ ├── costs5.npz │ │ ├── costs6.npz │ │ ├── costs7.npz │ │ ├── costs8.npz │ │ └── costs9.npz ├── etc │ ├── benchmark-dtype-int.png │ └── benchmark-dtype-numpy.float32.png └── tests │ ├── __init__.py │ ├── test_dense.py │ └── test_files.py ├── setup.py └── src ├── dense.hpp ├── dense_wrap.hpp └── lapsolverc.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pybind11"] 2 | path = pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | branch = v2.2 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | env: 3 | - PYTHON=3.5 4 | - PYTHON=3.6 5 | install: 6 | # Install conda 7 | - if [[ "$PYTHON" == "2.7" ]]; then 8 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 9 | else 10 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 11 | fi 12 | - bash miniconda.sh -b -p $HOME/miniconda 13 | - export PATH="$HOME/miniconda/bin:$PATH" 14 | - hash -r 15 | - conda config --set always_yes yes --set changeps1 no 16 | - conda update -q conda 17 | - conda info -a 18 | 19 | # Install deps 20 | - deps='pip numpy' 21 | - conda create -q -n pyenv python=$PYTHON $deps 22 | - source activate pyenv 23 | - pip install pytest 24 | - pip install scikit-build 25 | - pip install cmake 26 | - pip install . 27 | 28 | script: pytest --pyargs lapsolver -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.28) 2 | project(lapsolverc) 3 | 4 | add_subdirectory(pybind11) 5 | 6 | pybind11_add_module(lapsolverc src/lapsolverc.cpp) 7 | 8 | target_compile_options(lapsolverc PUBLIC "$<$:-O2>") 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018-2020 Christoph Heindl 4 | Copyright (c) 2019-2020 Jack Valmadre 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE 2 | global-include CMakeLists.txt *.cmake 3 | recursive-include src * 4 | recursive-include pybind11/include *.h 5 | recursive-include lapsolver/data * 6 | recursive-include lapsolver/etc * -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | ## py-lapsolver 2 | 3 | **py-lapsolver** implements a Linear sum Assignment Problem (LAP) solver for dense matrices based on shortest path augmentation in Python. In practice, it solves 5000x5000 problems in around 3 seconds. 4 | 5 | ### Install 6 | 7 | ``` 8 | pip install [--pre] lapsolver 9 | ``` 10 | 11 | Windows binary wheels are provided for Python 3.5/3.6. Source wheels otherwise. 12 | 13 | ### Install from source 14 | 15 | Clone this repository 16 | 17 | ``` 18 | git clone --recursive https://github.com/cheind/py-lapsolver.git 19 | ``` 20 | 21 | Then build the project and exectute tests 22 | 23 | ``` 24 | python setup.py develop 25 | python setup.py test 26 | ``` 27 | 28 | Executing the tests requires `pytest` and optionally `pytest-benchmark` for generating benchmarks. 29 | 30 | ### Usage 31 | 32 | ```python 33 | import numpy as np 34 | from lapsolver import solve_dense 35 | 36 | costs = np.array([ 37 | [6, 9, 1], 38 | [10, 3, 2], 39 | [8, 7, 4.] 40 | ], dtype=np.float32) 41 | 42 | rids, cids = solve_dense(costs) 43 | 44 | for r,c in zip(rids, cids): 45 | print(r,c) # Row/column pairings 46 | """ 47 | 0 2 48 | 1 1 49 | 2 0 50 | """ 51 | ``` 52 | 53 | You may also want to mark certain pairings impossible 54 | 55 | ```python 56 | # Matrix with non-allowed pairings 57 | costs = np.array([ 58 | [5, 9, np.nan], 59 | [10, np.nan, 2], 60 | [8, 7, 4.]] 61 | ) 62 | 63 | rids, cids = solve_dense(costs) 64 | 65 | for r,c in zip(rids, cids): 66 | print(r,c) # Row/column pairings 67 | """ 68 | 0 0 69 | 1 2 70 | 2 1 71 | """ 72 | ``` 73 | 74 | ### Benchmarks 75 | 76 | Comparisons below are generated by scripts in [`./lapsolver/benchmarks`](./lapsolver/benchmarks). 77 | 78 | Currently, the following solvers are tested 79 | - `lapjv` - https://github.com/gatagat/lap 80 | - `munkres` - http://software.clapper.org/munkres/ 81 | - `ortools` - https://github.com/google/or-tools ** 82 | - `scipy` - https://github.com/scipy/scipy/tree/master/scipy 83 | - `lapsolver` - this project 84 | 85 | **reduced performance due to costly dense matrix to graph conversion. If you know a better way, please let me know. 86 | 87 | Please note that the x-axis is scaled logarithmically. Missing bars indicate excessive runtime or errors in returned result. 88 | ![](./lapsolver/etc/benchmark-dtype-int.png) 89 | ![](./lapsolver/etc/benchmark-dtype-numpy.float32.png) 90 | 91 | #### Additional Benchmarks 92 | 93 | Berhane performs an in depth analysis of Python3 linear assignment problem solver at https://github.com/berhane/LAP-solvers 94 | 95 | ### References 96 | **py-lapsolver** heavily relies on [code](https://github.com/jaehyunp/stanfordacm/blob/9e1375cd4eba68a59dd7b1e2f81692653e9908a9/code/MinCostMatching.cc) published by @jaehyunp at https://github.com/jaehyunp/ 97 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | environment: 2 | TWINE_USERNAME: cheind 3 | TWINE_PASSWORD: 4 | secure: hnxMBvmJAGM1rQVOUbkGvQ== 5 | 6 | # http://www.appveyor.com/docs/installed-software#python 7 | matrix: 8 | - PYTHON: "C:\\Miniconda36-x64" 9 | PYTHON_VERSION: "3.6" 10 | PYTHON_ARCH: "64" 11 | - PYTHON: "C:\\Miniconda36" 12 | PYTHON_VERSION: "3.6" 13 | PYTHON_ARCH: "32" 14 | - PYTHON: "C:\\Miniconda35-x64" 15 | PYTHON_VERSION: "3.5" 16 | PYTHON_ARCH: "64" 17 | - PYTHON: "C:\\Miniconda35" 18 | PYTHON_VERSION: "3.5" 19 | PYTHON_ARCH: "32" 20 | 21 | install: 22 | - set "CONDA_ROOT=%PYTHON%" 23 | - set "PATH=%CONDA_ROOT%;%CONDA_ROOT%\Scripts;%CONDA_ROOT%\Library\bin;%PATH%" 24 | - conda config --set always_yes yes 25 | - conda update -q conda 26 | - conda config --set auto_update_conda no 27 | - conda install -q pip pytest numpy cmake 28 | - python -m pip install --upgrade pip 29 | - pip install wheel 30 | - pip install --upgrade --ignore-installed setuptools 31 | - git submodule update --init --recursive 32 | - ps: >- 33 | if ($env:APPVEYOR_REPO_BRANCH -eq 'develop') { 34 | $env:LAPSOLVER_RELEASE_TYPE="dev"; 35 | $env:LAPSOLVER_DEV_NUM=$env:APPVEYOR_BUILD_NUMBER 36 | } else { 37 | $env:LAPSOLVER_RELEASE_TYPE="stable" 38 | } 39 | 40 | build_script: 41 | - python setup.py sdist 42 | - python setup.py bdist_wheel 43 | 44 | test_script: 45 | # Try building source wheel and install. 46 | # Redirect stderr of pip within powershell to avoid error. 47 | - ps: >- 48 | $wheel = cmd /r dir .\dist\*.tar.gz /b/s; 49 | cmd /c "pip install --verbose $wheel 2>&1" 50 | - pytest --pyargs lapsolver 51 | - pip uninstall -y lapsolver 52 | # Try building binary wheel and install. 53 | # Redirect stderr of pip within powershell to avoid error. 54 | - ps: >- 55 | $wheel = cmd /r dir .\dist\*.whl /b/s; 56 | cmd /c "pip install --verbose $wheel 2>&1" 57 | - pytest --pyargs lapsolver 58 | - pip uninstall -y lapsolver 59 | 60 | on_success: 61 | ps: >- 62 | if ($env:APPVEYOR_REPO_TAG -eq $true -Or $env:APPVEYOR_REPO_BRANCH -eq "master") { 63 | Write-Output ("Deploying to PyPI") 64 | pip install --upgrade twine 65 | # If powershell ever sees anything on stderr it thinks it's a fail. 66 | # So we use cmd to redirect stderr to stdout before PS can see it. 67 | cmd /c 'twine upload --skip-existing dist\* 2>&1' 68 | } else { 69 | Write-Output "Not deploying as this is not a tagged commit or commit on master" 70 | } 71 | 72 | artifacts: 73 | - path: "dist\\*.whl" 74 | - path: "dist\\*.tar.gz" 75 | name: Wheels 76 | 77 | notifications: 78 | - provider: Email 79 | to: 80 | - christoph.heindl@email.com 81 | on_build_success: true 82 | on_build_failure: true 83 | 84 | branches: 85 | only: 86 | - master 87 | - develop 88 | -------------------------------------------------------------------------------- /lapsolver/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from lapsolverc import solve_dense 3 | 4 | # Needs to be last line 5 | __version__ = '1.1.0' 6 | -------------------------------------------------------------------------------- /lapsolver/benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/benchmarks/__init__.py -------------------------------------------------------------------------------- /lapsolver/benchmarks/plot_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='Plot benchmark results.', formatter_class=argparse.RawTextHelpFormatter) 10 | 11 | parser.add_argument('benchmarkfile', type=str, help='Json file containing benchmark results') 12 | #parser.add_argument('tests', type=str, help='Directory containing tracker result files') 13 | #parser.add_argument('--loglevel', type=str, help='Log level', default='info') 14 | #parser.add_argument('--fmt', type=str, help='Data format', default='mot15-2D') 15 | #parser.add_argument('--solver', type=str, help='LAP solver to use') 16 | return parser.parse_args() 17 | 18 | def build_dataframe(args): 19 | dre = re.compile(r"'(.*)'") 20 | 21 | with open(args.benchmarkfile) as f: 22 | data = json.load(f) 23 | 24 | events = [] 25 | for b in data['benchmarks']: 26 | ei = b['extra_info'] 27 | dtype = dre.search(ei['scalar']).group(1) 28 | events.append(('{}x{}'.format(ei['size'][0], ei['size'][1]), ei['solver'], dtype, b['stats']['mean'], b['stats']['stddev'])) 29 | 30 | return pd.DataFrame(events, columns=['matrix-size', 'solver', 'scalar', 'mean-time', 'stddev']) 31 | 32 | def draw_plots(df): 33 | sns.set_style("whitegrid") 34 | for s, g in df.groupby('scalar'): 35 | print(g) 36 | plt.figure(figsize=(8, 5.5)) 37 | title='Benchmark results for dtype={}'.format(s) 38 | ax = sns.barplot(x='mean-time', y='matrix-size', hue='solver', data=g, errwidth=0, palette="muted") 39 | ax.set_xscale("log") 40 | ax.set_xlabel('mean-time (sec)') 41 | plt.legend(loc='upper right') 42 | plt.title(title) 43 | plt.tight_layout() 44 | plt.savefig('benchmark-dtype-{}.png'.format(s), transparent=True, ) 45 | plt.show() 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | 51 | df = build_dataframe(args) 52 | draw_plots(df) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | 58 | -------------------------------------------------------------------------------- /lapsolver/benchmarks/test_benchmark_solvers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from scipy.optimize import linear_sum_assignment 3 | import numpy as np 4 | import pytest 5 | import importlib 6 | import sys 7 | 8 | def load_solver_lapsolver(): 9 | from lapsolver import solve_dense 10 | 11 | def run(costs): 12 | rids, cids = solve_dense(costs) 13 | return costs[rids, cids].sum() 14 | 15 | return run 16 | 17 | def load_solver_scipy(): 18 | from scipy.optimize import linear_sum_assignment 19 | 20 | def run(costs): 21 | rids, cids = linear_sum_assignment(costs) 22 | return costs[rids, cids].sum() 23 | 24 | return run 25 | 26 | def load_solver_munkres(): 27 | from munkres import Munkres, DISALLOWED 28 | 29 | def run(costs): 30 | m = Munkres() 31 | idx = np.array(m.compute(costs), dtype=int) 32 | return costs[idx[:,0], idx[:,1]].sum() 33 | 34 | return run 35 | 36 | def load_solver_lapjv(): 37 | from lap import lapjv 38 | 39 | def run(costs): 40 | r = lapjv(costs, return_cost=True, extend_cost=True) 41 | return r[0] 42 | 43 | return run 44 | 45 | def load_solver_ortools(): 46 | from ortools.graph import pywrapgraph 47 | 48 | def run(costs): 49 | f = 1e3 50 | valid = np.isfinite(costs) 51 | # A lot of time in ortools is being spent in constructing the graph. 52 | assignment = pywrapgraph.LinearSumAssignment() 53 | for r in range(costs.shape[0]): 54 | for c in range(costs.shape[1]): 55 | if valid[r,c]: 56 | assignment.AddArcWithCost(r, c, int(costs[r,c]*f)) 57 | 58 | # No error checking for now 59 | assignment.Solve() 60 | return assignment.OptimalCost() / f 61 | 62 | return run 63 | 64 | def load_solvers(): 65 | loaders = [ 66 | ('lapsolver', load_solver_lapsolver), 67 | ('lapjv', load_solver_lapjv), 68 | ('scipy', load_solver_scipy), 69 | ('munkres', load_solver_munkres), 70 | ('ortools', load_solver_ortools), 71 | ] 72 | 73 | solvers = {} 74 | for l in loaders: 75 | try: 76 | solvers[l[0]] = l[1]() 77 | except: 78 | pass 79 | return solvers 80 | 81 | 82 | solvers = load_solvers() 83 | size_to_expected = collections.OrderedDict([ 84 | ('10x5', -39518.0), 85 | ('10x10', -80040.0), 86 | ('20x20', -175988.0), 87 | ('50x20', -193922.0), 88 | ('50x50', -467118.0), 89 | ('100x100', -970558.0), 90 | ('200x200', -1967491.0), 91 | ('500x500', -4968156.0), 92 | ('1000x1000', -9968874.0), 93 | ('5000x5000', -49969853.0), 94 | ]) 95 | size_max = [5000,5000] 96 | 97 | np.random.seed(123) 98 | icosts = np.random.randint(-1e4, 1e4, size=size_max) 99 | 100 | 101 | @pytest.mark.benchmark( 102 | min_time=1, 103 | min_rounds=2, 104 | disable_gc=False, 105 | warmup=True, 106 | warmup_iterations=1 107 | ) 108 | @pytest.mark.parametrize('solver', solvers.keys()) 109 | @pytest.mark.parametrize('scalar', [int, np.float32]) 110 | @pytest.mark.parametrize('size', [k for k, v in size_to_expected.items()]) 111 | 112 | def test_benchmark_solver(benchmark, solver, scalar, size): 113 | dims = _parse_size(size) 114 | expected = size_to_expected[size] 115 | 116 | exclude_above = { 117 | 'munkres' : 200, 118 | 'ortools' : 5000 119 | } 120 | 121 | benchmark.extra_info = { 122 | 'solver': solver, 123 | 'size': size, 124 | 'scalar': str(scalar) 125 | } 126 | 127 | s = np.array(dims) 128 | if (s > exclude_above.get(solver, sys.maxsize)).any(): 129 | benchmark.extra_info['success'] = False 130 | return 131 | 132 | costs = icosts[:dims[0], :dims[1]].astype(scalar).copy() 133 | r = benchmark(solvers[solver], costs) 134 | if r != expected: 135 | benchmark.extra_info['success'] = False 136 | 137 | 138 | def _parse_size(size_str): 139 | """Parses a string of the form 'MxN'.""" 140 | m, n = (int(x) for x in size_str.split('x')) 141 | return m, n 142 | 143 | 144 | # pytest lapsolver -k test_benchmark_solver -v --benchmark-group-by=param:size,param:scalar -s --benchmark-save=bench 145 | -------------------------------------------------------------------------------- /lapsolver/data/dense/costs0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs0.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs1.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs10.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs2.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs3.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs4.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs5.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs6.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs6.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs7.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs7.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs8.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs8.npz -------------------------------------------------------------------------------- /lapsolver/data/dense/costs9.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/data/dense/costs9.npz -------------------------------------------------------------------------------- /lapsolver/etc/benchmark-dtype-int.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/etc/benchmark-dtype-int.png -------------------------------------------------------------------------------- /lapsolver/etc/benchmark-dtype-numpy.float32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/etc/benchmark-dtype-numpy.float32.png -------------------------------------------------------------------------------- /lapsolver/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/py-lapsolver/7bbb1ed64b460a0f0252e5e9be73707082b7b546/lapsolver/tests/__init__.py -------------------------------------------------------------------------------- /lapsolver/tests/test_dense.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import lapsolver as lap 4 | 5 | @pytest.mark.parametrize('dtype', ['float', 'int', 'float32', 'float64', 'int32', 'int64']) 6 | def test_small(dtype): 7 | costs = np.array([[6, 9, 1],[10, 3, 2],[8, 7, 4]], dtype=dtype) 8 | r = lap.solve_dense(costs) 9 | expected = np.array([[0, 1, 2], [2, 1, 0]]) 10 | np.testing.assert_equal(r, expected) 11 | 12 | def test_plain_array(): 13 | costs = [[6, 9, 1],[10, 3, 2],[8, 7, 4.]] 14 | r = lap.solve_dense(costs) 15 | expected = np.array([[0, 1, 2], [2, 1, 0]]) 16 | np.testing.assert_allclose(r, expected) 17 | 18 | def test_plain_array_integer(): 19 | # Integer problem whose solution is changed by fractional modification. 20 | costs = [[6, 9, 1],[10, 3, 2],[8, 5, 4]] 21 | r = lap.solve_dense(costs) 22 | expected = np.array([[0, 1, 2], [2, 1, 0]]) 23 | np.testing.assert_allclose(r, expected) 24 | 25 | def test_plain_array_fractional(): 26 | # Add fractional costs that change the solution. 27 | # Before: (1 + 3 + 8) = 12 < 13 = (6 + 5 + 2) 28 | # After: (1.4 + 3.4 + 8.4) = 13.2 < 13 29 | # This confirms that pylib11 did not cast float to int. 30 | costs = [[6, 9, 1.4],[10, 3.4, 2],[8.4, 5, 4]] 31 | r = lap.solve_dense(costs) 32 | expected = np.array([[0, 1, 2], [0, 2, 1]]) 33 | np.testing.assert_allclose(r, expected) 34 | 35 | def test_nonsquare(): 36 | costs = np.array([[6, 9],[10, 3],[8, 7]], dtype=float) 37 | 38 | r = lap.solve_dense(costs) 39 | expected = np.array([[0, 1], [0, 1]]) 40 | np.testing.assert_allclose(r, expected) 41 | 42 | r = lap.solve_dense(costs.T) # view test 43 | expected = np.array([[0, 1], [0, 1]]) 44 | np.testing.assert_allclose(r, expected) 45 | 46 | costs = np.array( 47 | [[ -17.13614455, -536.59009819], 48 | [ 292.64662837, 187.49841358], 49 | [ 664.70501771, 948.09658792]]) 50 | 51 | expected = np.array([[0, 1], [1, 0]]) 52 | r = lap.solve_dense(costs) 53 | np.testing.assert_allclose(r, expected) 54 | 55 | def test_views(): 56 | costs = np.array([[6, 9],[10, 3],[8, 7]], dtype=float) 57 | np.testing.assert_allclose(lap.solve_dense(costs.T[1:, :]), [[0], [1]]) 58 | 59 | def test_large(): 60 | costs = np.random.uniform(size=(5000,5000)) 61 | r = lap.solve_dense(costs) 62 | 63 | def test_solve_nan(): 64 | costs = np.array([[5, 9, np.nan],[10, np.nan, 2],[8, 7, 4.]]) 65 | r = lap.solve_dense(costs) 66 | expected = np.array([[0, 1, 2], [0, 2, 1]]) 67 | np.testing.assert_allclose(r, expected) 68 | 69 | def test_solve_inf(): 70 | costs = np.array([[5, 9, np.inf],[10, np.inf, 2],[8, 7, 4.]]) 71 | r = lap.solve_dense(costs) 72 | expected = np.array([[0, 1, 2], [0, 2, 1]]) 73 | np.testing.assert_allclose(r, expected) 74 | 75 | def test_missing_edge_negative(): 76 | costs = np.array([[-1000, -1], [-1, np.nan]]) 77 | r = lap.solve_dense(costs) 78 | # The optimal solution is (0, 1), (1, 0) with cost -1 + -1. 79 | # If the implementation does not use a large enough constant, it may choose 80 | # (0, 0), (1, 1) with cost -1000 + L. 81 | expected = np.array([[0, 1], [1, 0]]) 82 | np.testing.assert_allclose(r, expected) 83 | 84 | def test_missing_edge_positive(): 85 | costs = np.array([ 86 | [np.nan, 1000, np.nan], 87 | [np.nan, 1, 1000], 88 | [1000, np.nan, 1], 89 | ]) 90 | costs_copy = costs.copy() 91 | r = lap.solve_dense(costs) 92 | # The optimal solution is (0, 1), (1, 2), (2, 0) with cost 1000 + 1000 + 1000. 93 | # If the implementation does not use a large enough constant, it may choose 94 | # (0, 0), (1, 1), (2, 2) with cost (L + 1 + 1) instead. 95 | expected = np.array([[0, 1, 2], [1, 2, 0]]) 96 | np.testing.assert_allclose(r, expected) 97 | -------------------------------------------------------------------------------- /lapsolver/tests/test_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import lapsolver as lap 4 | import glob 5 | 6 | DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') 7 | 8 | def test_files_for_dense(): 9 | files = glob.glob(os.path.join(DATA_DIR, 'dense', '*.npz')) 10 | print(DATA_DIR) 11 | assert len(files) > 0 12 | for f in files: 13 | data = np.load(f) 14 | rids, cids = lap.solve_dense(data['costs']) 15 | 16 | assert data['costs'][rids, cids].sum() == data['total_cost'] 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import platform 5 | import subprocess 6 | 7 | from setuptools import setup, Extension 8 | from setuptools.command.build_ext import build_ext 9 | from setuptools.command.test import test 10 | from distutils.version import LooseVersion 11 | 12 | 13 | class CMakeExtension(Extension): 14 | def __init__(self, name, sourcedir=''): 15 | super(CMakeExtension, self).__init__(name, sources=[]) 16 | self.sourcedir = os.path.abspath(sourcedir) 17 | 18 | 19 | class CMakeBuild(build_ext): 20 | def run(self): 21 | try: 22 | out = subprocess.check_output(['cmake', '--version']) 23 | except OSError: 24 | raise RuntimeError("CMake must be installed to build the following extensions: " + 25 | ", ".join(e.name for e in self.extensions)) 26 | 27 | if platform.system() == "Windows": 28 | cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1)) 29 | if cmake_version < '3.1.0': 30 | raise RuntimeError("CMake >= 3.1.0 is required on Windows") 31 | 32 | for ext in self.extensions: 33 | self.build_extension(ext) 34 | 35 | def build_extension(self, ext): 36 | extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) 37 | cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, 38 | '-DPYTHON_EXECUTABLE=' + sys.executable] 39 | 40 | cfg = 'Debug' if self.debug else 'Release' 41 | build_args = ['--config', cfg] 42 | 43 | if platform.system() == "Windows": 44 | cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)] 45 | if sys.maxsize > 2**32: 46 | cmake_args += ['-A', 'x64'] 47 | build_args += ['--', '/m'] 48 | else: 49 | cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] 50 | build_args += ['--', '-j2'] 51 | 52 | env = os.environ.copy() 53 | env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''), 54 | self.distribution.get_version()) 55 | if not os.path.exists(self.build_temp): 56 | os.makedirs(self.build_temp) 57 | subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env) 58 | subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) 59 | 60 | class PyTest(test): 61 | def run_tests(self): 62 | # import here, cause outside the eggs aren't loaded 63 | import pytest 64 | errno = pytest.main(['lapsolver/tests']) 65 | sys.exit(errno) 66 | 67 | 68 | version = open('lapsolver/__init__.py').readlines()[-1].split()[-1].strip('\'') 69 | if os.getenv('LAPSOLVER_RELEASE_TYPE', 'stable') == 'dev': 70 | version = version + '.dev' + os.getenv('LAPSOLVER_DEV_NUM', '0') # this will allow pre-releases pip install --pre lapsolver 71 | 72 | setup( 73 | name='lapsolver', 74 | version=version, 75 | author='Christoph Heindl, Jack Valmadre', 76 | url='https://github.com/cheind/py-lapsolver', 77 | description='Fast linear assignment problem solvers', 78 | license='MIT', 79 | long_description='', 80 | packages=['lapsolver', 'lapsolver.tests'], 81 | include_package_data=True, 82 | ext_modules=[CMakeExtension('lapsolverc')], 83 | cmdclass=dict(build_ext=CMakeBuild, test=PyTest), 84 | zip_safe=False, 85 | python_requires='>=3', 86 | setup_requires=['pytest-runner'], 87 | tests_require=['pytest'], 88 | keywords='hungarian munkres kuhn linear-sum-assignment bipartite-graph lap' 89 | ) -------------------------------------------------------------------------------- /src/dense.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | /** 6 | Min cost bipartite matching via shortest augmenting paths 7 | 8 | This is an O(n^3) implementation of a shortest augmenting path 9 | algorithm for finding min cost perfect matchings in dense 10 | graphs. In practice, it solves 1000x1000 problems in around 1 11 | second. 12 | 13 | cost[i][j] = cost for pairing left node i with right node j 14 | Lmate[i] = index of right node that left node i pairs with 15 | Rmate[j] = index of left node that right node j pairs with 16 | The values in cost[i][j] may be positive or negative. To perform 17 | maximization, simply negate the cost[][] matrix. 18 | 19 | Taken from https://github.com/jaehyunp/ 20 | Adapted by https://github.com/cheind 21 | */ 22 | template 23 | void solve_dense(const std::vector< std::vector > &cost, std::vector &Lmate, std::vector &Rmate) 24 | { 25 | 26 | ////////////////////////////////////////////////////////////////////// 27 | // Min cost bipartite matching via shortest augmenting paths 28 | // 29 | // This is an O(n^3) implementation of a shortest augmenting path 30 | // algorithm for finding min cost perfect matchings in dense 31 | // graphs. In practice, it solves 1000x1000 problems in around 1 32 | // second. 33 | // 34 | // cost[i][j] = cost for pairing left node i with right node j 35 | // Lmate[i] = index of right node that left node i pairs with 36 | // Rmate[j] = index of left node that right node j pairs with 37 | // 38 | // The values in cost[i][j] may be positive or negative. To perform 39 | // maximization, simply negate the cost[][] matrix. 40 | ////////////////////////////////////////////////////////////////////// 41 | 42 | typedef std::vector VD; 43 | typedef std::vector VVD; 44 | typedef std::vector VI; 45 | 46 | // assumes square matrices 47 | const int n = int(cost.size()); 48 | 49 | // construct dual feasible solution 50 | VD u(n); 51 | VD v(n); 52 | for (int i = 0; i < n; i++) { 53 | u[i] = cost[i][0]; 54 | for (int j = 1; j < n; j++) u[i] = std::min(u[i], cost[i][j]); 55 | } 56 | for (int j = 0; j < n; j++) { 57 | v[j] = cost[0][j] - u[0]; 58 | for (int i = 1; i < n; i++) v[j] = std::min(v[j], cost[i][j] - u[i]); 59 | } 60 | 61 | // construct primal solution satisfying complementary slackness 62 | Lmate = VI(n, -1); 63 | Rmate = VI(n, -1); 64 | int mated = 0; 65 | for (int i = 0; i < n; i++) { 66 | for (int j = 0; j < n; j++) { 67 | if (Rmate[j] != -1) continue; 68 | if (fabs(cost[i][j] - u[i] - v[j]) < 1e-10) { 69 | Lmate[i] = j; 70 | Rmate[j] = i; 71 | mated++; 72 | break; 73 | } 74 | } 75 | } 76 | 77 | VD dist(n); 78 | VI dad(n); 79 | VI seen(n); 80 | 81 | // repeat until primal solution is feasible 82 | while (mated < n) { 83 | 84 | // find an unmatched left node 85 | int s = 0; 86 | while (Lmate[s] != -1) s++; 87 | 88 | // initialize Dijkstra 89 | fill(dad.begin(), dad.end(), -1); 90 | fill(seen.begin(), seen.end(), 0); 91 | for (int k = 0; k < n; k++) 92 | dist[k] = cost[s][k] - u[s] - v[k]; 93 | 94 | int j = 0; 95 | while (true) { 96 | 97 | // find closest 98 | j = -1; 99 | for (int k = 0; k < n; k++) { 100 | if (seen[k]) continue; 101 | if (j == -1 || dist[k] < dist[j]) j = k; 102 | } 103 | seen[j] = 1; 104 | 105 | // termination condition 106 | if (Rmate[j] == -1) break; 107 | 108 | // relax neighbors 109 | const int i = Rmate[j]; 110 | for (int k = 0; k < n; k++) { 111 | if (seen[k]) continue; 112 | const T new_dist = dist[j] + cost[i][k] - u[i] - v[k]; 113 | if (dist[k] > new_dist) { 114 | dist[k] = new_dist; 115 | dad[k] = j; 116 | } 117 | } 118 | } 119 | 120 | // update dual variables 121 | for (int k = 0; k < n; k++) { 122 | if (k == j || !seen[k]) continue; 123 | const int i = Rmate[k]; 124 | v[k] += dist[k] - dist[j]; 125 | u[i] -= dist[k] - dist[j]; 126 | } 127 | u[s] += dist[j]; 128 | 129 | // augment along path 130 | while (dad[j] >= 0) { 131 | const int d = dad[j]; 132 | Rmate[j] = Rmate[d]; 133 | Lmate[Rmate[j]] = j; 134 | j = d; 135 | } 136 | Rmate[j] = s; 137 | Lmate[s] = j; 138 | 139 | mated++; 140 | } 141 | } -------------------------------------------------------------------------------- /src/dense_wrap.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "dense.hpp" 9 | 10 | namespace py = pybind11; 11 | 12 | template 13 | py::tuple solve_dense_wrap(py::array_t input1) { 14 | auto buf1 = input1.request(); 15 | 16 | if (buf1.ndim != 2) 17 | throw std::runtime_error("Number of dimensions must be two"); 18 | 19 | const int nrows = int(buf1.shape[0]); 20 | const int ncols = int(buf1.shape[1]); 21 | 22 | if (nrows == 0 || ncols == 0) { 23 | return py::make_tuple(py::array(), py::array()); 24 | } 25 | 26 | T *data = (T *)buf1.ptr; 27 | 28 | bool any_finite = false; 29 | T max_abs_cost = 0; 30 | for(int i = 0; i < nrows*ncols; ++i) { 31 | if (std::isfinite((double)data[i])) { 32 | any_finite = true; 33 | // Careful: Note that std::abs() is not a template. 34 | // https://en.cppreference.com/w/cpp/numeric/math/abs 35 | // https://en.cppreference.com/w/cpp/numeric/math/fabs 36 | max_abs_cost = std::max(max_abs_cost, std::abs(data[i])); 37 | } 38 | } 39 | 40 | if (!any_finite) { 41 | return py::make_tuple(py::array(), py::array()); 42 | } 43 | 44 | const int r = std::min(nrows, ncols); 45 | const int n = std::max(nrows, ncols); 46 | const T LARGE_COST = 2 * r * max_abs_cost + 1; 47 | std::vector> costs(n, std::vector(n, LARGE_COST)); 48 | 49 | for (int i = 0; i < nrows; i++) 50 | { 51 | T *cptr = data + i*ncols; 52 | for (int j =0; j < ncols; j++) 53 | { 54 | const T c = cptr[j]; 55 | if (std::isfinite((double)c)) 56 | costs[i][j] = c; 57 | } 58 | } 59 | 60 | 61 | std::vector Lmate, Rmate; 62 | solve_dense(costs, Lmate, Rmate); 63 | 64 | std::vector rowids, colids; 65 | 66 | for (int i = 0; i < nrows; i++) 67 | { 68 | int mate = Lmate[i]; 69 | if (Lmate[i] < ncols && costs[i][mate] != LARGE_COST) 70 | { 71 | rowids.push_back(i); 72 | colids.push_back(mate); 73 | } 74 | } 75 | 76 | return py::make_tuple(py::array(rowids.size(), rowids.data()), py::array(colids.size(), colids.data())); 77 | } 78 | -------------------------------------------------------------------------------- /src/lapsolverc.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "dense_wrap.hpp" 6 | 7 | namespace py = pybind11; 8 | 9 | const char *doc_dense = R"pbdoc( 10 | Min cost bipartite matching via shortest augmenting paths for dense matrices 11 | 12 | This is an O(n^3) implementation of a shortest augmenting path 13 | algorithm for finding min cost perfect matchings in dense 14 | graphs. In practice, it solves 1000x1000 problems in around 1 15 | second. 16 | 17 | rids, cids = solve_dense(costs) 18 | total_cost = costs[rids, cids].sum() 19 | 20 | Params 21 | ------ 22 | costs : MxN array 23 | Array containing costs. 24 | 25 | Returns 26 | ------- 27 | rids : array 28 | Array of row ids of matching pairings 29 | cids : array 30 | Array of column ids of matching pairings 31 | 32 | )pbdoc"; 33 | 34 | 35 | PYBIND11_MODULE(lapsolverc, m) { 36 | m.doc() = R"pbdoc( 37 | Linear assignment problem solvers using native c-extensions. 38 | )pbdoc"; 39 | 40 | // pybind11 first tries each overload (in order) without conversion. 41 | // If no match is found, it tries again with conversion, unless disallowed. 42 | // This conversion will cast e.g. double to int. 43 | // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#overload-resolution-order 44 | m.def("solve_dense", solve_dense_wrap, py::arg().noconvert()); 45 | m.def("solve_dense", solve_dense_wrap, py::arg().noconvert()); 46 | m.def("solve_dense", solve_dense_wrap, py::arg().noconvert()); 47 | m.def("solve_dense", solve_dense_wrap); 48 | 49 | #ifdef VERSION_INFO 50 | m.attr("__version__") = VERSION_INFO; 51 | #else 52 | m.attr("__version__") = "dev"; 53 | #endif 54 | } 55 | --------------------------------------------------------------------------------