├── .codecov.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── doc_improvement.md │ └── enhancement_proposal.md └── workflows │ └── main.yml ├── .gitignore ├── .landscape.yml ├── LICENSE.txt ├── README.rst ├── bench ├── .gitignore ├── asv.conf.json └── benchmarks │ ├── __init__.py │ └── iris.py ├── doc ├── .gitignore ├── Makefile ├── _static │ ├── .gitignore │ └── css │ │ └── styles.css ├── _templates │ ├── .gitignore │ └── class.rst ├── conf.py ├── getting_started.rst ├── index.rst ├── introduction.rst ├── make.bat ├── metric_learn.rst ├── preprocessor.rst ├── supervised.rst ├── unsupervised.rst ├── user_guide.rst └── weakly_supervised.rst ├── examples ├── README.txt ├── plot_metric_learning_examples.py └── plot_sandwich.py ├── metric_learn ├── __init__.py ├── _util.py ├── _version.py ├── base_metric.py ├── constraints.py ├── covariance.py ├── exceptions.py ├── itml.py ├── lfda.py ├── lmnn.py ├── lsml.py ├── mlkr.py ├── mmc.py ├── nca.py ├── rca.py ├── scml.py ├── sdml.py └── sklearn_shims.py ├── pytest.ini ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── metric_learn_test.py ├── test_base_metric.py ├── test_components_metric_conversion.py ├── test_constraints.py ├── test_fit_transform.py ├── test_mahalanobis_mixin.py ├── test_pairs_classifiers.py ├── test_quadruplets_classifiers.py ├── test_sklearn_compat.py ├── test_triplets_classifiers.py └── test_utils.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - "test" 3 | 4 | # taken from scikit-learn: 5 | # https://github.com/scikit-learn/scikit-learn/blob/a7e17117bb15eb3f51ebccc1bd53e42fcb4e6cd8/.codecov.yml 6 | comment: false 7 | 8 | coverage: 9 | status: 10 | project: 11 | default: 12 | # Commits pushed to master should not make the overall 13 | # project coverage decrease by more than 1%: 14 | target: auto 15 | threshold: 1% 16 | patch: 17 | default: 18 | # Be tolerant on slight code coverage diff on PRs to limit 19 | # noisy red coverage status on github PRs. 20 | # Note The coverage stats are still uploaded 21 | # to codecov so that PR reviewers can see uncovered lines 22 | # in the github diff if they install the codecov browser 23 | # extension: 24 | # https://github.com/codecov/browser-extension 25 | target: auto 26 | threshold: 1% 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Reproducible bug report 3 | about: Create a reproducible bug report. Not for support requests. 4 | labels: 'bug' 5 | --- 6 | 7 | #### Description 8 | 9 | 10 | #### Steps/Code to Reproduce 11 | 29 | 30 | #### Expected Results 31 | 32 | 33 | #### Actual Results 34 | 35 | 36 | #### Versions 37 | 50 | 51 | 52 | --- 53 | 54 | **Message from the maintainers**: 55 | 56 | Impacted by this bug? Give it a 👍. We prioritise the issues with the most 👍. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | 3 | contact_links: 4 | - name: Have you read the docs? 5 | url: http://contrib.scikit-learn.org/metric-learn/ 6 | about: Much help can be found in the docs 7 | - name: Ask a question 8 | url: https://github.com/scikit-learn-contrib/metric-learn/discussions/new 9 | about: Ask a question or start a discussion about metric-learn 10 | - name: Stack Overflow 11 | url: https://stackoverflow.com 12 | about: Please ask and answer metric-learn usage questions (API, installation...) on Stack Overflow 13 | - name: Cross Validated 14 | url: https://stats.stackexchange.com 15 | about: Please ask and answer metric learning questions (use cases, algorithms & theory...) on Cross Validated 16 | - name: Blank issue 17 | url: https://github.com/scikit-learn-contrib/metric-learn/issues/new 18 | about: Please note that Github Discussions should be used in most cases instead 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/doc_improvement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation improvement 3 | about: Create a report to help us improve the documentation. Alternatively you can just open a pull request with the suggested change. 4 | labels: Documentation 5 | --- 6 | 7 | #### Describe the issue linked to the documentation 8 | 9 | 12 | 13 | #### Suggest a potential alternative/fix 14 | 15 | 18 | 19 | --- 20 | 21 | **Message from the maintainers**: 22 | 23 | Confused by this part of the doc too? Give it a 👍. We prioritise the issues with the most 👍. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/enhancement_proposal.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Enhancement proposal 3 | about: Propose an enhancement for metric-learn 4 | labels: 'enhancement' 5 | --- 6 | # Summary 7 | 8 | What change needs making? 9 | 10 | # Use Cases 11 | 12 | When would you use this? 13 | 14 | --- 15 | 16 | **Message from the maintainers**: 17 | 18 | Want to see this feature happen? Give it a 👍. We prioritise the issues with the most 👍. -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # Controls when the workflow will run 4 | on: 5 | # Triggers the workflow on push or pull request events but only for the master branch 6 | push: 7 | branches: [ master ] 8 | pull_request: 9 | branches: [ master ] 10 | 11 | jobs: 12 | # Run normal testing with the latest versions of all dependencies 13 | build: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | os: [ubuntu-latest] 18 | python-version: ['3.8', '3.9', '3.10', '3.11'] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Run Tests without skggm 26 | run: | 27 | sudo apt-get install liblapack-dev 28 | pip install --upgrade pip pytest 29 | pip install wheel cython numpy scipy codecov pytest-cov scikit-learn 30 | pytest test --cov 31 | bash <(curl -s https://codecov.io/bash) 32 | - name: Run Tests with skggm 33 | env: 34 | SKGGM_VERSION: a0ed406586c4364ea3297a658f415e13b5cbdaf8 35 | run: | 36 | pip install git+https://github.com/skggm/skggm.git@${SKGGM_VERSION} 37 | pytest test --cov 38 | bash <(curl -s https://codecov.io/bash) 39 | - name: Syntax checking with flake8 40 | run: | 41 | pip install flake8 42 | flake8 --extend-ignore=E111,E114 --show-source; 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build/ 3 | dist/ 4 | *.egg-info 5 | .coverage 6 | htmlcov/ 7 | .cache/ 8 | .pytest_cache/ 9 | doc/auto_examples/* 10 | doc/generated/* 11 | venv/ 12 | .vscode/ 13 | -------------------------------------------------------------------------------- /.landscape.yml: -------------------------------------------------------------------------------- 1 | strictness: medium 2 | pep8: 3 | disable: 4 | - E111 5 | - E114 6 | - E231 7 | - E225 8 | - E402 9 | - W503 10 | pylint: 11 | disable: 12 | - bad-indentation 13 | - invalid-name 14 | - too-many-arguments 15 | ignore-paths: 16 | - bench/ 17 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 CJ Carey and Yuan Tang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | |GitHub Actions Build Status| |License| |PyPI version| |Code coverage| 2 | 3 | metric-learn: Metric Learning in Python 4 | ======================================= 5 | 6 | metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised metric learning algorithms. As part of `scikit-learn-contrib `_, the API of metric-learn is compatible with `scikit-learn `_, the leading library for machine learning in Python. This allows to use all the scikit-learn routines (for pipelining, model selection, etc) with metric learning algorithms through a unified interface. 7 | 8 | **Algorithms** 9 | 10 | - Large Margin Nearest Neighbor (LMNN) 11 | - Information Theoretic Metric Learning (ITML) 12 | - Sparse Determinant Metric Learning (SDML) 13 | - Least Squares Metric Learning (LSML) 14 | - Sparse Compositional Metric Learning (SCML) 15 | - Neighborhood Components Analysis (NCA) 16 | - Local Fisher Discriminant Analysis (LFDA) 17 | - Relative Components Analysis (RCA) 18 | - Metric Learning for Kernel Regression (MLKR) 19 | - Mahalanobis Metric for Clustering (MMC) 20 | 21 | **Dependencies** 22 | 23 | - Python 3.6+ (the last version supporting Python 2 and Python 3.5 was 24 | `v0.5.0 `_) 25 | - numpy>= 1.11.0, scipy>= 0.17.0, scikit-learn>=0.21.3 26 | 27 | **Optional dependencies** 28 | 29 | - For SDML, using skggm will allow the algorithm to solve problematic cases 30 | (install from commit `a0ed406 `_). 31 | ``pip install 'git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8'`` to install the required version of skggm from GitHub. 32 | - For running the examples only: matplotlib 33 | 34 | **Installation/Setup** 35 | 36 | - If you use Anaconda: ``conda install -c conda-forge metric-learn``. See more options `here `_. 37 | 38 | - To install from PyPI: ``pip install metric-learn``. 39 | 40 | - For a manual install of the latest code, download the source repository and run ``python setup.py install``. You may then run ``pytest test`` to run all tests (you will need to have the ``pytest`` package installed). 41 | 42 | **Usage** 43 | 44 | See the `sphinx documentation`_ for full documentation about installation, API, usage, and examples. 45 | 46 | **Citation** 47 | 48 | If you use metric-learn in a scientific publication, we would appreciate 49 | citations to the following paper: 50 | 51 | `metric-learn: Metric Learning Algorithms in Python 52 | `_, de Vazelhes 53 | *et al.*, Journal of Machine Learning Research, 21(138):1-6, 2020. 54 | 55 | Bibtex entry:: 56 | 57 | @article{metric-learn, 58 | title = {metric-learn: {M}etric {L}earning {A}lgorithms in {P}ython}, 59 | author = {{de Vazelhes}, William and {Carey}, CJ and {Tang}, Yuan and 60 | {Vauquier}, Nathalie and {Bellet}, Aur{\'e}lien}, 61 | journal = {Journal of Machine Learning Research}, 62 | year = {2020}, 63 | volume = {21}, 64 | number = {138}, 65 | pages = {1--6} 66 | } 67 | 68 | .. _sphinx documentation: http://contrib.scikit-learn.org/metric-learn/ 69 | 70 | .. |GitHub Actions Build Status| image:: https://github.com/scikit-learn-contrib/metric-learn/workflows/CI/badge.svg 71 | :target: https://github.com/scikit-learn-contrib/metric-learn/actions?query=event%3Apush+branch%3Amaster 72 | .. |License| image:: http://img.shields.io/:license-mit-blue.svg?style=flat 73 | :target: http://badges.mit-license.org 74 | .. |PyPI version| image:: https://badge.fury.io/py/metric-learn.svg 75 | :target: http://badge.fury.io/py/metric-learn 76 | .. |Code coverage| image:: https://codecov.io/gh/scikit-learn-contrib/metric-learn/branch/master/graph/badge.svg 77 | :target: https://codecov.io/gh/scikit-learn-contrib/metric-learn 78 | -------------------------------------------------------------------------------- /bench/.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | env 3 | metric-learn 4 | html 5 | -------------------------------------------------------------------------------- /bench/asv.conf.json: -------------------------------------------------------------------------------- 1 | { 2 | // The version of the config file format. Do not change, unless 3 | // you know what you are doing. 4 | "version": 1, 5 | 6 | // The name of the project being benchmarked 7 | "project": "metric-learn", 8 | 9 | // The project's homepage 10 | "project_url": "https://github.com/all-umass/metric-learn", 11 | 12 | // The URL or local path of the source code repository for the 13 | // project being benchmarked 14 | "repo": "..", 15 | 16 | // List of branches to benchmark. If not provided, defaults to "master" 17 | // (for git) or "tip" (for mercurial). 18 | "branches": ["master"], // for git 19 | // "branches": ["tip"], // for mercurial 20 | 21 | // The DVCS being used. If not set, it will be automatically 22 | // determined from "repo" by looking at the protocol in the URL 23 | // (if remote), or by looking for special directories, such as 24 | // ".git" (if local). 25 | "dvcs": "git", 26 | 27 | // The tool to use to create environments. May be "conda", 28 | // "virtualenv" or other value depending on the plugins in use. 29 | // If missing or the empty string, the tool will be automatically 30 | // determined by looking for tools on the PATH environment 31 | // variable. 32 | "environment_type": "virtualenv", 33 | 34 | // the base URL to show a commit for the project. 35 | "show_commit_url": "http://github.com/all-umass/metric-learn/commit/", 36 | 37 | // The Pythons you'd like to test against. If not provided, defaults 38 | // to the current version of Python used to run `asv`. 39 | // "pythons": ["2.7", "3.3"], 40 | 41 | // The matrix of dependencies to test. Each key is the name of a 42 | // package (in PyPI) and the values are version numbers. An empty 43 | // list indicates to just test against the default (latest) 44 | // version. 45 | "matrix": { 46 | "numpy": ["1.12"], 47 | "scipy": ["0.18"], 48 | "scikit-learn": ["0.18"] 49 | }, 50 | 51 | // The directory (relative to the current directory) that benchmarks are 52 | // stored in. If not provided, defaults to "benchmarks" 53 | // "benchmark_dir": "benchmarks", 54 | 55 | // The directory (relative to the current directory) to cache the Python 56 | // environments in. If not provided, defaults to "env" 57 | // "env_dir": "env", 58 | 59 | // The directory (relative to the current directory) that raw benchmark 60 | // results are stored in. If not provided, defaults to "results". 61 | // "results_dir": "results", 62 | 63 | // The directory (relative to the current directory) that the html tree 64 | // should be written to. If not provided, defaults to "html". 65 | // "html_dir": "html", 66 | 67 | // The number of characters to retain in the commit hashes. 68 | // "hash_length": 8, 69 | 70 | // `asv` will cache wheels of the recent builds in each 71 | // environment, making them faster to install next time. This is 72 | // number of builds to keep, per environment. 73 | "wheel_cache_size": 4 74 | } 75 | -------------------------------------------------------------------------------- /bench/benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/metric-learn/dc7e4499b1a9e522f03c87ba8dc249f9747cac82/bench/benchmarks/__init__.py -------------------------------------------------------------------------------- /bench/benchmarks/iris.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_iris 3 | 4 | import metric_learn 5 | 6 | CLASSES = { 7 | 'Covariance': metric_learn.Covariance(), 8 | 'ITML_Supervised': metric_learn.ITML_Supervised(n_constraints=200), 9 | 'LFDA': metric_learn.LFDA(k=2, dim=2), 10 | 'LMNN': metric_learn.LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False), 11 | 'LSML_Supervised': metric_learn.LSML_Supervised(n_constraints=200), 12 | 'MLKR': metric_learn.MLKR(), 13 | 'NCA': metric_learn.NCA(max_iter=700, n_components=2), 14 | 'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, n_chunks=30, 15 | chunk_size=2), 16 | 'SDML_Supervised': metric_learn.SDML_Supervised(n_constraints=1500) 17 | } 18 | 19 | 20 | class IrisDataset(object): 21 | params = [sorted(CLASSES)] 22 | param_names = ['alg'] 23 | 24 | def setup(self, alg): 25 | iris_data = load_iris() 26 | self.iris_points = iris_data['data'] 27 | self.iris_labels = iris_data['target'] 28 | 29 | def time_fit(self, alg): 30 | np.random.seed(5555) 31 | CLASSES[alg].fit(self.iris_points, self.iris_labels) 32 | -------------------------------------------------------------------------------- /doc/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " applehelp to make an Apple Help Book" 34 | @echo " devhelp to make HTML files and a Devhelp project" 35 | @echo " epub to make an epub" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | html: 55 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 56 | @echo 57 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 58 | 59 | dirhtml: 60 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 63 | 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | pickle: 70 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 71 | @echo 72 | @echo "Build finished; now you can process the pickle files." 73 | 74 | json: 75 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 76 | @echo 77 | @echo "Build finished; now you can process the JSON files." 78 | 79 | htmlhelp: 80 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 81 | @echo 82 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 83 | ".hhp project file in $(BUILDDIR)/htmlhelp." 84 | 85 | qthelp: 86 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 87 | @echo 88 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 89 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 90 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/metric-learn.qhcp" 91 | @echo "To view the help file:" 92 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/metric-learn.qhc" 93 | 94 | applehelp: 95 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 96 | @echo 97 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 98 | @echo "N.B. You won't be able to view it unless you put it in" \ 99 | "~/Library/Documentation/Help or install it in your application" \ 100 | "bundle." 101 | 102 | devhelp: 103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 104 | @echo 105 | @echo "Build finished." 106 | @echo "To view the help file:" 107 | @echo "# mkdir -p $$HOME/.local/share/devhelp/metric-learn" 108 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/metric-learn" 109 | @echo "# devhelp" 110 | 111 | epub: 112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 113 | @echo 114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 115 | 116 | latex: 117 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 118 | @echo 119 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 120 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 121 | "(use \`make latexpdf' here to do that automatically)." 122 | 123 | latexpdf: 124 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 125 | @echo "Running LaTeX files through pdflatex..." 126 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 127 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 128 | 129 | latexpdfja: 130 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 131 | @echo "Running LaTeX files through platex and dvipdfmx..." 132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 134 | 135 | text: 136 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 137 | @echo 138 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 139 | 140 | man: 141 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 142 | @echo 143 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 144 | 145 | texinfo: 146 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 147 | @echo 148 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 149 | @echo "Run \`make' in that directory to run these through makeinfo" \ 150 | "(use \`make info' here to do that automatically)." 151 | 152 | info: 153 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 154 | @echo "Running Texinfo files through makeinfo..." 155 | make -C $(BUILDDIR)/texinfo info 156 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 157 | 158 | gettext: 159 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 160 | @echo 161 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 162 | 163 | changes: 164 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 165 | @echo 166 | @echo "The overview file is in $(BUILDDIR)/changes." 167 | 168 | linkcheck: 169 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 170 | @echo 171 | @echo "Link check complete; look for any errors in the above output " \ 172 | "or in $(BUILDDIR)/linkcheck/output.txt." 173 | 174 | doctest: 175 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 176 | @echo "Testing of doctests in the sources finished, look at the " \ 177 | "results in $(BUILDDIR)/doctest/output.txt." 178 | 179 | coverage: 180 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 181 | @echo "Testing of coverage in the sources finished, look at the " \ 182 | "results in $(BUILDDIR)/coverage/python.txt." 183 | 184 | xml: 185 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 186 | @echo 187 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 188 | 189 | pseudoxml: 190 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 191 | @echo 192 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 193 | -------------------------------------------------------------------------------- /doc/_static/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/metric-learn/dc7e4499b1a9e522f03c87ba8dc249f9747cac82/doc/_static/.gitignore -------------------------------------------------------------------------------- /doc/_static/css/styles.css: -------------------------------------------------------------------------------- 1 | .hatnote { 2 | border-color: #e1e4e5 ; 3 | border-style: solid ; 4 | border-width: 1px ; 5 | font-size: x-small ; 6 | font-style: italic ; 7 | margin-left: auto ; 8 | margin-right: auto ; 9 | margin-bottom: 24px; 10 | padding: 12px; 11 | } 12 | .hatnote-gray { 13 | background-color: #f5f5f5 14 | } 15 | .hatnote li { 16 | list-style-type: square; 17 | margin-left: 12px !important; 18 | } 19 | .hatnote ul { 20 | list-style-type: square; 21 | margin-left: 0px !important; 22 | margin-bottom: 0px !important; 23 | } 24 | .deprecated { 25 | color: #b94a48; 26 | background-color: #F3E5E5; 27 | border-color: #eed3d7; 28 | margin-top: 0.5rem; 29 | padding: 0.5rem; 30 | border-radius: 0.5rem; 31 | margin-bottom: 0.5rem; 32 | } 33 | 34 | .deprecated p { 35 | margin-bottom: 0 !important; 36 | } -------------------------------------------------------------------------------- /doc/_templates/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/metric-learn/dc7e4499b1a9e522f03c87ba8dc249f9747cac82/doc/_templates/.gitignore -------------------------------------------------------------------------------- /doc/_templates/class.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | :special-members: __init__ 11 | 12 | .. include:: {{module}}.{{objname}}.examples 13 | 14 | .. raw:: html 15 | 16 |
17 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | import warnings 5 | 6 | extensions = [ 7 | 'sphinx.ext.autodoc', 8 | 'sphinx.ext.autosummary', 9 | 'sphinx.ext.todo', 10 | 'sphinx.ext.viewcode', 11 | 'sphinx.ext.mathjax', 12 | 'numpydoc', 13 | 'sphinx_gallery.gen_gallery', 14 | 'sphinx.ext.doctest', 15 | 'sphinx.ext.intersphinx' 16 | ] 17 | 18 | templates_path = ['_templates'] 19 | source_suffix = '.rst' 20 | master_doc = 'index' 21 | 22 | # General information about the project. 23 | project = u'metric-learn' 24 | copyright = (u'2015-2023, CJ Carey, Yuan Tang, William de Vazelhes, Aurélien ' 25 | u'Bellet and Nathalie Vauquier') 26 | author = (u'CJ Carey, Yuan Tang, William de Vazelhes, Aurélien Bellet and ' 27 | u'Nathalie Vauquier') 28 | version = '0.7.0' 29 | release = '0.7.0' 30 | language = 'en' 31 | 32 | exclude_patterns = ['_build'] 33 | pygments_style = 'sphinx' 34 | todo_include_todos = True 35 | 36 | # Options for HTML output 37 | html_theme = 'sphinx_rtd_theme' 38 | html_static_path = ['_static'] 39 | htmlhelp_basename = 'metric-learndoc' 40 | 41 | # Option to hide doctests comments in the documentation (like # doctest: 42 | # +NORMALIZE_WHITESPACE for instance) 43 | trim_doctest_flags = True 44 | 45 | # intersphinx configuration 46 | intersphinx_mapping = { 47 | 'python': ('https://docs.python.org/{.major}'.format( 48 | sys.version_info), None), 49 | 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 50 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), 51 | 'scikit-learn': ('https://scikit-learn.org/stable/', None) 52 | } 53 | 54 | 55 | # sphinx-gallery configuration 56 | sphinx_gallery_conf = { 57 | # to generate mini-galleries at the end of each docstring in the API 58 | # section: (see https://sphinx-gallery.github.io/configuration.html 59 | # #references-to-examples) 60 | 'doc_module': 'metric_learn', 61 | 'backreferences_dir': os.path.join('generated'), 62 | } 63 | 64 | # generate autosummary even if no references 65 | autosummary_generate = True 66 | 67 | 68 | # Temporary work-around for spacing problem between parameter and parameter 69 | # type in the doc, see https://github.com/numpy/numpydoc/issues/215. The bug 70 | # has been fixed in sphinx (https://github.com/sphinx-doc/sphinx/pull/5976) but 71 | # through a change in sphinx basic.css except rtd_theme does not use basic.css. 72 | # In an ideal world, this would get fixed in this PR: 73 | # https://github.com/readthedocs/sphinx_rtd_theme/pull/747/files 74 | def setup(app): 75 | app.add_js_file('js/copybutton.js') 76 | app.add_css_file('css/styles.css') 77 | 78 | 79 | # Remove matplotlib agg warnings from generated doc when using plt.show 80 | warnings.filterwarnings("ignore", category=UserWarning, 81 | message='Matplotlib is currently using agg, which is a' 82 | ' non-GUI backend, so cannot show the figure.') 83 | -------------------------------------------------------------------------------- /doc/getting_started.rst: -------------------------------------------------------------------------------- 1 | ############### 2 | Getting started 3 | ############### 4 | 5 | Installation and Setup 6 | ====================== 7 | 8 | **Installation** 9 | 10 | metric-learn can be installed in either of the following ways: 11 | 12 | - If you use Anaconda: ``conda install -c conda-forge metric-learn``. See more options `here `_. 13 | 14 | - To install from PyPI: ``pip install metric-learn``. 15 | 16 | - For a manual install of the latest code, download the source repository and run ``python setup.py install``. You may then run ``pytest test`` to run all tests (you will need to have the ``pytest`` package installed). 17 | 18 | **Dependencies** 19 | 20 | - Python 3.6+ (the last version supporting Python 2 and Python 3.5 was 21 | `v0.5.0 `_) 22 | - numpy>= 1.11.0, scipy>= 0.17.0, scikit-learn>=0.21.3 23 | 24 | **Optional dependencies** 25 | 26 | - For SDML, using skggm will allow the algorithm to solve problematic cases 27 | (install from commit `a0ed406 `_). 28 | ``pip install 'git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8'`` to install the required version of skggm from GitHub. 29 | - For running the examples only: matplotlib 30 | 31 | Quick start 32 | =========== 33 | 34 | This example loads the iris dataset, and evaluates a k-nearest neighbors 35 | algorithm on an embedding space learned with `NCA`. 36 | 37 | :: 38 | 39 | from metric_learn import NCA 40 | from sklearn.datasets import load_iris 41 | from sklearn.model_selection import cross_val_score 42 | from sklearn.pipeline import make_pipeline 43 | from sklearn.neighbors import KNeighborsClassifier 44 | 45 | X, y = load_iris(return_X_y=True) 46 | clf = make_pipeline(NCA(), KNeighborsClassifier()) 47 | cross_val_score(clf, X, y) 48 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | metric-learn: Metric Learning in Python 2 | ======================================= 3 | |GitHub Actions Build Status| |License| |PyPI version| |Code coverage| 4 | 5 | `metric-learn `_ 6 | contains efficient Python implementations of several popular supervised and 7 | weakly-supervised metric learning algorithms. As part of `scikit-learn-contrib 8 | `_, the API of metric-learn is compatible with `scikit-learn 9 | `_, the leading library for machine learning in 10 | Python. This allows to use all the scikit-learn routines (for pipelining, 11 | model selection, etc) with metric learning algorithms through a unified 12 | interface. 13 | 14 | If you use metric-learn in a scientific publication, we would appreciate 15 | citations to the following paper: 16 | 17 | `metric-learn: Metric Learning Algorithms in Python 18 | `_, de Vazelhes 19 | *et al.*, Journal of Machine Learning Research, 21(138):1-6, 2020. 20 | 21 | Bibtex entry:: 22 | 23 | @article{metric-learn, 24 | title = {metric-learn: {M}etric {L}earning {A}lgorithms in {P}ython}, 25 | author = {{de Vazelhes}, William and {Carey}, CJ and {Tang}, Yuan and 26 | {Vauquier}, Nathalie and {Bellet}, Aur{\'e}lien}, 27 | journal = {Journal of Machine Learning Research}, 28 | year = {2020}, 29 | volume = {21}, 30 | number = {138}, 31 | pages = {1--6} 32 | } 33 | 34 | 35 | Documentation outline 36 | --------------------- 37 | 38 | .. toctree:: 39 | :maxdepth: 2 40 | 41 | getting_started 42 | 43 | .. toctree:: 44 | :maxdepth: 2 45 | 46 | user_guide 47 | 48 | .. toctree:: 49 | :maxdepth: 2 50 | 51 | Package Contents 52 | 53 | .. toctree:: 54 | :maxdepth: 2 55 | 56 | auto_examples/index 57 | 58 | :ref:`genindex` | :ref:`search` 59 | 60 | .. |GitHub Actions Build Status| image:: https://github.com/scikit-learn-contrib/metric-learn/workflows/CI/badge.svg 61 | :target: https://github.com/scikit-learn-contrib/metric-learn/actions?query=event%3Apush+branch%3Amaster 62 | .. |PyPI version| image:: https://badge.fury.io/py/metric-learn.svg 63 | :target: http://badge.fury.io/py/metric-learn 64 | .. |License| image:: http://img.shields.io/:license-mit-blue.svg?style=flat 65 | :target: http://badges.mit-license.org 66 | .. |Code coverage| image:: https://codecov.io/gh/scikit-learn-contrib/metric-learn/branch/master/graph/badge.svg 67 | :target: https://codecov.io/gh/scikit-learn-contrib/metric-learn 68 | -------------------------------------------------------------------------------- /doc/introduction.rst: -------------------------------------------------------------------------------- 1 | .. _intro_metric_learning: 2 | 3 | ======================== 4 | What is Metric Learning? 5 | ======================== 6 | 7 | Many approaches in machine learning require a measure of distance between data 8 | points. Traditionally, practitioners would choose a standard distance metric 9 | (Euclidean, City-Block, Cosine, etc.) using a priori knowledge of the 10 | domain. However, it is often difficult to design metrics that are well-suited 11 | to the particular data and task of interest. 12 | 13 | Distance metric learning (or simply, metric learning) aims at 14 | automatically constructing task-specific distance metrics from (weakly) 15 | supervised data, in a machine learning manner. The learned distance metric can 16 | then be used to perform various tasks (e.g., k-NN classification, clustering, 17 | information retrieval). 18 | 19 | Problem Setting 20 | =============== 21 | 22 | Metric learning problems fall into two main categories depending on the type 23 | of supervision available about the training data: 24 | 25 | - :doc:`Supervised learning `: the algorithm has access to 26 | a set of data points, each of them belonging to a class (label) as in a 27 | standard classification problem. 28 | Broadly speaking, the goal in this setting is to learn a distance metric 29 | that puts points with the same label close together while pushing away 30 | points with different labels. 31 | - :doc:`Weakly supervised learning `: the 32 | algorithm has access to a set of data points with supervision only 33 | at the tuple level (typically pairs, triplets, or quadruplets of 34 | data points). A classic example of such weaker supervision is a set of 35 | positive and negative pairs: in this case, the goal is to learn a distance 36 | metric that puts positive pairs close together and negative pairs far away. 37 | 38 | Based on the above (weakly) supervised data, the metric learning problem is 39 | generally formulated as an optimization problem where one seeks to find the 40 | parameters of a distance function that optimize some objective function 41 | measuring the agreement with the training data. 42 | 43 | .. _mahalanobis_distances: 44 | 45 | Mahalanobis Distances 46 | ===================== 47 | 48 | In the metric-learn package, all algorithms currently implemented learn 49 | so-called Mahalanobis distances. Given a real-valued parameter matrix 50 | :math:`L` of shape ``(num_dims, n_features)`` where ``n_features`` is the 51 | number features describing the data, the Mahalanobis distance associated with 52 | :math:`L` is defined as follows: 53 | 54 | .. math:: D(x, x') = \sqrt{(Lx-Lx')^\top(Lx-Lx')} 55 | 56 | In other words, a Mahalanobis distance is a Euclidean distance after a 57 | linear transformation of the feature space defined by :math:`L` (taking 58 | :math:`L` to be the identity matrix recovers the standard Euclidean distance). 59 | Mahalanobis distance metric learning can thus be seen as learning a new 60 | embedding space of dimension ``num_dims``. Note that when ``num_dims`` is 61 | smaller than ``n_features``, this achieves dimensionality reduction. 62 | 63 | Strictly speaking, Mahalanobis distances are "pseudo-metrics": they satisfy 64 | three of the `properties of a metric `_ (non-negativity, symmetry, triangle inequality) but not 66 | necessarily the identity of indiscernibles. 67 | 68 | .. note:: 69 | 70 | Mahalanobis distances can also be parameterized by a `positive semi-definite 71 | (PSD) matrix 72 | `_ 73 | :math:`M`: 74 | 75 | .. math:: D(x, x') = \sqrt{(x-x')^\top M(x-x')} 76 | 77 | Using the fact that a PSD matrix :math:`M` can always be decomposed as 78 | :math:`M=L^\top L` for some :math:`L`, one can show that both 79 | parameterizations are equivalent. In practice, an algorithm may thus solve 80 | the metric learning problem with respect to either :math:`M` or :math:`L`. 81 | 82 | .. _use_cases: 83 | 84 | Use-cases 85 | ========= 86 | 87 | There are many use-cases for metric learning. We list here a few popular 88 | examples (for code illustrating some of these use-cases, see the 89 | :doc:`examples ` section of the documentation): 90 | 91 | - `Nearest neighbors models 92 | `_: the learned 93 | metric can be used to improve nearest neighbors learning models for 94 | classification, regression, anomaly detection... 95 | - `Clustering `_: 96 | metric learning provides a way to bias the clusters found by algorithms like 97 | K-Means towards the intended semantics. 98 | - Information retrieval: the learned metric can be used to retrieve the 99 | elements of a database that are semantically closest to a query element. 100 | - Dimensionality reduction: metric learning may be seen as a way to reduce the 101 | data dimension in a (weakly) supervised setting. 102 | - More generally, the learned transformation :math:`L` can be used to project 103 | the data into a new embedding space before feeding it into another machine 104 | learning algorithm. 105 | 106 | The API of metric-learn is compatible with `scikit-learn 107 | `_, the leading library for machine 108 | learning in Python. This allows to easily pipeline metric learners with other 109 | scikit-learn estimators to realize the above use-cases, to perform joint 110 | hyperparameter tuning, etc. 111 | 112 | Further reading 113 | =============== 114 | 115 | For more information about metric learning and its applications, one can refer 116 | to the following resources: 117 | 118 | - **Tutorial:** `Similarity and Distance Metric Learning with Applications to 119 | Computer Vision 120 | `_ (2015) 121 | - **Surveys:** `A Survey on Metric Learning for Feature Vectors and Structured 122 | Data `_ (2013), `Metric Learning: A 123 | Survey `_ (2012) 124 | - **Book:** `Metric Learning 125 | `_ (2015) 126 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | echo. coverage to run coverage check of the documentation if enabled 41 | goto end 42 | ) 43 | 44 | if "%1" == "clean" ( 45 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 46 | del /q /s %BUILDDIR%\* 47 | goto end 48 | ) 49 | 50 | 51 | REM Check if sphinx-build is available and fallback to Python version if any 52 | %SPHINXBUILD% 2> nul 53 | if errorlevel 9009 goto sphinx_python 54 | goto sphinx_ok 55 | 56 | :sphinx_python 57 | 58 | set SPHINXBUILD=python -m sphinx.__init__ 59 | %SPHINXBUILD% 2> nul 60 | if errorlevel 9009 ( 61 | echo. 62 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 63 | echo.installed, then set the SPHINXBUILD environment variable to point 64 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 65 | echo.may add the Sphinx directory to PATH. 66 | echo. 67 | echo.If you don't have Sphinx installed, grab it from 68 | echo.http://sphinx-doc.org/ 69 | exit /b 1 70 | ) 71 | 72 | :sphinx_ok 73 | 74 | 75 | if "%1" == "html" ( 76 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 77 | if errorlevel 1 exit /b 1 78 | echo. 79 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 80 | goto end 81 | ) 82 | 83 | if "%1" == "dirhtml" ( 84 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 85 | if errorlevel 1 exit /b 1 86 | echo. 87 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 88 | goto end 89 | ) 90 | 91 | if "%1" == "singlehtml" ( 92 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 93 | if errorlevel 1 exit /b 1 94 | echo. 95 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 96 | goto end 97 | ) 98 | 99 | if "%1" == "pickle" ( 100 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 101 | if errorlevel 1 exit /b 1 102 | echo. 103 | echo.Build finished; now you can process the pickle files. 104 | goto end 105 | ) 106 | 107 | if "%1" == "json" ( 108 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 109 | if errorlevel 1 exit /b 1 110 | echo. 111 | echo.Build finished; now you can process the JSON files. 112 | goto end 113 | ) 114 | 115 | if "%1" == "htmlhelp" ( 116 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 117 | if errorlevel 1 exit /b 1 118 | echo. 119 | echo.Build finished; now you can run HTML Help Workshop with the ^ 120 | .hhp project file in %BUILDDIR%/htmlhelp. 121 | goto end 122 | ) 123 | 124 | if "%1" == "qthelp" ( 125 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 129 | .qhcp project file in %BUILDDIR%/qthelp, like this: 130 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\metric-learn.qhcp 131 | echo.To view the help file: 132 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\metric-learn.ghc 133 | goto end 134 | ) 135 | 136 | if "%1" == "devhelp" ( 137 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 138 | if errorlevel 1 exit /b 1 139 | echo. 140 | echo.Build finished. 141 | goto end 142 | ) 143 | 144 | if "%1" == "epub" ( 145 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 146 | if errorlevel 1 exit /b 1 147 | echo. 148 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 149 | goto end 150 | ) 151 | 152 | if "%1" == "latex" ( 153 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 154 | if errorlevel 1 exit /b 1 155 | echo. 156 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 157 | goto end 158 | ) 159 | 160 | if "%1" == "latexpdf" ( 161 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 162 | cd %BUILDDIR%/latex 163 | make all-pdf 164 | cd %~dp0 165 | echo. 166 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 167 | goto end 168 | ) 169 | 170 | if "%1" == "latexpdfja" ( 171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 172 | cd %BUILDDIR%/latex 173 | make all-pdf-ja 174 | cd %~dp0 175 | echo. 176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 177 | goto end 178 | ) 179 | 180 | if "%1" == "text" ( 181 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 182 | if errorlevel 1 exit /b 1 183 | echo. 184 | echo.Build finished. The text files are in %BUILDDIR%/text. 185 | goto end 186 | ) 187 | 188 | if "%1" == "man" ( 189 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 190 | if errorlevel 1 exit /b 1 191 | echo. 192 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 193 | goto end 194 | ) 195 | 196 | if "%1" == "texinfo" ( 197 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 198 | if errorlevel 1 exit /b 1 199 | echo. 200 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 201 | goto end 202 | ) 203 | 204 | if "%1" == "gettext" ( 205 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 206 | if errorlevel 1 exit /b 1 207 | echo. 208 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 209 | goto end 210 | ) 211 | 212 | if "%1" == "changes" ( 213 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 214 | if errorlevel 1 exit /b 1 215 | echo. 216 | echo.The overview file is in %BUILDDIR%/changes. 217 | goto end 218 | ) 219 | 220 | if "%1" == "linkcheck" ( 221 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 222 | if errorlevel 1 exit /b 1 223 | echo. 224 | echo.Link check complete; look for any errors in the above output ^ 225 | or in %BUILDDIR%/linkcheck/output.txt. 226 | goto end 227 | ) 228 | 229 | if "%1" == "doctest" ( 230 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 231 | if errorlevel 1 exit /b 1 232 | echo. 233 | echo.Testing of doctests in the sources finished, look at the ^ 234 | results in %BUILDDIR%/doctest/output.txt. 235 | goto end 236 | ) 237 | 238 | if "%1" == "coverage" ( 239 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage 240 | if errorlevel 1 exit /b 1 241 | echo. 242 | echo.Testing of coverage in the sources finished, look at the ^ 243 | results in %BUILDDIR%/coverage/python.txt. 244 | goto end 245 | ) 246 | 247 | if "%1" == "xml" ( 248 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 249 | if errorlevel 1 exit /b 1 250 | echo. 251 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 252 | goto end 253 | ) 254 | 255 | if "%1" == "pseudoxml" ( 256 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 257 | if errorlevel 1 exit /b 1 258 | echo. 259 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 260 | goto end 261 | ) 262 | 263 | :end 264 | -------------------------------------------------------------------------------- /doc/metric_learn.rst: -------------------------------------------------------------------------------- 1 | metric_learn package 2 | ==================== 3 | 4 | Module Contents 5 | --------------- 6 | 7 | Base Classes 8 | ------------ 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | :template: class.rst 13 | 14 | metric_learn.Constraints 15 | metric_learn.base_metric.BaseMetricLearner 16 | metric_learn.base_metric.MetricTransformer 17 | metric_learn.base_metric.MahalanobisMixin 18 | metric_learn.base_metric._PairsClassifierMixin 19 | metric_learn.base_metric._TripletsClassifierMixin 20 | metric_learn.base_metric._QuadrupletsClassifierMixin 21 | 22 | Supervised Learning Algorithms 23 | ------------------------------ 24 | .. autosummary:: 25 | :toctree: generated/ 26 | :template: class.rst 27 | 28 | metric_learn.LFDA 29 | metric_learn.LMNN 30 | metric_learn.MLKR 31 | metric_learn.NCA 32 | metric_learn.RCA 33 | metric_learn.ITML_Supervised 34 | metric_learn.LSML_Supervised 35 | metric_learn.MMC_Supervised 36 | metric_learn.SDML_Supervised 37 | metric_learn.RCA_Supervised 38 | metric_learn.SCML_Supervised 39 | 40 | Weakly Supervised Learning Algorithms 41 | ------------------------------------- 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | :template: class.rst 46 | 47 | metric_learn.ITML 48 | metric_learn.LSML 49 | metric_learn.MMC 50 | metric_learn.SDML 51 | metric_learn.SCML 52 | 53 | Unsupervised Learning Algorithms 54 | -------------------------------- 55 | 56 | .. autosummary:: 57 | :toctree: generated/ 58 | :template: class.rst 59 | 60 | metric_learn.Covariance -------------------------------------------------------------------------------- /doc/preprocessor.rst: -------------------------------------------------------------------------------- 1 | .. _preprocessor_section: 2 | 3 | ============ 4 | Preprocessor 5 | ============ 6 | 7 | Estimators in metric-learn all have a ``preprocessor`` option at instantiation. 8 | Filling this argument allows them to take more compact input representation 9 | when fitting, predicting etc... 10 | 11 | If ``preprocessor=None``, no preprocessor will be used and the user must 12 | provide the classical representation to the fit/predict/score/etc... methods of 13 | the estimators (see the documentation of the particular estimator to know the 14 | type of input it accepts). Otherwise, two types of objects can be put in this 15 | argument: 16 | 17 | Array-like 18 | ---------- 19 | You can specify ``preprocessor=X`` where ``X`` is an array-like containing the 20 | dataset of points. In this case, the fit/predict/score/etc... methods of the 21 | estimator will be able to take as inputs an array-like of indices, replacing 22 | under the hood each index by the corresponding sample. 23 | 24 | 25 | Example with a supervised metric learner: 26 | 27 | >>> from metric_learn import NCA 28 | >>> 29 | >>> X = np.array([[-0.7 , -0.23], 30 | >>> [-0.43, -0.49], 31 | >>> [ 0.14, -0.37]]) # array of 3 samples of 2 features 32 | >>> points_indices = np.array([2, 0, 1, 0]) 33 | >>> y = np.array([1, 0, 1, 1]) 34 | >>> 35 | >>> nca = NCA(preprocessor=X) 36 | >>> nca.fit(points_indices, y) 37 | >>> # under the hood the algorithm will create 38 | >>> # points = np.array([[ 0.14, -0.37], 39 | >>> # [-0.7 , -0.23], 40 | >>> # [-0.43, -0.49], 41 | >>> # [ 0.14, -0.37]]) and fit on it 42 | 43 | 44 | Example with a weakly supervised metric learner: 45 | 46 | >>> from metric_learn import MMC 47 | >>> X = np.array([[-0.7 , -0.23], 48 | >>> [-0.43, -0.49], 49 | >>> [ 0.14, -0.37]]) # array of 3 samples of 2 features 50 | >>> pairs_indices = np.array([[2, 0], [1, 0]]) 51 | >>> y_pairs = np.array([1, -1]) 52 | >>> 53 | >>> mmc = MMC(preprocessor=X) 54 | >>> mmc.fit(pairs_indices, y_pairs) 55 | >>> # under the hood the algorithm will create 56 | >>> # pairs = np.array([[[ 0.14, -0.37], [-0.7 , -0.23]], 57 | >>> # [[-0.43, -0.49], [-0.7 , -0.23]]]) and fit on it 58 | 59 | Callable 60 | -------- 61 | Alternatively, you can provide a callable as ``preprocessor``. Then the 62 | estimator will accept indicators of points instead of points. Under the hood, 63 | the estimator will call this callable on the indicators you provide as input 64 | when fitting, predicting etc... Using a callable can be really useful to 65 | represent lazily a dataset of images stored on the file system for instance. 66 | The callable should take as an input a 1D array-like, and return a 2D 67 | array-like. For supervised learners it will be applied on the whole 1D array of 68 | indicators at once, and for weakly supervised learners it will be applied on 69 | each column of the 2D array of tuples. 70 | 71 | Example with a supervised metric learner: 72 | 73 | >>> def find_images(file_paths): 74 | >>> # each file contains a small image to use as an input datapoint 75 | >>> return np.row_stack([imread(f).ravel() for f in file_paths]) 76 | >>> 77 | >>> nca = NCA(preprocessor=find_images) 78 | >>> nca.fit(['img01.png', 'img00.png', 'img02.png'], [1, 0, 1]) 79 | >>> # under the hood preprocessor(indicators) will be called 80 | 81 | 82 | Example with a weakly supervised metric learner: 83 | 84 | >>> pairs_images_paths = [['img02.png', 'img00.png'], 85 | >>> ['img01.png', 'img00.png']] 86 | >>> y_pairs = np.array([1, -1]) 87 | >>> 88 | >>> mmc = NCA(preprocessor=find_images) 89 | >>> mmc.fit(pairs_images_paths, y_pairs) 90 | >>> # under the hood preprocessor(pairs_indicators[i]) will be called for each 91 | >>> # i in [0, 1] 92 | 93 | 94 | .. note:: Note that when you fill the ``preprocessor`` option, it allows you 95 | to give more compact inputs, but the classical way of providing inputs 96 | stays valid (2D array-like for supervised learners and 3D array-like of 97 | tuples for weakly supervised learners). If a classical input 98 | is provided, the metric learner will not use the preprocessor. 99 | 100 | Example: This will work: 101 | 102 | >>> from metric_learn import MMC 103 | >>> def preprocessor_wip(array): 104 | >>> raise NotImplementedError("This preprocessor does nothing yet.") 105 | >>> 106 | >>> pairs = np.array([[[ 0.14, -0.37], [-0.7 , -0.23]], 107 | >>> [[-0.43, -0.49], [-0.7 , -0.23]]]) 108 | >>> y_pairs = np.array([1, -1]) 109 | >>> 110 | >>> mmc = MMC(preprocessor=preprocessor_wip) 111 | >>> mmc.fit(pairs, y_pairs) # preprocessor_wip will not be called here 112 | -------------------------------------------------------------------------------- /doc/unsupervised.rst: -------------------------------------------------------------------------------- 1 | ============================ 2 | Unsupervised Metric Learning 3 | ============================ 4 | 5 | Unsupervised metric learning algorithms only take as input an (unlabeled) 6 | dataset `X`. For now, in metric-learn, there only is `Covariance`, which is a 7 | simple baseline algorithm (see below). 8 | 9 | 10 | Algorithms 11 | ========== 12 | .. _covariance: 13 | 14 | Covariance 15 | ---------- 16 | 17 | `Covariance` does not "learn" anything, rather it calculates 18 | the covariance matrix of the input data. This is a simple baseline method. 19 | It can be used for ZCA whitening of the data (see the Wikipedia page of 20 | `whitening transformation `_). 22 | 23 | .. rubric:: Example Code 24 | 25 | :: 26 | 27 | from metric_learn import Covariance 28 | from sklearn.datasets import load_iris 29 | 30 | iris = load_iris()['data'] 31 | 32 | cov = Covariance().fit(iris) 33 | x = cov.transform(iris) 34 | 35 | .. rubric:: References 36 | 37 | 38 | .. container:: hatnote hatnote-gray 39 | 40 | [1]. On the Generalized Distance in Statistics, P.C.Mahalanobis, 1936. -------------------------------------------------------------------------------- /doc/user_guide.rst: -------------------------------------------------------------------------------- 1 | .. title:: User guide: contents 2 | 3 | .. _user_guide: 4 | 5 | ========== 6 | User Guide 7 | ========== 8 | 9 | .. toctree:: 10 | :numbered: 11 | 12 | introduction.rst 13 | supervised.rst 14 | weakly_supervised.rst 15 | unsupervised.rst 16 | preprocessor.rst -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | Below is a gallery of example metric-learn use cases. -------------------------------------------------------------------------------- /examples/plot_sandwich.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Sandwich demo 4 | ============= 5 | 6 | Sandwich demo based on code from http://nbviewer.ipython.org/6576096 7 | """ 8 | 9 | ###################################################################### 10 | # .. note:: 11 | # 12 | # In order to show the charts of the examples you need a graphical 13 | # ``matplotlib`` backend installed. For intance, use ``pip install pyqt5`` 14 | # to get Qt graphical interface or use your favorite one. 15 | 16 | import numpy as np 17 | from matplotlib import pyplot as plt 18 | from sklearn.metrics import pairwise_distances 19 | from sklearn.neighbors import NearestNeighbors 20 | 21 | from metric_learn import (LMNN, ITML_Supervised, LSML_Supervised, 22 | SDML_Supervised) 23 | 24 | 25 | def sandwich_demo(): 26 | x, y = sandwich_data() 27 | knn = nearest_neighbors(x, k=2) 28 | ax = plt.subplot(3, 1, 1) # take the whole top row 29 | plot_sandwich_data(x, y, ax) 30 | plot_neighborhood_graph(x, knn, y, ax) 31 | ax.set_title('input space') 32 | ax.set_aspect('equal') 33 | ax.set_xticks([]) 34 | ax.set_yticks([]) 35 | 36 | mls = [ 37 | LMNN(), 38 | ITML_Supervised(n_constraints=200), 39 | SDML_Supervised(n_constraints=200, balance_param=0.001), 40 | LSML_Supervised(n_constraints=200), 41 | ] 42 | 43 | for ax_num, ml in enumerate(mls, start=3): 44 | ml.fit(x, y) 45 | tx = ml.transform(x) 46 | ml_knn = nearest_neighbors(tx, k=2) 47 | ax = plt.subplot(3, 2, ax_num) 48 | plot_sandwich_data(tx, y, axis=ax) 49 | plot_neighborhood_graph(tx, ml_knn, y, axis=ax) 50 | ax.set_title(ml.__class__.__name__) 51 | ax.set_xticks([]) 52 | ax.set_yticks([]) 53 | plt.show() 54 | 55 | 56 | # TODO: use this somewhere 57 | def visualize_class_separation(X, labels): 58 | _, (ax1, ax2) = plt.subplots(ncols=2) 59 | label_order = np.argsort(labels) 60 | ax1.imshow(pairwise_distances(X[label_order]), interpolation='nearest') 61 | ax2.imshow(pairwise_distances(labels[label_order, None]), 62 | interpolation='nearest') 63 | 64 | 65 | def nearest_neighbors(X, k=5): 66 | knn = NearestNeighbors(n_neighbors=k) 67 | knn.fit(X) 68 | return knn.kneighbors(X, return_distance=False) 69 | 70 | 71 | def sandwich_data(): 72 | # number of distinct classes 73 | num_classes = 6 74 | # number of points per class 75 | num_points = 9 76 | # distance between layers, the points of each class are in a layer 77 | dist = 0.7 78 | 79 | data = np.zeros((num_classes, num_points, 2), dtype=float) 80 | labels = np.zeros((num_classes, num_points), dtype=int) 81 | 82 | x_centers = np.arange(num_points, dtype=float) - num_points / 2 83 | y_centers = dist * (np.arange(num_classes, dtype=float) - num_classes / 2) 84 | for i, yc in enumerate(y_centers): 85 | for k, xc in enumerate(x_centers): 86 | data[i, k, 0] = np.random.normal(xc, 0.1) 87 | data[i, k, 1] = np.random.normal(yc, 0.1) 88 | labels[i, :] = i 89 | return data.reshape((-1, 2)), labels.ravel() 90 | 91 | 92 | def plot_sandwich_data(x, y, axis=plt, colors='rbgmky'): 93 | for idx, val in enumerate(np.unique(y)): 94 | xi = x[y == val] 95 | axis.scatter(*xi.T, s=50, facecolors='none', edgecolors=colors[idx]) 96 | 97 | 98 | def plot_neighborhood_graph(x, nn, y, axis=plt, colors='rbgmky'): 99 | for i, a in enumerate(x): 100 | b = x[nn[i, 1]] 101 | axis.plot((a[0], b[0]), (a[1], b[1]), colors[y[i]]) 102 | 103 | 104 | if __name__ == '__main__': 105 | sandwich_demo() 106 | -------------------------------------------------------------------------------- /metric_learn/__init__.py: -------------------------------------------------------------------------------- 1 | from .constraints import Constraints 2 | from .covariance import Covariance 3 | from .itml import ITML, ITML_Supervised 4 | from .lmnn import LMNN 5 | from .lsml import LSML, LSML_Supervised 6 | from .sdml import SDML, SDML_Supervised 7 | from .nca import NCA 8 | from .lfda import LFDA 9 | from .rca import RCA, RCA_Supervised 10 | from .mlkr import MLKR 11 | from .mmc import MMC, MMC_Supervised 12 | from .scml import SCML, SCML_Supervised 13 | 14 | from ._version import __version__ 15 | 16 | __all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised', 17 | 'LMNN', 'LSML', 'LSML_Supervised', 'SDML', 18 | 'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised', 19 | 'MLKR', 'MMC', 'MMC_Supervised', 'SCML', 20 | 'SCML_Supervised', '__version__'] 21 | -------------------------------------------------------------------------------- /metric_learn/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.7.0' 2 | -------------------------------------------------------------------------------- /metric_learn/constraints.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper module for generating different types of constraints 3 | from supervised data labels. 4 | """ 5 | import numpy as np 6 | import warnings 7 | from sklearn.utils import check_random_state 8 | from sklearn.neighbors import NearestNeighbors 9 | 10 | 11 | __all__ = ['Constraints'] 12 | 13 | 14 | class Constraints(object): 15 | """ 16 | Class to build constraints from labeled data. 17 | 18 | See more in the :ref:`User Guide `. 19 | 20 | Parameters 21 | ---------- 22 | partial_labels : `numpy.ndarray` of ints, shape=(n_samples,) 23 | Array of labels, with -1 indicating unknown label. 24 | 25 | Attributes 26 | ---------- 27 | partial_labels : `numpy.ndarray` of ints, shape=(n_samples,) 28 | Array of labels, with -1 indicating unknown label. 29 | """ 30 | 31 | def __init__(self, partial_labels): 32 | partial_labels = np.asanyarray(partial_labels, dtype=int) 33 | self.partial_labels = partial_labels 34 | 35 | def positive_negative_pairs(self, n_constraints, same_length=False, 36 | random_state=None, num_constraints='deprecated'): 37 | """ 38 | Generates positive pairs and negative pairs from labeled data. 39 | 40 | Positive pairs are formed by randomly drawing ``n_constraints`` pairs of 41 | points with the same label. Negative pairs are formed by randomly drawing 42 | ``n_constraints`` pairs of points with different label. 43 | 44 | In the case where it is not possible to generate enough positive or 45 | negative pairs, a smaller number of pairs will be returned with a warning. 46 | 47 | Parameters 48 | ---------- 49 | n_constraints : int 50 | Number of positive and negative constraints to generate. 51 | 52 | same_length : bool, optional (default=False) 53 | If True, forces the number of positive and negative pairs to be 54 | equal by ignoring some pairs from the larger set. 55 | 56 | random_state : int or numpy.RandomState or None, optional (default=None) 57 | A pseudo random number generator object or a seed for it if int. 58 | 59 | num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0 60 | 61 | Returns 62 | ------- 63 | a : array-like, shape=(n_constraints,) 64 | 1D array of indicators for the left elements of positive pairs. 65 | 66 | b : array-like, shape=(n_constraints,) 67 | 1D array of indicators for the right elements of positive pairs. 68 | 69 | c : array-like, shape=(n_constraints,) 70 | 1D array of indicators for the left elements of negative pairs. 71 | 72 | d : array-like, shape=(n_constraints,) 73 | 1D array of indicators for the right elements of negative pairs. 74 | """ 75 | if num_constraints != 'deprecated': 76 | warnings.warn('"num_constraints" parameter has been renamed to' 77 | ' "n_constraints". It has been deprecated in' 78 | ' version 0.6.3 and will be removed in 0.7.0' 79 | '', FutureWarning) 80 | self.n_constraints = num_constraints 81 | else: 82 | self.n_constraints = n_constraints 83 | random_state = check_random_state(random_state) 84 | a, b = self._pairs(n_constraints, same_label=True, 85 | random_state=random_state) 86 | c, d = self._pairs(n_constraints, same_label=False, 87 | random_state=random_state) 88 | if same_length and len(a) != len(c): 89 | n = min(len(a), len(c)) 90 | return a[:n], b[:n], c[:n], d[:n] 91 | return a, b, c, d 92 | 93 | def generate_knntriplets(self, X, k_genuine, k_impostor): 94 | """ 95 | Generates triplets from labeled data. 96 | 97 | For every point (X_a) the triplets (X_a, X_b, X_c) are constructed from all 98 | the combinations of taking one of its `k_genuine`-nearest neighbors of the 99 | same class (X_b) and taking one of its `k_impostor`-nearest neighbors of 100 | other classes (X_c). 101 | 102 | In the case a class doesn't have enough points in the same class (other 103 | classes) to yield `k_genuine` (`k_impostor`) neighbors a warning will be 104 | raised and the maximum value of genuine (impostor) neighbors will be used 105 | for that class. 106 | 107 | Parameters 108 | ---------- 109 | X : (n x d) matrix 110 | Input data, where each row corresponds to a single instance. 111 | 112 | k_genuine : int 113 | Number of neighbors of the same class to be taken into account. 114 | 115 | k_impostor : int 116 | Number of neighbors of different classes to be taken into account. 117 | 118 | Returns 119 | ------- 120 | triplets : array-like, shape=(n_constraints, 3) 121 | 2D array of triplets of indicators. 122 | """ 123 | # Ignore unlabeled samples 124 | known_labels_mask = self.partial_labels >= 0 125 | known_labels = self.partial_labels[known_labels_mask] 126 | X = X[known_labels_mask] 127 | 128 | labels, labels_count = np.unique(known_labels, return_counts=True) 129 | len_input = known_labels.shape[0] 130 | 131 | # Handle the case where there are too few elements to yield k_genuine or 132 | # k_impostor neighbors for every class. 133 | 134 | k_genuine_vec = np.full_like(labels, k_genuine) 135 | k_impostor_vec = np.full_like(labels, k_impostor) 136 | 137 | for i, count in enumerate(labels_count): 138 | if k_genuine + 1 > count: 139 | k_genuine_vec[i] = count-1 140 | warnings.warn("The class {} has {} elements, which is not sufficient " 141 | "to generate {} genuine neighbors as specified by " 142 | "k_genuine. Will generate {} genuine neighbors instead." 143 | "\n" 144 | .format(labels[i], count, k_genuine+1, 145 | k_genuine_vec[i])) 146 | if k_impostor > len_input - count: 147 | k_impostor_vec[i] = len_input - count 148 | warnings.warn("The class {} has {} elements of other classes, which is" 149 | " not sufficient to generate {} impostor neighbors as " 150 | "specified by k_impostor. Will generate {} impostor " 151 | "neighbors instead.\n" 152 | .format(labels[i], k_impostor_vec[i], k_impostor, 153 | k_impostor_vec[i])) 154 | 155 | # The total number of possible triplets combinations per label comes from 156 | # taking one of the k_genuine_vec[i] genuine neighbors and one of the 157 | # k_impostor_vec[i] impostor neighbors for the labels_count[i] elements 158 | comb_per_label = labels_count * k_genuine_vec * k_impostor_vec 159 | 160 | # Get start and finish for later triplet assigning 161 | # append zero at the begining for start and get cumulative sum 162 | start_finish_indices = np.hstack((0, comb_per_label)).cumsum() 163 | 164 | # Total number of triplets is the sum of all possible combinations per 165 | # label 166 | num_triplets = start_finish_indices[-1] 167 | triplets = np.empty((num_triplets, 3), dtype=np.intp) 168 | 169 | neigh = NearestNeighbors() 170 | 171 | for i, label in enumerate(labels): 172 | 173 | # generate mask for current label 174 | gen_mask = known_labels == label 175 | gen_indx = np.where(gen_mask) 176 | 177 | # get k_genuine genuine neighbors 178 | neigh.fit(X=X[gen_indx]) 179 | # Take elements of gen_indx according to the yielded k-neighbors 180 | gen_relative_indx = neigh.kneighbors(n_neighbors=k_genuine_vec[i], 181 | return_distance=False) 182 | gen_neigh = np.take(gen_indx, gen_relative_indx) 183 | 184 | # generate mask for impostors of current label 185 | imp_indx = np.where(~gen_mask) 186 | 187 | # get k_impostor impostor neighbors 188 | neigh.fit(X=X[imp_indx]) 189 | # Take elements of imp_indx according to the yielded k-neighbors 190 | imp_relative_indx = neigh.kneighbors(n_neighbors=k_impostor_vec[i], 191 | X=X[gen_mask], 192 | return_distance=False) 193 | imp_neigh = np.take(imp_indx, imp_relative_indx) 194 | 195 | # length = len_label*k_genuine*k_impostor 196 | start, finish = start_finish_indices[i:i+2] 197 | 198 | triplets[start:finish, :] = comb(gen_indx, gen_neigh, imp_neigh, 199 | k_genuine_vec[i], 200 | k_impostor_vec[i]) 201 | 202 | return triplets 203 | 204 | def _pairs(self, n_constraints, same_label=True, max_iter=10, 205 | random_state=np.random): 206 | known_label_idx, = np.where(self.partial_labels >= 0) 207 | known_labels = self.partial_labels[known_label_idx] 208 | num_labels = len(known_labels) 209 | ab = set() 210 | it = 0 211 | while it < max_iter and len(ab) < n_constraints: 212 | nc = n_constraints - len(ab) 213 | for aidx in random_state.randint(num_labels, size=nc): 214 | if same_label: 215 | mask = known_labels[aidx] == known_labels 216 | mask[aidx] = False # avoid identity pairs 217 | else: 218 | mask = known_labels[aidx] != known_labels 219 | b_choices, = np.where(mask) 220 | if len(b_choices) > 0: 221 | ab.add((aidx, random_state.choice(b_choices))) 222 | it += 1 223 | if len(ab) < n_constraints: 224 | warnings.warn("Only generated %d %s constraints (requested %d)" % ( 225 | len(ab), 'positive' if same_label else 'negative', n_constraints)) 226 | ab = np.array(list(ab)[:n_constraints], dtype=int) 227 | return known_label_idx[ab.T] 228 | 229 | def chunks(self, n_chunks=100, chunk_size=2, random_state=None, 230 | num_chunks='deprecated'): 231 | """ 232 | Generates chunks from labeled data. 233 | 234 | Each of ``n_chunks`` chunks is composed of ``chunk_size`` points from 235 | the same class drawn at random. Each point can belong to at most 1 chunk. 236 | 237 | In the case where there is not enough points to generate ``n_chunks`` 238 | chunks of size ``chunk_size``, a ValueError will be raised. 239 | 240 | Parameters 241 | ---------- 242 | n_chunks : int, optional (default=100) 243 | Number of chunks to generate. 244 | 245 | chunk_size : int, optional (default=2) 246 | Number of points in each chunk. 247 | 248 | random_state : int or numpy.RandomState or None, optional (default=None) 249 | A pseudo random number generator object or a seed for it if int. 250 | 251 | num_chunks : Renamed to n_chunks. Will be deprecated in 0.7.0 252 | 253 | Returns 254 | ------- 255 | chunks : array-like, shape=(n_samples,) 256 | 1D array of chunk indicators, where -1 indicates that the point does not 257 | belong to any chunk. 258 | """ 259 | if num_chunks != 'deprecated': 260 | warnings.warn('"num_chunks" parameter has been renamed to' 261 | ' "n_chunks". It has been deprecated in' 262 | ' version 0.6.3 and will be removed in 0.7.0' 263 | '', FutureWarning) 264 | n_chunks = num_chunks 265 | random_state = check_random_state(random_state) 266 | chunks = -np.ones_like(self.partial_labels, dtype=int) 267 | uniq, lookup = np.unique(self.partial_labels, return_inverse=True) 268 | unknown_uniq = np.where(uniq < 0)[0] 269 | all_inds = [set(np.where(lookup == c)[0]) for c in range(len(uniq)) 270 | if c not in unknown_uniq] 271 | max_chunks = int(np.sum([len(s) // chunk_size for s in all_inds])) 272 | if max_chunks < n_chunks: 273 | raise ValueError(('Not enough possible chunks of %d elements in each' 274 | ' class to form expected %d chunks - maximum number' 275 | ' of chunks is %d' 276 | ) % (chunk_size, n_chunks, max_chunks)) 277 | idx = 0 278 | while idx < n_chunks and all_inds: 279 | if len(all_inds) == 1: 280 | c = 0 281 | else: 282 | c = random_state.randint(0, high=len(all_inds) - 1) 283 | inds = all_inds[c] 284 | if len(inds) < chunk_size: 285 | del all_inds[c] 286 | continue 287 | ii = random_state.choice(list(inds), chunk_size, replace=False) 288 | inds.difference_update(ii) 289 | chunks[ii] = idx 290 | idx += 1 291 | return chunks 292 | 293 | 294 | def comb(A, B, C, sizeB, sizeC): 295 | # generate_knntriplets helper function 296 | # generate an array with all combinations of choosing 297 | # an element from A, B and C 298 | return np.vstack((np.tile(A, (sizeB*sizeC, 1)).ravel(order='F'), 299 | np.tile(np.hstack(B), (sizeC, 1)).ravel(order='F'), 300 | np.tile(C, (1, sizeB)).ravel())).T 301 | 302 | 303 | def wrap_pairs(X, constraints): 304 | a = np.array(constraints[0]) 305 | b = np.array(constraints[1]) 306 | c = np.array(constraints[2]) 307 | d = np.array(constraints[3]) 308 | constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) 309 | y = np.concatenate([np.ones_like(a), -np.ones_like(c)]) 310 | pairs = X[constraints] 311 | return pairs, y 312 | -------------------------------------------------------------------------------- /metric_learn/covariance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Covariance metric (baseline method) 3 | """ 4 | 5 | import numpy as np 6 | import scipy 7 | from sklearn.base import TransformerMixin 8 | 9 | from .base_metric import MahalanobisMixin 10 | from ._util import components_from_metric 11 | 12 | 13 | class Covariance(MahalanobisMixin, TransformerMixin): 14 | """Covariance metric (baseline method) 15 | 16 | This method does not "learn" anything, rather it calculates 17 | the covariance matrix of the input data. 18 | 19 | This is a simple baseline method first introduced in 20 | On the Generalized Distance in Statistics, P.C.Mahalanobis, 1936 21 | 22 | Read more in the :ref:`User Guide `. 23 | 24 | Attributes 25 | ---------- 26 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 27 | The linear transformation ``L`` deduced from the learned Mahalanobis 28 | metric (See function `components_from_metric`.) 29 | 30 | Examples 31 | -------- 32 | >>> from metric_learn import Covariance 33 | >>> from sklearn.datasets import load_iris 34 | >>> iris = load_iris()['data'] 35 | >>> cov = Covariance().fit(iris) 36 | >>> x = cov.transform(iris) 37 | 38 | """ 39 | 40 | def __init__(self, preprocessor=None): 41 | super(Covariance, self).__init__(preprocessor) 42 | 43 | def fit(self, X, y=None): 44 | """ 45 | Calculates the covariance matrix of the input data. 46 | 47 | Parameters 48 | ---------- 49 | X : data matrix, (n x d) 50 | y : unused 51 | """ 52 | X = self._prepare_inputs(X, ensure_min_samples=2) 53 | M = np.atleast_2d(np.cov(X, rowvar=False)) 54 | if M.size == 1: 55 | M = 1. / M 56 | else: 57 | M = scipy.linalg.pinvh(M) 58 | 59 | self.components_ = components_from_metric(np.atleast_2d(M)) 60 | return self 61 | -------------------------------------------------------------------------------- /metric_learn/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`metric_learn.exceptions` module includes all custom warnings and 3 | error classes used across metric-learn. 4 | """ 5 | from numpy.linalg import LinAlgError 6 | 7 | 8 | class PreprocessorError(Exception): 9 | 10 | def __init__(self, original_error): 11 | err_msg = ("An error occurred when trying to use the " 12 | "preprocessor: {}").format(repr(original_error)) 13 | super(PreprocessorError, self).__init__(err_msg) 14 | 15 | 16 | class NonPSDError(LinAlgError): 17 | 18 | def __init__(self): 19 | err_msg = "Matrix is not positive semidefinite (PSD)." 20 | super(LinAlgError, self).__init__(err_msg) 21 | -------------------------------------------------------------------------------- /metric_learn/itml.py: -------------------------------------------------------------------------------- 1 | """ 2 | Information Theoretic Metric Learning (ITML) 3 | """ 4 | 5 | import numpy as np 6 | from sklearn.metrics import pairwise_distances 7 | from sklearn.utils.validation import check_array 8 | from sklearn.base import TransformerMixin 9 | from .base_metric import _PairsClassifierMixin, MahalanobisMixin 10 | from .constraints import Constraints, wrap_pairs 11 | from ._util import components_from_metric, _initialize_metric_mahalanobis 12 | import warnings 13 | 14 | 15 | class _BaseITML(MahalanobisMixin): 16 | """Information Theoretic Metric Learning (ITML)""" 17 | 18 | _tuple_size = 2 # constraints are pairs 19 | 20 | def __init__(self, gamma=1., max_iter=1000, tol=1e-3, 21 | prior='identity', verbose=False, 22 | preprocessor=None, random_state=None, 23 | convergence_threshold='deprecated'): 24 | if convergence_threshold != 'deprecated': 25 | warnings.warn('"convergence_threshold" parameter has been ' 26 | ' renamed to "tol". It has been deprecated in' 27 | ' version 0.6.3 and will be removed in 0.7.0' 28 | '', FutureWarning) 29 | tol = convergence_threshold 30 | self.convergence_threshold = 'deprecated' # Avoid errors 31 | self.gamma = gamma 32 | self.max_iter = max_iter 33 | self.tol = tol 34 | self.prior = prior 35 | self.verbose = verbose 36 | self.random_state = random_state 37 | super(_BaseITML, self).__init__(preprocessor) 38 | 39 | def _fit(self, pairs, y, bounds=None): 40 | pairs, y = self._prepare_inputs(pairs, y, 41 | type_of_inputs='tuples') 42 | # init bounds 43 | if bounds is None: 44 | X = np.unique(np.vstack(pairs), axis=0) 45 | self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) 46 | else: 47 | bounds = check_array(bounds, allow_nd=False, ensure_min_samples=0, 48 | ensure_2d=False) 49 | bounds = bounds.ravel() 50 | if bounds.size != 2: 51 | raise ValueError("`bounds` should be an array-like of two elements.") 52 | self.bounds_ = bounds 53 | self.bounds_[self.bounds_ == 0] = 1e-9 54 | # set the prior 55 | # pairs will be deduplicated into X two times, TODO: avoid that 56 | A = _initialize_metric_mahalanobis(pairs, self.prior, self.random_state, 57 | strict_pd=True, 58 | matrix_name='prior') 59 | gamma = self.gamma 60 | pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] 61 | num_pos = len(pos_pairs) 62 | num_neg = len(neg_pairs) 63 | _lambda = np.zeros(num_pos + num_neg) 64 | lambdaold = np.zeros_like(_lambda) 65 | gamma_proj = 1. if gamma is np.inf else gamma / (gamma + 1.) 66 | pos_bhat = np.zeros(num_pos) + self.bounds_[0] 67 | neg_bhat = np.zeros(num_neg) + self.bounds_[1] 68 | pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] 69 | neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] 70 | 71 | for it in range(self.max_iter): 72 | # update positives 73 | for i, v in enumerate(pos_vv): 74 | wtw = v.dot(A).dot(v) # scalar 75 | alpha = min(_lambda[i], gamma_proj * (1. / wtw - 1. / pos_bhat[i])) 76 | _lambda[i] -= alpha 77 | beta = alpha / (1 - alpha * wtw) 78 | pos_bhat[i] = 1. / ((1 / pos_bhat[i]) + (alpha / gamma)) 79 | Av = A.dot(v) 80 | A += np.outer(Av, Av * beta) 81 | 82 | # update negatives 83 | for i, v in enumerate(neg_vv): 84 | wtw = v.dot(A).dot(v) # scalar 85 | alpha = min(_lambda[i + num_pos], 86 | gamma_proj * (1. / neg_bhat[i] - 1. / wtw)) 87 | _lambda[i + num_pos] -= alpha 88 | beta = -alpha / (1 + alpha * wtw) 89 | neg_bhat[i] = 1. / ((1 / neg_bhat[i]) - (alpha / gamma)) 90 | Av = A.dot(v) 91 | A += np.outer(Av, Av * beta) 92 | 93 | normsum = np.linalg.norm(_lambda) + np.linalg.norm(lambdaold) 94 | if normsum == 0: 95 | conv = np.inf 96 | break 97 | conv = np.abs(lambdaold - _lambda).sum() / normsum 98 | if conv < self.tol: 99 | break 100 | lambdaold = _lambda.copy() 101 | if self.verbose: 102 | print('itml iter: %d, conv = %f' % (it, conv)) 103 | 104 | if self.verbose: 105 | print('itml converged at iter: %d, conv = %f' % (it, conv)) 106 | self.n_iter_ = it 107 | 108 | self.components_ = components_from_metric(A) 109 | return self 110 | 111 | 112 | class ITML(_BaseITML, _PairsClassifierMixin): 113 | """Information Theoretic Metric Learning (ITML) 114 | 115 | `ITML` minimizes the (differential) relative entropy, aka Kullback-Leibler 116 | divergence, between two multivariate Gaussians subject to constraints on the 117 | associated Mahalanobis distance, which can be formulated into a Bregman 118 | optimization problem by minimizing the LogDet divergence subject to 119 | linear constraints. This algorithm can handle a wide variety of constraints 120 | and can optionally incorporate a prior on the distance function. Unlike some 121 | other methods, `ITML` does not rely on an eigenvalue computation or 122 | semi-definite programming. 123 | 124 | Read more in the :ref:`User Guide `. 125 | 126 | Parameters 127 | ---------- 128 | gamma : float, optional (default=1.0) 129 | Value for slack variables 130 | 131 | max_iter : int, optional (default=1000) 132 | Maximum number of iteration of the optimization procedure. 133 | 134 | tol : float, optional (default=1e-3) 135 | Convergence tolerance. 136 | 137 | prior : string or numpy array, optional (default='identity') 138 | The Mahalanobis matrix to use as a prior. Possible options are 139 | 'identity', 'covariance', 'random', and a numpy array of shape 140 | (n_features, n_features). For ITML, the prior should be strictly 141 | positive definite (PD). 142 | 143 | 'identity' 144 | An identity matrix of shape (n_features, n_features). 145 | 146 | 'covariance' 147 | The inverse covariance matrix. 148 | 149 | 'random' 150 | The prior will be a random SPD matrix of shape 151 | `(n_features, n_features)`, generated using 152 | `sklearn.datasets.make_spd_matrix`. 153 | 154 | numpy array 155 | A positive definite (PD) matrix of shape 156 | (n_features, n_features), that will be used as such to set the 157 | prior. 158 | 159 | verbose : bool, optional (default=False) 160 | If True, prints information while learning 161 | 162 | preprocessor : array-like, shape=(n_samples, n_features) or callable 163 | The preprocessor to call to get tuples from indices. If array-like, 164 | tuples will be formed like this: X[indices]. 165 | 166 | random_state : int or numpy.RandomState or None, optional (default=None) 167 | A pseudo random number generator object or a seed for it if int. If 168 | ``prior='random'``, ``random_state`` is used to set the prior. 169 | 170 | convergence_threshold : Renamed to tol. Will be deprecated in 0.7.0 171 | 172 | Attributes 173 | ---------- 174 | bounds_ : `numpy.ndarray`, shape=(2,) 175 | Bounds on similarity, aside slack variables, s.t. 176 | ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` 177 | and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of 178 | dissimilar points ``c`` and ``d``, with ``d`` the learned distance. If 179 | not provided at initialization, bounds_[0] and bounds_[1] are set at 180 | train time to the 5th and 95th percentile of the pairwise distances among 181 | all points present in the input `pairs`. 182 | 183 | n_iter_ : `int` 184 | The number of iterations the solver has run. 185 | 186 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 187 | The linear transformation ``L`` deduced from the learned Mahalanobis 188 | metric (See function `components_from_metric`.) 189 | 190 | threshold_ : `float` 191 | If the distance metric between two points is lower than this threshold, 192 | points will be classified as similar, otherwise they will be 193 | classified as dissimilar. 194 | 195 | Examples 196 | -------- 197 | >>> from metric_learn import ITML 198 | >>> pairs = [[[1.2, 7.5], [1.3, 1.5]], 199 | >>> [[6.4, 2.6], [6.2, 9.7]], 200 | >>> [[1.3, 4.5], [3.2, 4.6]], 201 | >>> [[6.2, 5.5], [5.4, 5.4]]] 202 | >>> y = [1, 1, -1, -1] 203 | >>> # in this task we want points where the first feature is close to be 204 | >>> # closer to each other, no matter how close the second feature is 205 | >>> itml = ITML() 206 | >>> itml.fit(pairs, y) 207 | 208 | References 209 | ---------- 210 | .. [1] Jason V. Davis, et al. `Information-theoretic Metric Learning 211 | `_. ICML 2007. 213 | """ 214 | 215 | def fit(self, pairs, y, bounds=None, calibration_params=None): 216 | """Learn the ITML model. 217 | 218 | The threshold will be calibrated on the trainset using the parameters 219 | `calibration_params`. 220 | 221 | Parameters 222 | ---------- 223 | pairs: array-like, shape=(n_constraints, 2, n_features) or \ 224 | (n_constraints, 2) 225 | 3D Array of pairs with each row corresponding to two points, 226 | or 2D array of indices of pairs if the metric learner uses a 227 | preprocessor. 228 | 229 | y: array-like, of shape (n_constraints,) 230 | Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. 231 | 232 | bounds : array-like of two numbers 233 | Bounds on similarity, aside slack variables, s.t. 234 | ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` 235 | and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of 236 | dissimilar points ``c`` and ``d``, with ``d`` the learned distance. 237 | If not provided at initialization, bounds_[0] and bounds_[1] will be 238 | set to the 5th and 95th percentile of the pairwise distances among all 239 | points present in the input `pairs`. 240 | 241 | calibration_params : `dict` or `None` 242 | Dictionary of parameters to give to `calibrate_threshold` for the 243 | threshold calibration step done at the end of `fit`. If `None` is 244 | given, `calibrate_threshold` will use the default parameters. 245 | 246 | Returns 247 | ------- 248 | self : object 249 | Returns the instance. 250 | """ 251 | calibration_params = (calibration_params if calibration_params is not 252 | None else dict()) 253 | self._validate_calibration_params(**calibration_params) 254 | self._fit(pairs, y, bounds=bounds) 255 | self.calibrate_threshold(pairs, y, **calibration_params) 256 | return self 257 | 258 | 259 | class ITML_Supervised(_BaseITML, TransformerMixin): 260 | """Supervised version of Information Theoretic Metric Learning (ITML) 261 | 262 | `ITML_Supervised` creates pairs of similar sample by taking same class 263 | samples, and pairs of dissimilar samples by taking different class 264 | samples. It then passes these pairs to `ITML` for training. 265 | 266 | Parameters 267 | ---------- 268 | gamma : float, optional (default=1.0) 269 | Value for slack variables 270 | 271 | max_iter : int, optional (default=1000) 272 | Maximum number of iterations of the optimization procedure. 273 | 274 | tol : float, optional (default=1e-3) 275 | Tolerance of the optimization procedure. 276 | 277 | n_constraints : int, optional (default=None) 278 | Number of constraints to generate. If None, default to `20 * 279 | num_classes**2`. 280 | 281 | prior : string or numpy array, optional (default='identity') 282 | Initialization of the Mahalanobis matrix. Possible options are 283 | 'identity', 'covariance', 'random', and a numpy array of shape 284 | (n_features, n_features). For ITML, the prior should be strictly 285 | positive definite (PD). 286 | 287 | 'identity' 288 | An identity matrix of shape (n_features, n_features). 289 | 290 | 'covariance' 291 | The inverse covariance matrix. 292 | 293 | 'random' 294 | The prior will be a random SPD matrix of shape 295 | `(n_features, n_features)`, generated using 296 | `sklearn.datasets.make_spd_matrix`. 297 | 298 | numpy array 299 | A positive definite (PD) matrix of shape 300 | (n_features, n_features), that will be used as such to set the 301 | prior. 302 | 303 | verbose : bool, optional (default=False) 304 | If True, prints information while learning 305 | 306 | preprocessor : array-like, shape=(n_samples, n_features) or callable 307 | The preprocessor to call to get tuples from indices. If array-like, 308 | tuples will be formed like this: X[indices]. 309 | 310 | random_state : int or numpy.RandomState or None, optional (default=None) 311 | A pseudo random number generator object or a seed for it if int. If 312 | ``prior='random'``, ``random_state`` is used to set the prior. In any 313 | case, `random_state` is also used to randomly sample constraints from 314 | labels. 315 | 316 | num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0 317 | 318 | convergence_threshold : Renamed to tol. Will be deprecated in 0.7.0 319 | 320 | Attributes 321 | ---------- 322 | bounds_ : `numpy.ndarray`, shape=(2,) 323 | Bounds on similarity, aside slack variables, s.t. 324 | ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` 325 | and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of 326 | dissimilar points ``c`` and ``d``, with ``d`` the learned distance. 327 | If not provided at initialization, bounds_[0] and bounds_[1] are set at 328 | train time to the 5th and 95th percentile of the pairwise distances 329 | among all points in the training data `X`. 330 | 331 | n_iter_ : `int` 332 | The number of iterations the solver has run. 333 | 334 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 335 | The linear transformation ``L`` deduced from the learned Mahalanobis 336 | metric (See function `components_from_metric`.) 337 | 338 | Examples 339 | -------- 340 | >>> from metric_learn import ITML_Supervised 341 | >>> from sklearn.datasets import load_iris 342 | >>> iris_data = load_iris() 343 | >>> X = iris_data['data'] 344 | >>> Y = iris_data['target'] 345 | >>> itml = ITML_Supervised(n_constraints=200) 346 | >>> itml.fit(X, Y) 347 | 348 | See Also 349 | -------- 350 | metric_learn.ITML : The original weakly-supervised algorithm 351 | :ref:`supervised_version` : The section of the project documentation 352 | that describes the supervised version of weakly supervised estimators. 353 | """ 354 | 355 | def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3, 356 | n_constraints=None, prior='identity', 357 | verbose=False, preprocessor=None, random_state=None, 358 | num_constraints='deprecated', 359 | convergence_threshold='deprecated'): 360 | _BaseITML.__init__(self, gamma=gamma, max_iter=max_iter, 361 | tol=tol, 362 | prior=prior, verbose=verbose, 363 | preprocessor=preprocessor, 364 | random_state=random_state, 365 | convergence_threshold=convergence_threshold) 366 | if num_constraints != 'deprecated': 367 | warnings.warn('"num_constraints" parameter has been renamed to' 368 | ' "n_constraints". It has been deprecated in' 369 | ' version 0.6.3 and will be removed in 0.7.0' 370 | '', FutureWarning) 371 | n_constraints = num_constraints 372 | self.n_constraints = n_constraints 373 | # Avoid test get_params from failing (all params passed sholud be set) 374 | self.num_constraints = 'deprecated' 375 | 376 | def fit(self, X, y, bounds=None): 377 | """Create constraints from labels and learn the ITML model. 378 | 379 | 380 | Parameters 381 | ---------- 382 | X : (n x d) matrix 383 | Input data, where each row corresponds to a single instance. 384 | 385 | y : (n) array-like 386 | Data labels. 387 | 388 | bounds : array-like of two numbers 389 | Bounds on similarity, aside slack variables, s.t. 390 | ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` 391 | and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of 392 | dissimilar points ``c`` and ``d``, with ``d`` the learned distance. 393 | If not provided at initialization, bounds_[0] and bounds_[1] will be 394 | set to the 5th and 95th percentile of the pairwise distances among all 395 | points in the training data `X`. 396 | """ 397 | X, y = self._prepare_inputs(X, y, ensure_min_samples=2) 398 | n_constraints = self.n_constraints 399 | if n_constraints is None: 400 | num_classes = len(np.unique(y)) 401 | n_constraints = 20 * num_classes**2 402 | 403 | c = Constraints(y) 404 | pos_neg = c.positive_negative_pairs(n_constraints, 405 | random_state=self.random_state) 406 | pairs, y = wrap_pairs(X, pos_neg) 407 | return _BaseITML._fit(self, pairs, y, bounds=bounds) 408 | -------------------------------------------------------------------------------- /metric_learn/lfda.py: -------------------------------------------------------------------------------- 1 | """ 2 | Local Fisher Discriminant Analysis (LFDA) 3 | """ 4 | import numpy as np 5 | import scipy 6 | import warnings 7 | from sklearn.metrics import pairwise_distances 8 | from sklearn.base import TransformerMixin 9 | 10 | from ._util import _check_n_components 11 | from .base_metric import MahalanobisMixin 12 | 13 | 14 | class LFDA(MahalanobisMixin, TransformerMixin): 15 | ''' 16 | Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction 17 | 18 | LFDA is a linear supervised dimensionality reduction method. It is 19 | particularly useful when dealing with multimodality, where one ore more 20 | classes consist of separate clusters in input space. The core optimization 21 | problem of LFDA is solved as a generalized eigenvalue problem. 22 | 23 | Read more in the :ref:`User Guide `. 24 | 25 | Parameters 26 | ---------- 27 | n_components : int or None, optional (default=None) 28 | Dimensionality of reduced space (if None, defaults to dimension of X). 29 | 30 | k : int, optional (default=None) 31 | Number of nearest neighbors used in local scaling method. If None, 32 | defaults to min(7, n_features - 1). 33 | 34 | embedding_type : str, optional (default: 'weighted') 35 | Type of metric in the embedding space. 36 | 37 | 'weighted' 38 | weighted eigenvectors 39 | 40 | 'orthonormalized' 41 | orthonormalized 42 | 43 | 'plain' 44 | raw eigenvectors 45 | 46 | preprocessor : array-like, shape=(n_samples, n_features) or callable 47 | The preprocessor to call to get tuples from indices. If array-like, 48 | tuples will be formed like this: X[indices]. 49 | 50 | Attributes 51 | ---------- 52 | components_ : `numpy.ndarray`, shape=(n_components, n_features) 53 | The learned linear transformation ``L``. 54 | 55 | Examples 56 | -------- 57 | 58 | >>> import numpy as np 59 | >>> from metric_learn import LFDA 60 | >>> from sklearn.datasets import load_iris 61 | >>> iris_data = load_iris() 62 | >>> X = iris_data['data'] 63 | >>> Y = iris_data['target'] 64 | >>> lfda = LFDA(k=2, dim=2) 65 | >>> lfda.fit(X, Y) 66 | 67 | References 68 | ---------- 69 | .. [1] Masashi Sugiyama. `Dimensionality Reduction of Multimodal Labeled 70 | Data by Local Fisher Discriminant Analysis 71 | `_. JMLR 2007. 72 | 73 | .. [2] Yuan Tang. `Local Fisher Discriminant Analysis on Beer Style 74 | Clustering 75 | `_. 77 | ''' 78 | 79 | def __init__(self, n_components=None, 80 | k=None, embedding_type='weighted', preprocessor=None): 81 | if embedding_type not in ('weighted', 'orthonormalized', 'plain'): 82 | raise ValueError('Invalid embedding_type: %r' % embedding_type) 83 | self.n_components = n_components 84 | self.embedding_type = embedding_type 85 | self.k = k 86 | super(LFDA, self).__init__(preprocessor) 87 | 88 | def fit(self, X, y): 89 | '''Fit the LFDA model. 90 | 91 | Parameters 92 | ---------- 93 | X : (n, d) array-like 94 | Input data. 95 | 96 | y : (n,) array-like 97 | Class labels, one per point of data. 98 | ''' 99 | X, y = self._prepare_inputs(X, y, ensure_min_samples=2) 100 | unique_classes, y = np.unique(y, return_inverse=True) 101 | n, d = X.shape 102 | num_classes = len(unique_classes) 103 | 104 | dim = _check_n_components(d, self.n_components) 105 | 106 | if self.k is None: 107 | k = min(7, d - 1) 108 | elif self.k >= d: 109 | warnings.warn('Chosen k (%d) too large, using %d instead.' 110 | % (self.k, d - 1)) 111 | k = d - 1 112 | else: 113 | k = int(self.k) 114 | tSb = np.zeros((d, d)) 115 | tSw = np.zeros((d, d)) 116 | 117 | for c in range(num_classes): 118 | Xc = X[y == c] 119 | nc = Xc.shape[0] 120 | 121 | # classwise affinity matrix 122 | dist = pairwise_distances(Xc, metric='l2', squared=True) 123 | # distances to k-th nearest neighbor 124 | k = min(k, nc - 1) 125 | sigma = np.sqrt(np.partition(dist, k, axis=0)[:, k]) 126 | 127 | local_scale = np.outer(sigma, sigma) 128 | with np.errstate(divide='ignore', invalid='ignore'): 129 | A = np.exp(-dist / local_scale) 130 | A[local_scale == 0] = 0 131 | 132 | G = Xc.T.dot(A.sum(axis=0)[:, None] * Xc) - Xc.T.dot(A).dot(Xc) 133 | tSb += G / n + (1 - nc / n) * Xc.T.dot(Xc) + _sum_outer(Xc) / n 134 | tSw += G / nc 135 | 136 | tSb -= _sum_outer(X) / n - tSw 137 | 138 | # symmetrize 139 | tSb = (tSb + tSb.T) / 2 140 | tSw = (tSw + tSw.T) / 2 141 | 142 | vals, vecs = _eigh(tSb, tSw, dim) 143 | order = np.argsort(-vals)[:dim] 144 | vals = vals[order].real 145 | vecs = vecs[:, order] 146 | 147 | if self.embedding_type == 'weighted': 148 | vecs *= np.sqrt(vals) 149 | elif self.embedding_type == 'orthonormalized': 150 | vecs, _ = np.linalg.qr(vecs) 151 | 152 | self.components_ = vecs.T 153 | return self 154 | 155 | 156 | def _sum_outer(x): 157 | s = x.sum(axis=0) 158 | return np.outer(s, s) 159 | 160 | 161 | def _eigh(a, b, dim): 162 | try: 163 | return scipy.sparse.linalg.eigsh(a, k=dim, M=b, which='LA') 164 | except np.linalg.LinAlgError: 165 | pass # scipy already tried eigh for us 166 | except (ValueError, scipy.sparse.linalg.ArpackNoConvergence): 167 | try: 168 | return scipy.linalg.eigh(a, b) 169 | except np.linalg.LinAlgError: 170 | pass 171 | return scipy.linalg.eig(a, b) 172 | -------------------------------------------------------------------------------- /metric_learn/lmnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Large Margin Nearest Neighbor Metric learning (LMNN) 3 | """ 4 | import numpy as np 5 | from collections import Counter 6 | from sklearn.metrics import euclidean_distances 7 | from sklearn.base import TransformerMixin 8 | import warnings 9 | 10 | from ._util import _initialize_components, _check_n_components 11 | from .base_metric import MahalanobisMixin 12 | 13 | 14 | class LMNN(MahalanobisMixin, TransformerMixin): 15 | """Large Margin Nearest Neighbor (LMNN) 16 | 17 | LMNN learns a Mahalanobis distance metric in the kNN classification 18 | setting. The learned metric attempts to keep close k-nearest neighbors 19 | from the same class, while keeping examples from different classes 20 | separated by a large margin. This algorithm makes no assumptions about 21 | the distribution of the data. 22 | 23 | Read more in the :ref:`User Guide `. 24 | 25 | Parameters 26 | ---------- 27 | init : string or numpy array, optional (default='auto') 28 | Initialization of the linear transformation. Possible options are 29 | 'auto', 'pca', 'identity', 'random', and a numpy array of shape 30 | (n_features_a, n_features_b). 31 | 32 | 'auto' 33 | Depending on ``n_components``, the most reasonable initialization 34 | will be chosen. If ``n_components <= n_classes`` we use 'lda', as 35 | it uses labels information. If not, but 36 | ``n_components < min(n_features, n_samples)``, we use 'pca', as 37 | it projects data in meaningful directions (those of higher 38 | variance). Otherwise, we just use 'identity'. 39 | 40 | 'pca' 41 | ``n_components`` principal components of the inputs passed 42 | to :meth:`fit` will be used to initialize the transformation. 43 | (See `sklearn.decomposition.PCA`) 44 | 45 | 'lda' 46 | ``min(n_components, n_classes)`` most discriminative 47 | components of the inputs passed to :meth:`fit` will be used to 48 | initialize the transformation. (If ``n_components > n_classes``, 49 | the rest of the components will be zero.) (See 50 | `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) 51 | 52 | 'identity' 53 | If ``n_components`` is strictly smaller than the 54 | dimensionality of the inputs passed to :meth:`fit`, the identity 55 | matrix will be truncated to the first ``n_components`` rows. 56 | 57 | 'random' 58 | The initial transformation will be a random array of shape 59 | `(n_components, n_features)`. Each value is sampled from the 60 | standard normal distribution. 61 | 62 | numpy array 63 | n_features_b must match the dimensionality of the inputs passed to 64 | :meth:`fit` and n_features_a must be less than or equal to that. 65 | If ``n_components`` is not None, n_features_a must match it. 66 | 67 | n_neighbors : int, optional (default=3) 68 | Number of neighbors to consider, not including self-edges. 69 | 70 | min_iter : int, optional (default=50) 71 | Minimum number of iterations of the optimization procedure. 72 | 73 | max_iter : int, optional (default=1000) 74 | Maximum number of iterations of the optimization procedure. 75 | 76 | learn_rate : float, optional (default=1e-7) 77 | Learning rate of the optimization procedure 78 | 79 | tol : float, optional (default=0.001) 80 | Tolerance of the optimization procedure. If the objective value varies 81 | less than `tol`, we consider the algorithm has converged and stop it. 82 | 83 | verbose : bool, optional (default=False) 84 | Whether to print the progress of the optimization procedure. 85 | 86 | regularization: float, optional (default=0.5) 87 | Relative weight between pull and push terms, with 0.5 meaning equal 88 | weight. 89 | 90 | preprocessor : array-like, shape=(n_samples, n_features) or callable 91 | The preprocessor to call to get tuples from indices. If array-like, 92 | tuples will be formed like this: X[indices]. 93 | 94 | n_components : int or None, optional (default=None) 95 | Dimensionality of reduced space (if None, defaults to dimension of X). 96 | 97 | random_state : int or numpy.RandomState or None, optional (default=None) 98 | A pseudo random number generator object or a seed for it if int. If 99 | ``init='random'``, ``random_state`` is used to initialize the random 100 | transformation. If ``init='pca'``, ``random_state`` is passed as an 101 | argument to PCA when initializing the transformation. 102 | 103 | k : Renamed to n_neighbors. Will be deprecated in 0.7.0 104 | 105 | Attributes 106 | ---------- 107 | n_iter_ : `int` 108 | The number of iterations the solver has run. 109 | 110 | components_ : `numpy.ndarray`, shape=(n_components, n_features) 111 | The learned linear transformation ``L``. 112 | 113 | Examples 114 | -------- 115 | 116 | >>> import numpy as np 117 | >>> from metric_learn import LMNN 118 | >>> from sklearn.datasets import load_iris 119 | >>> iris_data = load_iris() 120 | >>> X = iris_data['data'] 121 | >>> Y = iris_data['target'] 122 | >>> lmnn = LMNN(n_neighbors=5, learn_rate=1e-6) 123 | >>> lmnn.fit(X, Y, verbose=False) 124 | 125 | References 126 | ---------- 127 | .. [1] K. Q. Weinberger, J. Blitzer, L. K. Saul. `Distance Metric 128 | Learning for Large Margin Nearest Neighbor Classification 129 | `_. NIPS 131 | 2005. 132 | """ 133 | 134 | def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000, 135 | learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, 136 | verbose=False, preprocessor=None, 137 | n_components=None, random_state=None, k='deprecated'): 138 | self.init = init 139 | if k != 'deprecated': 140 | warnings.warn('"num_chunks" parameter has been renamed to' 141 | ' "n_chunks". It has been deprecated in' 142 | ' version 0.6.3 and will be removed in 0.7.0' 143 | '', FutureWarning) 144 | n_neighbors = k 145 | self.k = 'deprecated' # To avoid no_attribute error 146 | self.n_neighbors = n_neighbors 147 | self.min_iter = min_iter 148 | self.max_iter = max_iter 149 | self.learn_rate = learn_rate 150 | self.regularization = regularization 151 | self.convergence_tol = convergence_tol 152 | self.verbose = verbose 153 | self.n_components = n_components 154 | self.random_state = random_state 155 | super(LMNN, self).__init__(preprocessor) 156 | 157 | def fit(self, X, y): 158 | k = self.n_neighbors 159 | reg = self.regularization 160 | learn_rate = self.learn_rate 161 | 162 | X, y = self._prepare_inputs(X, y, dtype=float, 163 | ensure_min_samples=2) 164 | num_pts, d = X.shape 165 | output_dim = _check_n_components(d, self.n_components) 166 | unique_labels, label_inds = np.unique(y, return_inverse=True) 167 | if len(label_inds) != num_pts: 168 | raise ValueError('Must have one label per point.') 169 | self.labels_ = np.arange(len(unique_labels)) 170 | 171 | self.components_ = _initialize_components(output_dim, X, y, self.init, 172 | self.verbose, 173 | random_state=self.random_state) 174 | required_k = np.bincount(label_inds).min() 175 | if self.n_neighbors > required_k: 176 | raise ValueError('not enough class labels for specified k' 177 | ' (smallest class has %d)' % required_k) 178 | 179 | target_neighbors = self._select_targets(X, label_inds) 180 | 181 | # sum outer products 182 | dfG = _sum_outer_products(X, target_neighbors.flatten(), 183 | np.repeat(np.arange(X.shape[0]), k)) 184 | 185 | # initialize L 186 | L = self.components_ 187 | 188 | # first iteration: we compute variables (including objective and gradient) 189 | # at initialization point 190 | G, objective, total_active = self._loss_grad(X, L, dfG, k, 191 | reg, target_neighbors, 192 | label_inds) 193 | 194 | it = 1 # we already made one iteration 195 | 196 | if self.verbose: 197 | print("iter | objective | objective difference | active constraints", 198 | "| learning rate") 199 | 200 | # main loop 201 | for it in range(2, self.max_iter): 202 | # then at each iteration, we try to find a value of L that has better 203 | # objective than the previous L, following the gradient: 204 | while True: 205 | # the next point next_L to try out is found by a gradient step 206 | L_next = L - learn_rate * G 207 | # we compute the objective at next point 208 | # we copy variables that can be modified by _loss_grad, because if we 209 | # retry we don t want to modify them several times 210 | (G_next, objective_next, total_active_next) = ( 211 | self._loss_grad(X, L_next, dfG, k, reg, target_neighbors, 212 | label_inds)) 213 | assert not np.isnan(objective) 214 | delta_obj = objective_next - objective 215 | if delta_obj > 0: 216 | # if we did not find a better objective, we retry with an L closer to 217 | # the starting point, by decreasing the learning rate (making the 218 | # gradient step smaller) 219 | learn_rate /= 2 220 | else: 221 | # otherwise, if we indeed found a better obj, we get out of the loop 222 | break 223 | # when the better L is found (and the related variables), we set the 224 | # old variables to these new ones before next iteration and we 225 | # slightly increase the learning rate 226 | L = L_next 227 | G, objective, total_active = G_next, objective_next, total_active_next 228 | learn_rate *= 1.01 229 | 230 | if self.verbose: 231 | print(it, objective, delta_obj, total_active, learn_rate) 232 | 233 | # check for convergence 234 | if it > self.min_iter and abs(delta_obj) < self.convergence_tol: 235 | if self.verbose: 236 | print("LMNN converged with objective", objective) 237 | break 238 | else: 239 | if self.verbose: 240 | print("LMNN didn't converge in %d steps." % self.max_iter) 241 | 242 | # store the last L 243 | self.components_ = L 244 | self.n_iter_ = it 245 | return self 246 | 247 | def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds): 248 | # Compute pairwise distances under current metric 249 | Lx = L.dot(X.T).T 250 | 251 | # we need to find the furthest neighbor: 252 | Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) 253 | furthest_neighbors = np.take_along_axis(target_neighbors, 254 | Ni.argmax(axis=1)[:, None], 1) 255 | impostors = self._find_impostors(furthest_neighbors.ravel(), X, 256 | label_inds, L) 257 | 258 | g0 = _inplace_paired_L2(*Lx[impostors]) 259 | 260 | # we reorder the target neighbors 261 | g1, g2 = Ni[impostors] 262 | # compute the gradient 263 | total_active = 0 264 | df = np.zeros((X.shape[1], X.shape[1])) 265 | for nn_idx in reversed(range(k)): # note: reverse not useful here 266 | act1 = g0 < g1[:, nn_idx] 267 | act2 = g0 < g2[:, nn_idx] 268 | total_active += act1.sum() + act2.sum() 269 | 270 | targets = target_neighbors[:, nn_idx] 271 | PLUS, pweight = _count_edges(act1, act2, impostors, targets) 272 | df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) 273 | 274 | in_imp, out_imp = impostors 275 | df -= _sum_outer_products(X, in_imp[act1], out_imp[act1]) 276 | df -= _sum_outer_products(X, in_imp[act2], out_imp[act2]) 277 | 278 | # do the gradient update 279 | assert not np.isnan(df).any() 280 | G = dfG * reg + df * (1 - reg) 281 | G = L.dot(G) 282 | # compute the objective function 283 | objective = total_active * (1 - reg) 284 | objective += G.flatten().dot(L.flatten()) 285 | return 2 * G, objective, total_active 286 | 287 | def _select_targets(self, X, label_inds): 288 | target_neighbors = np.empty((X.shape[0], self.n_neighbors), dtype=int) 289 | for label in self.labels_: 290 | inds, = np.nonzero(label_inds == label) 291 | dd = euclidean_distances(X[inds], squared=True) 292 | np.fill_diagonal(dd, np.inf) 293 | nn = np.argsort(dd)[..., :self.n_neighbors] 294 | target_neighbors[inds] = inds[nn] 295 | return target_neighbors 296 | 297 | def _find_impostors(self, furthest_neighbors, X, label_inds, L): 298 | Lx = X.dot(L.T) 299 | margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) 300 | impostors = [] 301 | for label in self.labels_[:-1]: 302 | in_inds, = np.nonzero(label_inds == label) 303 | out_inds, = np.nonzero(label_inds > label) 304 | dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True) 305 | i1, j1 = np.nonzero(dist < margin_radii[out_inds][:, None]) 306 | i2, j2 = np.nonzero(dist < margin_radii[in_inds]) 307 | i = np.hstack((i1, i2)) 308 | j = np.hstack((j1, j2)) 309 | if i.size > 0: 310 | # get unique (i,j) pairs using index trickery 311 | shape = (i.max() + 1, j.max() + 1) 312 | tmp = np.ravel_multi_index((i, j), shape) 313 | i, j = np.unravel_index(np.unique(tmp), shape) 314 | impostors.append(np.vstack((in_inds[j], out_inds[i]))) 315 | if len(impostors) == 0: 316 | # No impostors detected 317 | return impostors 318 | return np.hstack(impostors) 319 | 320 | 321 | def _inplace_paired_L2(A, B): 322 | '''Equivalent to ((A-B)**2).sum(axis=-1), but modifies A in place.''' 323 | A -= B 324 | return np.einsum('...ij,...ij->...i', A, A) 325 | 326 | 327 | def _count_edges(act1, act2, impostors, targets): 328 | imp = impostors[0, act1] 329 | c = Counter(zip(imp, targets[imp])) 330 | imp = impostors[1, act2] 331 | c.update(zip(imp, targets[imp])) 332 | if c: 333 | active_pairs = np.array(list(c.keys())) 334 | else: 335 | active_pairs = np.empty((0, 2), dtype=int) 336 | return active_pairs, np.array(list(c.values())) 337 | 338 | 339 | def _sum_outer_products(data, a_inds, b_inds, weights=None): 340 | Xab = data[a_inds] - data[b_inds] 341 | if weights is not None: 342 | return np.dot(Xab.T, Xab * weights[:, None]) 343 | return np.dot(Xab.T, Xab) 344 | -------------------------------------------------------------------------------- /metric_learn/lsml.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metric Learning from Relative Comparisons by Minimizing Squared Residual (LSML) 3 | """ 4 | 5 | import numpy as np 6 | import scipy.linalg 7 | from sklearn.base import TransformerMixin 8 | 9 | from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin 10 | from .constraints import Constraints 11 | from ._util import components_from_metric, _initialize_metric_mahalanobis 12 | import warnings 13 | 14 | 15 | class _BaseLSML(MahalanobisMixin): 16 | 17 | _tuple_size = 4 # constraints are quadruplets 18 | 19 | def __init__(self, tol=1e-3, max_iter=1000, prior='identity', 20 | verbose=False, preprocessor=None, random_state=None): 21 | self.prior = prior 22 | self.tol = tol 23 | self.max_iter = max_iter 24 | self.verbose = verbose 25 | self.random_state = random_state 26 | super(_BaseLSML, self).__init__(preprocessor) 27 | 28 | def _fit(self, quadruplets, weights=None): 29 | quadruplets = self._prepare_inputs(quadruplets, 30 | type_of_inputs='tuples') 31 | 32 | # check to make sure that no two constrained vectors are identical 33 | vab = quadruplets[:, 0, :] - quadruplets[:, 1, :] 34 | vcd = quadruplets[:, 2, :] - quadruplets[:, 3, :] 35 | if vab.shape != vcd.shape: 36 | raise ValueError('Constraints must have same length') 37 | if weights is None: 38 | self.w_ = np.ones(vab.shape[0]) 39 | else: 40 | self.w_ = weights 41 | self.w_ /= self.w_.sum() # weights must sum to 1 42 | M, prior_inv = _initialize_metric_mahalanobis( 43 | quadruplets, self.prior, 44 | return_inverse=True, strict_pd=True, matrix_name='prior', 45 | random_state=self.random_state) 46 | 47 | step_sizes = np.logspace(-10, 0, 10) 48 | # Keep track of the best step size and the loss at that step. 49 | l_best = 0 50 | s_best = self._total_loss(M, vab, vcd, prior_inv) 51 | if self.verbose: 52 | print('initial loss', s_best) 53 | for it in range(1, self.max_iter + 1): 54 | grad = self._gradient(M, vab, vcd, prior_inv) 55 | grad_norm = scipy.linalg.norm(grad) 56 | if grad_norm < self.tol: 57 | break 58 | if self.verbose: 59 | print('gradient norm', grad_norm) 60 | M_best = None 61 | for step_size in step_sizes: 62 | step_size /= grad_norm 63 | new_metric = M - step_size * grad 64 | w, v = scipy.linalg.eigh(new_metric) 65 | new_metric = v.dot((np.maximum(w, 1e-8) * v).T) 66 | cur_s = self._total_loss(new_metric, vab, vcd, prior_inv) 67 | if cur_s < s_best: 68 | l_best = step_size 69 | s_best = cur_s 70 | M_best = new_metric 71 | if self.verbose: 72 | print('iter', it, 'cost', s_best, 'best step', l_best * grad_norm) 73 | if M_best is None: 74 | break 75 | M = M_best 76 | else: 77 | if self.verbose: 78 | print("Didn't converge after", it, "iterations. Final loss:", s_best) 79 | self.n_iter_ = it 80 | 81 | self.components_ = components_from_metric(M) 82 | return self 83 | 84 | def _comparison_loss(self, metric, vab, vcd): 85 | dab = np.sum(vab.dot(metric) * vab, axis=1) 86 | dcd = np.sum(vcd.dot(metric) * vcd, axis=1) 87 | violations = dab > dcd 88 | return self.w_[violations].dot((np.sqrt(dab[violations]) - 89 | np.sqrt(dcd[violations]))**2) 90 | 91 | def _total_loss(self, metric, vab, vcd, prior_inv): 92 | # Regularization loss 93 | sign, logdet = np.linalg.slogdet(metric) 94 | reg_loss = np.sum(metric * prior_inv) - sign * logdet 95 | return self._comparison_loss(metric, vab, vcd) + reg_loss 96 | 97 | def _gradient(self, metric, vab, vcd, prior_inv): 98 | dMetric = prior_inv - np.linalg.inv(metric) 99 | dabs = np.sum(vab.dot(metric) * vab, axis=1) 100 | dcds = np.sum(vcd.dot(metric) * vcd, axis=1) 101 | violations = dabs > dcds 102 | # TODO: vectorize 103 | for vab, dab, vcd, dcd in zip(vab[violations], dabs[violations], 104 | vcd[violations], dcds[violations]): 105 | dMetric += ((1 - np.sqrt(dcd / dab)) * np.outer(vab, vab) + 106 | (1 - np.sqrt(dab / dcd)) * np.outer(vcd, vcd)) 107 | return dMetric 108 | 109 | 110 | class LSML(_BaseLSML, _QuadrupletsClassifierMixin): 111 | """Least Squared-residual Metric Learning (LSML) 112 | 113 | `LSML` proposes a simple, yet effective, algorithm that minimizes a convex 114 | objective function corresponding to the sum of squared residuals of 115 | constraints. This algorithm uses the constraints in the form of the 116 | relative distance comparisons, such method is especially useful where 117 | pairwise constraints are not natural to obtain, thus pairwise constraints 118 | based algorithms become infeasible to be deployed. Furthermore, its sparsity 119 | extension leads to more stable estimation when the dimension is high and 120 | only a small amount of constraints is given. 121 | 122 | Read more in the :ref:`User Guide `. 123 | 124 | Parameters 125 | ---------- 126 | prior : string or numpy array, optional (default='identity') 127 | Prior to set for the metric. Possible options are 128 | 'identity', 'covariance', 'random', and a numpy array of 129 | shape (n_features, n_features). For LSML, the prior should be strictly 130 | positive definite (PD). 131 | 132 | 'identity' 133 | An identity matrix of shape (n_features, n_features). 134 | 135 | 'covariance' 136 | The inverse covariance matrix. 137 | 138 | 'random' 139 | The initial Mahalanobis matrix will be a random positive definite 140 | (PD) matrix of shape `(n_features, n_features)`, generated using 141 | `sklearn.datasets.make_spd_matrix`. 142 | 143 | numpy array 144 | A positive definite (PD) matrix of shape 145 | (n_features, n_features), that will be used as such to set the 146 | prior. 147 | 148 | tol : float, optional (default=1e-3) 149 | Convergence tolerance of the optimization procedure. 150 | 151 | max_iter : int, optional (default=1000) 152 | Maximum number of iteration of the optimization procedure. 153 | 154 | verbose : bool, optional (default=False) 155 | If True, prints information while learning 156 | 157 | preprocessor : array-like, shape=(n_samples, n_features) or callable 158 | The preprocessor to call to get tuples from indices. If array-like, 159 | tuples will be formed like this: X[indices]. 160 | 161 | random_state : int or numpy.RandomState or None, optional (default=None) 162 | A pseudo random number generator object or a seed for it if int. If 163 | ``init='random'``, ``random_state`` is used to set the random 164 | prior. 165 | 166 | Attributes 167 | ---------- 168 | n_iter_ : `int` 169 | The number of iterations the solver has run. 170 | 171 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 172 | The linear transformation ``L`` deduced from the learned Mahalanobis 173 | metric (See function `components_from_metric`.) 174 | 175 | Examples 176 | -------- 177 | >>> from metric_learn import LSML 178 | >>> quadruplets = [[[1.2, 7.5], [1.3, 1.5], [6.4, 2.6], [6.2, 9.7]], 179 | >>> [[1.3, 4.5], [3.2, 4.6], [6.2, 5.5], [5.4, 5.4]], 180 | >>> [[3.2, 7.5], [3.3, 1.5], [8.4, 2.6], [8.2, 9.7]], 181 | >>> [[3.3, 4.5], [5.2, 4.6], [8.2, 5.5], [7.4, 5.4]]] 182 | >>> # we want to make closer points where the first feature is close, and 183 | >>> # further if the second feature is close 184 | >>> lsml = LSML() 185 | >>> lsml.fit(quadruplets) 186 | 187 | References 188 | ---------- 189 | .. [1] Liu et al. `Metric Learning from Relative Comparisons by Minimizing 190 | Squared Residual 191 | `_. ICDM 2012. 192 | 193 | .. [2] Code adapted from https://gist.github.com/kcarnold/5439917 194 | 195 | See Also 196 | -------- 197 | metric_learn.LSML : The original weakly-supervised algorithm 198 | 199 | :ref:`supervised_version` : The section of the project documentation 200 | that describes the supervised version of weakly supervised estimators. 201 | """ 202 | 203 | def fit(self, quadruplets, weights=None): 204 | """Learn the LSML model. 205 | 206 | Parameters 207 | ---------- 208 | quadruplets : array-like, shape=(n_constraints, 4, n_features) or \ 209 | (n_constraints, 4) 210 | 3D array-like of quadruplets of points or 2D array of quadruplets of 211 | indicators. In order to supervise the algorithm in the right way, we 212 | should have the four samples ordered in a way such that: 213 | d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) for all 0 <= i < 214 | n_constraints. 215 | 216 | weights : (n_constraints,) array of floats, optional 217 | scale factor for each constraint 218 | 219 | Returns 220 | ------- 221 | self : object 222 | Returns the instance. 223 | """ 224 | return self._fit(quadruplets, weights=weights) 225 | 226 | 227 | class LSML_Supervised(_BaseLSML, TransformerMixin): 228 | """Supervised version of Least Squared-residual Metric Learning (LSML) 229 | 230 | `LSML_Supervised` creates quadruplets from labeled samples by taking two 231 | samples from the same class, and two samples from different classes. 232 | This way it builds quadruplets where the two first points must be more 233 | similar than the two last points. 234 | 235 | Parameters 236 | ---------- 237 | tol : float, optional (default=1e-3) 238 | Convergence tolerance of the optimization procedure. 239 | 240 | max_iter : int, optional (default=1000) 241 | Number of maximum iterations of the optimization procedure. 242 | 243 | prior : string or numpy array, optional (default='identity') 244 | Prior to set for the metric. Possible options are 245 | 'identity', 'covariance', 'random', and a numpy array of 246 | shape (n_features, n_features). For LSML, the prior should be strictly 247 | positive definite (PD). 248 | 249 | 'identity' 250 | An identity matrix of shape (n_features, n_features). 251 | 252 | 'covariance' 253 | The inverse covariance matrix. 254 | 255 | 'random' 256 | The initial Mahalanobis matrix will be a random positive definite 257 | (PD) matrix of shape `(n_features, n_features)`, generated using 258 | `sklearn.datasets.make_spd_matrix`. 259 | 260 | numpy array 261 | A positive definite (PD) matrix of shape 262 | (n_features, n_features), that will be used as such to set the 263 | prior. 264 | 265 | n_constraints: int, optional (default=None) 266 | Number of constraints to generate. If None, default to `20 * 267 | num_classes**2`. 268 | 269 | weights : (n_constraints,) array of floats, optional (default=None) 270 | Relative weight given to each constraint. If None, defaults to uniform 271 | weights. 272 | 273 | verbose : bool, optional (default=False) 274 | If True, prints information while learning 275 | 276 | preprocessor : array-like, shape=(n_samples, n_features) or callable 277 | The preprocessor to call to get tuples from indices. If array-like, 278 | tuples will be formed like this: X[indices]. 279 | 280 | random_state : int or numpy.RandomState or None, optional (default=None) 281 | A pseudo random number generator object or a seed for it if int. If 282 | ``init='random'``, ``random_state`` is used to set the random 283 | prior. In any case, `random_state` is also used to randomly sample 284 | constraints from labels. 285 | 286 | num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0 287 | 288 | Examples 289 | -------- 290 | >>> from metric_learn import LSML_Supervised 291 | >>> from sklearn.datasets import load_iris 292 | >>> iris_data = load_iris() 293 | >>> X = iris_data['data'] 294 | >>> Y = iris_data['target'] 295 | >>> lsml = LSML_Supervised(n_constraints=200) 296 | >>> lsml.fit(X, Y) 297 | 298 | Attributes 299 | ---------- 300 | n_iter_ : `int` 301 | The number of iterations the solver has run. 302 | 303 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 304 | The linear transformation ``L`` deduced from the learned Mahalanobis 305 | metric (See function `components_from_metric`.) 306 | """ 307 | 308 | def __init__(self, tol=1e-3, max_iter=1000, prior='identity', 309 | n_constraints=None, weights=None, 310 | verbose=False, preprocessor=None, random_state=None, 311 | num_constraints='deprecated'): 312 | _BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, 313 | verbose=verbose, preprocessor=preprocessor, 314 | random_state=random_state) 315 | if num_constraints != 'deprecated': 316 | warnings.warn('"num_constraints" parameter has been renamed to' 317 | ' "n_constraints". It has been deprecated in' 318 | ' version 0.6.3 and will be removed in 0.7.0' 319 | '', FutureWarning) 320 | self.n_constraints = num_constraints 321 | else: 322 | self.n_constraints = n_constraints 323 | # Avoid test get_params from failing (all params passed sholud be set) 324 | self.num_constraints = 'deprecated' 325 | self.weights = weights 326 | 327 | def fit(self, X, y): 328 | """Create constraints from labels and learn the LSML model. 329 | 330 | Parameters 331 | ---------- 332 | X : (n x d) matrix 333 | Input data, where each row corresponds to a single instance. 334 | 335 | y : (n) array-like 336 | Data labels. 337 | """ 338 | X, y = self._prepare_inputs(X, y, ensure_min_samples=2) 339 | n_constraints = self.n_constraints 340 | if n_constraints is None: 341 | num_classes = len(np.unique(y)) 342 | n_constraints = 20 * num_classes**2 343 | 344 | c = Constraints(y) 345 | pos_neg = c.positive_negative_pairs(n_constraints, same_length=True, 346 | random_state=self.random_state) 347 | return _BaseLSML._fit(self, X[np.column_stack(pos_neg)], 348 | weights=self.weights) 349 | -------------------------------------------------------------------------------- /metric_learn/mlkr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metric Learning for Kernel Regression (MLKR) 3 | """ 4 | import time 5 | import sys 6 | import warnings 7 | import numpy as np 8 | from scipy.optimize import minimize 9 | from scipy.special import logsumexp 10 | from sklearn.base import TransformerMixin 11 | from sklearn.exceptions import ConvergenceWarning 12 | from sklearn.metrics import pairwise_distances 13 | 14 | from .base_metric import MahalanobisMixin 15 | from ._util import _initialize_components, _check_n_components 16 | 17 | EPS = np.finfo(float).eps 18 | 19 | 20 | class MLKR(MahalanobisMixin, TransformerMixin): 21 | """Metric Learning for Kernel Regression (MLKR) 22 | 23 | MLKR is an algorithm for supervised metric learning, which learns a 24 | distance function by directly minimizing the leave-one-out regression error. 25 | This algorithm can also be viewed as a supervised variation of PCA and can be 26 | used for dimensionality reduction and high dimensional data visualization. 27 | 28 | Read more in the :ref:`User Guide `. 29 | 30 | Parameters 31 | ---------- 32 | n_components : int or None, optional (default=None) 33 | Dimensionality of reduced space (if None, defaults to dimension of X). 34 | 35 | init : string or numpy array, optional (default='auto') 36 | Initialization of the linear transformation. Possible options are 37 | 'auto', 'pca', 'identity', 'random', and a numpy array of shape 38 | (n_features_a, n_features_b). 39 | 40 | 'auto' 41 | Depending on ``n_components``, the most reasonable initialization 42 | will be chosen. If ``n_components < min(n_features, n_samples)``, 43 | we use 'pca', as it projects data in meaningful directions (those 44 | of higher variance). Otherwise, we just use 'identity'. 45 | 46 | 'pca' 47 | ``n_components`` principal components of the inputs passed 48 | to :meth:`fit` will be used to initialize the transformation. 49 | (See `sklearn.decomposition.PCA`) 50 | 51 | 'identity' 52 | If ``n_components`` is strictly smaller than the 53 | dimensionality of the inputs passed to :meth:`fit`, the identity 54 | matrix will be truncated to the first ``n_components`` rows. 55 | 56 | 'random' 57 | The initial transformation will be a random array of shape 58 | `(n_components, n_features)`. Each value is sampled from the 59 | standard normal distribution. 60 | 61 | numpy array 62 | n_features_b must match the dimensionality of the inputs passed to 63 | :meth:`fit` and n_features_a must be less than or equal to that. 64 | If ``n_components`` is not None, n_features_a must match it. 65 | 66 | tol : float, optional (default=None) 67 | Convergence tolerance for the optimization. 68 | 69 | max_iter : int, optional (default=1000) 70 | Cap on number of conjugate gradient iterations. 71 | 72 | verbose : bool, optional (default=False) 73 | Whether to print progress messages or not. 74 | 75 | preprocessor : array-like, shape=(n_samples, n_features) or callable 76 | The preprocessor to call to get tuples from indices. If array-like, 77 | tuples will be formed like this: X[indices]. 78 | 79 | random_state : int or numpy.RandomState or None, optional (default=None) 80 | A pseudo random number generator object or a seed for it if int. If 81 | ``init='random'``, ``random_state`` is used to initialize the random 82 | transformation. If ``init='pca'``, ``random_state`` is passed as an 83 | argument to PCA when initializing the transformation. 84 | 85 | Attributes 86 | ---------- 87 | n_iter_ : `int` 88 | The number of iterations the solver has run. 89 | 90 | components_ : `numpy.ndarray`, shape=(n_components, n_features) 91 | The learned linear transformation ``L``. 92 | 93 | Examples 94 | -------- 95 | 96 | >>> from metric_learn import MLKR 97 | >>> from sklearn.datasets import load_iris 98 | >>> iris_data = load_iris() 99 | >>> X = iris_data['data'] 100 | >>> Y = iris_data['target'] 101 | >>> mlkr = MLKR() 102 | >>> mlkr.fit(X, Y) 103 | 104 | References 105 | ---------- 106 | .. [1] K.Q. Weinberger and G. Tesauto. `Metric Learning for Kernel 107 | Regression `_. AISTATS 2007. 109 | """ 110 | 111 | def __init__(self, n_components=None, init='auto', 112 | tol=None, max_iter=1000, verbose=False, 113 | preprocessor=None, random_state=None): 114 | self.n_components = n_components 115 | self.init = init 116 | self.tol = tol 117 | self.max_iter = max_iter 118 | self.verbose = verbose 119 | self.random_state = random_state 120 | super(MLKR, self).__init__(preprocessor) 121 | 122 | def fit(self, X, y): 123 | """ 124 | Fit MLKR model 125 | 126 | Parameters 127 | ---------- 128 | X : (n x d) array of samples 129 | y : (n) data labels 130 | """ 131 | X, y = self._prepare_inputs(X, y, y_numeric=True, 132 | ensure_min_samples=2) 133 | n, d = X.shape 134 | if y.shape[0] != n: 135 | raise ValueError('Data and label lengths mismatch: %d != %d' 136 | % (n, y.shape[0])) 137 | 138 | m = _check_n_components(d, self.n_components) 139 | m = self.n_components 140 | if m is None: 141 | m = d 142 | # if the init is the default (None), we raise a warning 143 | A = _initialize_components(m, X, y, init=self.init, 144 | random_state=self.random_state, 145 | # MLKR works on regression targets: 146 | has_classes=False) 147 | 148 | # Measure the total training time 149 | train_time = time.time() 150 | 151 | self.n_iter_ = 0 152 | res = minimize(self._loss, A.ravel(), (X, y), method='L-BFGS-B', 153 | jac=True, tol=self.tol, 154 | options=dict(maxiter=self.max_iter)) 155 | self.components_ = res.x.reshape(A.shape) 156 | 157 | # Stop timer 158 | train_time = time.time() - train_time 159 | if self.verbose: 160 | cls_name = self.__class__.__name__ 161 | # Warn the user if the algorithm did not converge 162 | if not res.success: 163 | warnings.warn('[{}] MLKR did not converge: {}' 164 | .format(cls_name, res.message), ConvergenceWarning) 165 | print('[{}] Training took {:8.2f}s.'.format(cls_name, train_time)) 166 | 167 | return self 168 | 169 | def _loss(self, flatA, X, y): 170 | 171 | if self.n_iter_ == 0 and self.verbose: 172 | header_fields = ['Iteration', 'Objective Value', 'Time(s)'] 173 | header_fmt = '{:>10} {:>20} {:>10}' 174 | header = header_fmt.format(*header_fields) 175 | cls_name = self.__class__.__name__ 176 | print('[{cls}]'.format(cls=cls_name)) 177 | print('[{cls}] {header}\n[{cls}] {sep}'.format(cls=cls_name, 178 | header=header, 179 | sep='-' * len(header))) 180 | 181 | start_time = time.time() 182 | 183 | A = flatA.reshape((-1, X.shape[1])) 184 | X_embedded = np.dot(X, A.T) 185 | dist = pairwise_distances(X_embedded, squared=True) 186 | np.fill_diagonal(dist, np.inf) 187 | softmax = np.exp(- dist - logsumexp(- dist, axis=1)[:, np.newaxis]) 188 | yhat = softmax.dot(y) 189 | ydiff = yhat - y 190 | cost = (ydiff ** 2).sum() 191 | 192 | # also compute the gradient 193 | W = softmax * ydiff[:, np.newaxis] * (y - yhat[:, np.newaxis]) 194 | W_sym = W + W.T 195 | np.fill_diagonal(W_sym, - W.sum(axis=0)) 196 | grad = 4 * (X_embedded.T.dot(W_sym)).dot(X) 197 | 198 | if self.verbose: 199 | start_time = time.time() - start_time 200 | values_fmt = '[{cls}] {n_iter:>10} {loss:>20.6e} {start_time:>10.2f}' 201 | print(values_fmt.format(cls=self.__class__.__name__, 202 | n_iter=self.n_iter_, loss=cost, 203 | start_time=start_time)) 204 | sys.stdout.flush() 205 | 206 | self.n_iter_ += 1 207 | 208 | return cost, grad.ravel() 209 | -------------------------------------------------------------------------------- /metric_learn/nca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neighborhood Components Analysis (NCA) 3 | """ 4 | 5 | import warnings 6 | import time 7 | import sys 8 | import numpy as np 9 | from scipy.optimize import minimize 10 | from scipy.special import logsumexp 11 | from sklearn.base import TransformerMixin 12 | from sklearn.exceptions import ConvergenceWarning 13 | from sklearn.metrics import pairwise_distances 14 | 15 | from ._util import _initialize_components, _check_n_components 16 | from .base_metric import MahalanobisMixin 17 | 18 | EPS = np.finfo(float).eps 19 | 20 | 21 | class NCA(MahalanobisMixin, TransformerMixin): 22 | """Neighborhood Components Analysis (NCA) 23 | 24 | NCA is a distance metric learning algorithm which aims to improve the 25 | accuracy of nearest neighbors classification compared to the standard 26 | Euclidean distance. The algorithm directly maximizes a stochastic variant 27 | of the leave-one-out k-nearest neighbors(KNN) score on the training set. 28 | It can also learn a low-dimensional linear transformation of data that can 29 | be used for data visualization and fast classification. 30 | 31 | Read more in the :ref:`User Guide `. 32 | 33 | Parameters 34 | ---------- 35 | init : string or numpy array, optional (default='auto') 36 | Initialization of the linear transformation. Possible options are 37 | 'auto', 'pca', 'identity', 'random', and a numpy array of shape 38 | (n_features_a, n_features_b). 39 | 40 | 'auto' 41 | Depending on ``n_components``, the most reasonable initialization 42 | will be chosen. If ``n_components <= n_classes`` we use 'lda', as 43 | it uses labels information. If not, but 44 | ``n_components < min(n_features, n_samples)``, we use 'pca', as 45 | it projects data in meaningful directions (those of higher 46 | variance). Otherwise, we just use 'identity'. 47 | 48 | 'pca' 49 | ``n_components`` principal components of the inputs passed 50 | to :meth:`fit` will be used to initialize the transformation. 51 | (See `sklearn.decomposition.PCA`) 52 | 53 | 'lda' 54 | ``min(n_components, n_classes)`` most discriminative 55 | components of the inputs passed to :meth:`fit` will be used to 56 | initialize the transformation. (If ``n_components > n_classes``, 57 | the rest of the components will be zero.) (See 58 | `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) 59 | 60 | 'identity' 61 | If ``n_components`` is strictly smaller than the 62 | dimensionality of the inputs passed to :meth:`fit`, the identity 63 | matrix will be truncated to the first ``n_components`` rows. 64 | 65 | 'random' 66 | The initial transformation will be a random array of shape 67 | `(n_components, n_features)`. Each value is sampled from the 68 | standard normal distribution. 69 | 70 | numpy array 71 | n_features_b must match the dimensionality of the inputs passed to 72 | :meth:`fit` and n_features_a must be less than or equal to that. 73 | If ``n_components`` is not None, n_features_a must match it. 74 | 75 | n_components : int or None, optional (default=None) 76 | Dimensionality of reduced space (if None, defaults to dimension of X). 77 | 78 | max_iter : int, optional (default=100) 79 | Maximum number of iterations done by the optimization algorithm. 80 | 81 | tol : float, optional (default=None) 82 | Convergence tolerance for the optimization. 83 | 84 | verbose : bool, optional (default=False) 85 | Whether to print progress messages or not. 86 | 87 | random_state : int or numpy.RandomState or None, optional (default=None) 88 | A pseudo random number generator object or a seed for it if int. If 89 | ``init='random'``, ``random_state`` is used to initialize the random 90 | transformation. If ``init='pca'``, ``random_state`` is passed as an 91 | argument to PCA when initializing the transformation. 92 | 93 | Examples 94 | -------- 95 | 96 | >>> import numpy as np 97 | >>> from metric_learn import NCA 98 | >>> from sklearn.datasets import load_iris 99 | >>> iris_data = load_iris() 100 | >>> X = iris_data['data'] 101 | >>> Y = iris_data['target'] 102 | >>> nca = NCA(max_iter=1000) 103 | >>> nca.fit(X, Y) 104 | 105 | Attributes 106 | ---------- 107 | n_iter_ : `int` 108 | The number of iterations the solver has run. 109 | 110 | components_ : `numpy.ndarray`, shape=(n_components, n_features) 111 | The learned linear transformation ``L``. 112 | 113 | References 114 | ---------- 115 | .. [1] J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov. `Neighbourhood 116 | Components Analysis 117 | `_. 118 | NIPS 2005. 119 | 120 | .. [2] Wikipedia entry on `Neighborhood Components Analysis 121 | `_ 122 | """ 123 | 124 | def __init__(self, init='auto', n_components=None, 125 | max_iter=100, tol=None, verbose=False, preprocessor=None, 126 | random_state=None): 127 | self.n_components = n_components 128 | self.init = init 129 | self.max_iter = max_iter 130 | self.tol = tol 131 | self.verbose = verbose 132 | self.random_state = random_state 133 | super(NCA, self).__init__(preprocessor) 134 | 135 | def fit(self, X, y): 136 | """ 137 | X: data matrix, (n x d) 138 | y: scalar labels, (n) 139 | """ 140 | X, labels = self._prepare_inputs(X, y, ensure_min_samples=2) 141 | n, d = X.shape 142 | n_components = _check_n_components(d, self.n_components) 143 | 144 | # Measure the total training time 145 | train_time = time.time() 146 | 147 | # Initialize A 148 | A = _initialize_components(n_components, X, labels, self.init, 149 | self.verbose, self.random_state) 150 | 151 | # Run NCA 152 | mask = labels[:, np.newaxis] == labels[np.newaxis, :] 153 | optimizer_params = {'method': 'L-BFGS-B', 154 | 'fun': self._loss_grad_lbfgs, 155 | 'args': (X, mask, -1.0), 156 | 'jac': True, 157 | 'x0': A.ravel(), 158 | 'options': dict(maxiter=self.max_iter), 159 | 'tol': self.tol 160 | } 161 | 162 | # Call the optimizer 163 | self.n_iter_ = 0 164 | opt_result = minimize(**optimizer_params) 165 | 166 | self.components_ = opt_result.x.reshape(-1, X.shape[1]) 167 | self.n_iter_ = opt_result.nit 168 | 169 | # Stop timer 170 | train_time = time.time() - train_time 171 | if self.verbose: 172 | cls_name = self.__class__.__name__ 173 | 174 | # Warn the user if the algorithm did not converge 175 | if not opt_result.success: 176 | warnings.warn('[{}] NCA did not converge: {}'.format( 177 | cls_name, opt_result.message), ConvergenceWarning) 178 | 179 | print('[{}] Training took {:8.2f}s.'.format(cls_name, train_time)) 180 | 181 | return self 182 | 183 | def _loss_grad_lbfgs(self, A, X, mask, sign=1.0): 184 | 185 | if self.n_iter_ == 0 and self.verbose: 186 | header_fields = ['Iteration', 'Objective Value', 'Time(s)'] 187 | header_fmt = '{:>10} {:>20} {:>10}' 188 | header = header_fmt.format(*header_fields) 189 | cls_name = self.__class__.__name__ 190 | print('[{cls}]'.format(cls=cls_name)) 191 | print('[{cls}] {header}\n[{cls}] {sep}'.format(cls=cls_name, 192 | header=header, 193 | sep='-' * len(header))) 194 | 195 | start_time = time.time() 196 | 197 | A = A.reshape(-1, X.shape[1]) 198 | X_embedded = np.dot(X, A.T) # (n_samples, n_components) 199 | # Compute softmax distances 200 | p_ij = pairwise_distances(X_embedded, squared=True) 201 | np.fill_diagonal(p_ij, np.inf) 202 | p_ij = np.exp(-p_ij - logsumexp(-p_ij, axis=1)[:, np.newaxis]) 203 | # (n_samples, n_samples) 204 | 205 | # Compute loss 206 | masked_p_ij = p_ij * mask 207 | p = masked_p_ij.sum(axis=1, keepdims=True) # (n_samples, 1) 208 | loss = p.sum() 209 | 210 | # Compute gradient of loss w.r.t. `transform` 211 | weighted_p_ij = masked_p_ij - p_ij * p 212 | weighted_p_ij_sym = weighted_p_ij + weighted_p_ij.T 213 | np.fill_diagonal(weighted_p_ij_sym, - weighted_p_ij.sum(axis=0)) 214 | gradient = 2 * (X_embedded.T.dot(weighted_p_ij_sym)).dot(X) 215 | 216 | if self.verbose: 217 | start_time = time.time() - start_time 218 | values_fmt = '[{cls}] {n_iter:>10} {loss:>20.6e} {start_time:>10.2f}' 219 | print(values_fmt.format(cls=self.__class__.__name__, 220 | n_iter=self.n_iter_, loss=loss, 221 | start_time=start_time)) 222 | sys.stdout.flush() 223 | 224 | self.n_iter_ += 1 225 | return sign * loss, sign * gradient.ravel() 226 | -------------------------------------------------------------------------------- /metric_learn/rca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relative Components Analysis (RCA) 3 | """ 4 | 5 | import numpy as np 6 | import warnings 7 | from sklearn.base import TransformerMixin 8 | 9 | from ._util import _check_n_components 10 | from .base_metric import MahalanobisMixin 11 | from .constraints import Constraints 12 | 13 | 14 | # mean center each chunklet separately 15 | def _chunk_mean_centering(data, chunks): 16 | n_chunks = chunks.max() + 1 17 | chunk_mask = chunks != -1 18 | # We need to ensure the data is float so that we can substract the 19 | # mean on it 20 | chunk_data = data[chunk_mask].astype(float, copy=False) 21 | chunk_labels = chunks[chunk_mask] 22 | for c in range(n_chunks): 23 | mask = chunk_labels == c 24 | chunk_data[mask] -= chunk_data[mask].mean(axis=0) 25 | 26 | return chunk_mask, chunk_data 27 | 28 | 29 | class RCA(MahalanobisMixin, TransformerMixin): 30 | """Relevant Components Analysis (RCA) 31 | 32 | RCA learns a full rank Mahalanobis distance metric based on a weighted sum of 33 | in-chunklets covariance matrices. It applies a global linear transformation 34 | to assign large weights to relevant dimensions and low weights to irrelevant 35 | dimensions. Those relevant dimensions are estimated using "chunklets", 36 | subsets of points that are known to belong to the same class. 37 | 38 | Read more in the :ref:`User Guide `. 39 | 40 | Parameters 41 | ---------- 42 | n_components : int or None, optional (default=None) 43 | Dimensionality of reduced space (if None, defaults to dimension of X). 44 | 45 | preprocessor : array-like, shape=(n_samples, n_features) or callable 46 | The preprocessor to call to get tuples from indices. If array-like, 47 | tuples will be formed like this: X[indices]. 48 | 49 | Examples 50 | -------- 51 | >>> from metric_learn import RCA 52 | >>> X = [[-0.05, 3.0],[0.05, -3.0], 53 | >>> [0.1, -3.55],[-0.1, 3.55], 54 | >>> [-0.95, -0.05],[0.95, 0.05], 55 | >>> [0.4, 0.05],[-0.4, -0.05]] 56 | >>> chunks = [0, 0, 1, 1, 2, 2, 3, 3] 57 | >>> rca = RCA() 58 | >>> rca.fit(X, chunks) 59 | 60 | References 61 | ---------- 62 | .. [1] Noam Shental, et al. `Adjustment learning and relevant component 63 | analysis `_ . 65 | ECCV 2002. 66 | 67 | 68 | Attributes 69 | ---------- 70 | components_ : `numpy.ndarray`, shape=(n_components, n_features) 71 | The learned linear transformation ``L``. 72 | """ 73 | 74 | def __init__(self, n_components=None, preprocessor=None): 75 | self.n_components = n_components 76 | super(RCA, self).__init__(preprocessor) 77 | 78 | def _check_dimension(self, rank, X): 79 | d = X.shape[1] 80 | 81 | if rank < d: 82 | warnings.warn('The inner covariance matrix is not invertible, ' 83 | 'so the transformation matrix may contain Nan values. ' 84 | 'You should remove any linearly dependent features and/or ' 85 | 'reduce the dimensionality of your input, ' 86 | 'for instance using `sklearn.decomposition.PCA` as a ' 87 | 'preprocessing step.') 88 | 89 | dim = _check_n_components(d, self.n_components) 90 | return dim 91 | 92 | def fit(self, X, chunks): 93 | """Learn the RCA model. 94 | 95 | Parameters 96 | ---------- 97 | data : (n x d) data matrix 98 | Each row corresponds to a single instance 99 | 100 | chunks : (n,) array of ints 101 | When ``chunks[i] == -1``, point i doesn't belong to any chunklet. 102 | When ``chunks[i] == j``, point i belongs to chunklet j. 103 | """ 104 | X, chunks = self._prepare_inputs(X, chunks, ensure_min_samples=2) 105 | 106 | chunks = np.asanyarray(chunks, dtype=int) 107 | chunk_mask, chunked_data = _chunk_mean_centering(X, chunks) 108 | 109 | inner_cov = np.atleast_2d(np.cov(chunked_data, rowvar=0, bias=1)) 110 | dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X) 111 | 112 | # Fisher Linear Discriminant projection 113 | if dim < X.shape[1]: 114 | total_cov = np.cov(X[chunk_mask], rowvar=0) 115 | tmp = np.linalg.lstsq(total_cov, inner_cov, rcond=None)[0] 116 | vals, vecs = np.linalg.eig(tmp) 117 | inds = np.argsort(vals)[:dim] 118 | A = vecs[:, inds] 119 | inner_cov = np.atleast_2d(A.T.dot(inner_cov).dot(A)) 120 | self.components_ = _inv_sqrtm(inner_cov).dot(A.T) 121 | else: 122 | self.components_ = _inv_sqrtm(inner_cov).T 123 | 124 | return self 125 | 126 | 127 | def _inv_sqrtm(x): 128 | '''Computes x^(-1/2)''' 129 | vals, vecs = np.linalg.eigh(x) 130 | return (vecs / np.sqrt(vals)).dot(vecs.T) 131 | 132 | 133 | class RCA_Supervised(RCA): 134 | """Supervised version of Relevant Components Analysis (RCA) 135 | 136 | `RCA_Supervised` creates chunks of similar points by first sampling a 137 | class, taking `chunk_size` elements in it, and repeating the process 138 | `n_chunks` times. 139 | 140 | Parameters 141 | ---------- 142 | n_components : int or None, optional (default=None) 143 | Dimensionality of reduced space (if None, defaults to dimension of X). 144 | 145 | n_chunks: int, optional (default=100) 146 | Number of chunks to generate. 147 | 148 | chunk_size: int, optional (default=2) 149 | Number of points per chunk. 150 | 151 | preprocessor : array-like, shape=(n_samples, n_features) or callable 152 | The preprocessor to call to get tuples from indices. If array-like, 153 | tuples will be formed like this: X[indices]. 154 | 155 | random_state : int or numpy.RandomState or None, optional (default=None) 156 | A pseudo random number generator object or a seed for it if int. 157 | It is used to randomly sample constraints from labels. 158 | 159 | num_chunks : Renamed to n_chunks. Will be deprecated in 0.7.0 160 | 161 | Examples 162 | -------- 163 | >>> from metric_learn import RCA_Supervised 164 | >>> from sklearn.datasets import load_iris 165 | >>> iris_data = load_iris() 166 | >>> X = iris_data['data'] 167 | >>> Y = iris_data['target'] 168 | >>> rca = RCA_Supervised(n_chunks=30, chunk_size=2) 169 | >>> rca.fit(X, Y) 170 | 171 | Attributes 172 | ---------- 173 | components_ : `numpy.ndarray`, shape=(n_components, n_features) 174 | The learned linear transformation ``L``. 175 | """ 176 | 177 | def __init__(self, n_components=None, n_chunks=100, chunk_size=2, 178 | preprocessor=None, random_state=None, 179 | num_chunks='deprecated'): 180 | """Initialize the supervised version of `RCA`.""" 181 | RCA.__init__(self, n_components=n_components, preprocessor=preprocessor) 182 | if num_chunks != 'deprecated': 183 | warnings.warn('"num_chunks" parameter has been renamed to' 184 | ' "n_chunks". It has been deprecated in' 185 | ' version 0.6.3 and will be removed in 0.7.0' 186 | '', FutureWarning) 187 | n_chunks = num_chunks 188 | self.num_chunks = 'deprecated' # To avoid no_attribute error 189 | self.n_chunks = n_chunks 190 | self.chunk_size = chunk_size 191 | self.random_state = random_state 192 | 193 | def fit(self, X, y): 194 | """Create constraints from labels and learn the RCA model. 195 | Needs n_constraints specified in constructor. (Not true?) 196 | 197 | Parameters 198 | ---------- 199 | X : (n x d) data matrix 200 | each row corresponds to a single instance 201 | 202 | y : (n) data labels 203 | """ 204 | X, y = self._prepare_inputs(X, y, ensure_min_samples=2) 205 | chunks = Constraints(y).chunks(n_chunks=self.n_chunks, 206 | chunk_size=self.chunk_size, 207 | random_state=self.random_state) 208 | 209 | if self.n_chunks * (self.chunk_size - 1) < X.shape[1]: 210 | warnings.warn('Due to the parameters of RCA_Supervised, ' 211 | 'the inner covariance matrix is not invertible, ' 212 | 'so the transformation matrix will contain Nan values. ' 213 | 'Increase the number or size of the chunks to correct ' 214 | 'this problem.' 215 | ) 216 | 217 | return RCA.fit(self, X, chunks) 218 | -------------------------------------------------------------------------------- /metric_learn/sdml.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sparse High-Dimensional Metric Learning (SDML) 3 | """ 4 | 5 | import warnings 6 | import numpy as np 7 | from sklearn.base import TransformerMixin 8 | from scipy.linalg import pinvh 9 | try: 10 | from sklearn.covariance._graph_lasso import ( 11 | _graphical_lasso as graphical_lasso 12 | ) 13 | except ImportError: 14 | from sklearn.covariance import graphical_lasso 15 | 16 | from sklearn.exceptions import ConvergenceWarning 17 | 18 | from .base_metric import MahalanobisMixin, _PairsClassifierMixin 19 | from .constraints import Constraints, wrap_pairs 20 | from ._util import components_from_metric, _initialize_metric_mahalanobis 21 | try: 22 | from inverse_covariance import quic 23 | except ImportError: 24 | HAS_SKGGM = False 25 | else: 26 | HAS_SKGGM = True 27 | 28 | 29 | class _BaseSDML(MahalanobisMixin): 30 | 31 | _tuple_size = 2 # constraints are pairs 32 | 33 | def __init__(self, balance_param=0.5, sparsity_param=0.01, prior='identity', 34 | verbose=False, preprocessor=None, 35 | random_state=None): 36 | self.balance_param = balance_param 37 | self.sparsity_param = sparsity_param 38 | self.prior = prior 39 | self.verbose = verbose 40 | self.random_state = random_state 41 | super(_BaseSDML, self).__init__(preprocessor) 42 | 43 | def _fit(self, pairs, y): 44 | if not HAS_SKGGM: 45 | if self.verbose: 46 | print("SDML will use scikit-learn's graphical lasso solver.") 47 | else: 48 | if self.verbose: 49 | print("SDML will use skggm's graphical lasso solver.") 50 | pairs, y = self._prepare_inputs(pairs, y, 51 | type_of_inputs='tuples') 52 | n_features = pairs.shape[2] 53 | if n_features < 2: 54 | raise ValueError(f"Cannot fit SDML with {n_features} feature(s)") 55 | 56 | # set up (the inverse of) the prior M 57 | # if the prior is the default (None), we raise a warning 58 | _, prior_inv = _initialize_metric_mahalanobis( 59 | pairs, self.prior, 60 | return_inverse=True, strict_pd=True, matrix_name='prior', 61 | random_state=self.random_state) 62 | diff = pairs[:, 0] - pairs[:, 1] 63 | loss_matrix = (diff.T * y).dot(diff) 64 | emp_cov = prior_inv + self.balance_param * loss_matrix 65 | 66 | # our initialization will be the matrix with emp_cov's eigenvalues, 67 | # with a constant added so that they are all positive (plus an epsilon 68 | # to ensure definiteness). This is empirical. 69 | w, V = np.linalg.eigh(emp_cov) 70 | min_eigval = np.min(w) 71 | if min_eigval < 0.: 72 | warnings.warn("Warning, the input matrix of graphical lasso is not " 73 | "positive semi-definite (PSD). The algorithm may diverge, " 74 | "and lead to degenerate solutions. " 75 | "To prevent that, try to decrease the balance parameter " 76 | "`balance_param` and/or to set prior='identity'.", 77 | ConvergenceWarning) 78 | w -= min_eigval # we translate the eigenvalues to make them all positive 79 | w += 1e-10 # we add a small offset to avoid definiteness problems 80 | sigma0 = (V * w).dot(V.T) 81 | try: 82 | if HAS_SKGGM: 83 | theta0 = pinvh(sigma0) 84 | M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param, 85 | msg=self.verbose, 86 | Theta0=theta0, Sigma0=sigma0) 87 | else: 88 | _, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param, 89 | verbose=self.verbose, 90 | cov_init=sigma0) 91 | raised_error = None 92 | w_mahalanobis, _ = np.linalg.eigh(M) 93 | not_spd = any(w_mahalanobis < 0.) 94 | not_finite = not np.isfinite(M).all() 95 | # TODO: Narrow this to the specific exceptions we expect. 96 | except Exception as e: 97 | raised_error = e 98 | not_spd = False # not_spd not applicable here so we set to False 99 | not_finite = False # not_finite not applicable here so we set to False 100 | if raised_error is not None or not_spd or not_finite: 101 | msg = ("There was a problem in SDML when using {}'s graphical " 102 | "lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn") 103 | if not HAS_SKGGM: 104 | skggm_advice = (" skggm's graphical lasso can sometimes converge " 105 | "on non SPD cases where scikit-learn's graphical " 106 | "lasso fails to converge. Try to install skggm and " 107 | "rerun the algorithm (see the README.md for the " 108 | "right version of skggm).") 109 | msg += skggm_advice 110 | if raised_error is not None: 111 | msg += " The following error message was thrown: {}.".format( 112 | raised_error) 113 | raise RuntimeError(msg) 114 | 115 | self.components_ = components_from_metric(np.atleast_2d(M)) 116 | return self 117 | 118 | 119 | class SDML(_BaseSDML, _PairsClassifierMixin): 120 | r"""Sparse Distance Metric Learning (SDML) 121 | 122 | SDML is an efficient sparse metric learning in high-dimensional space via 123 | double regularization: an L1-penalization on the off-diagonal elements of the 124 | Mahalanobis matrix :math:`\mathbf{M}`, and a log-determinant divergence 125 | between :math:`\mathbf{M}` and :math:`\mathbf{M_0}` (set as either 126 | :math:`\mathbf{I}` or :math:`\mathbf{\Omega}^{-1}`, where 127 | :math:`\mathbf{\Omega}` is the covariance matrix). 128 | 129 | Read more in the :ref:`User Guide `. 130 | 131 | Parameters 132 | ---------- 133 | balance_param : float, optional (default=0.5) 134 | Trade off between sparsity and M0 prior. 135 | 136 | sparsity_param : float, optional (default=0.01) 137 | Trade off between optimizer and sparseness (see graph_lasso). 138 | 139 | prior : string or numpy array, optional (default='identity') 140 | Prior to set for the metric. Possible options are 141 | 'identity', 'covariance', 'random', and a numpy array of 142 | shape (n_features, n_features). For SDML, the prior should be strictly 143 | positive definite (PD). 144 | 145 | 'identity' 146 | An identity matrix of shape (n_features, n_features). 147 | 148 | 'covariance' 149 | The inverse covariance matrix. 150 | 151 | 'random' 152 | The prior will be a random positive definite (PD) matrix of shape 153 | `(n_features, n_features)`, generated using 154 | `sklearn.datasets.make_spd_matrix`. 155 | 156 | numpy array 157 | A positive definite (PD) matrix of shape 158 | (n_features, n_features), that will be used as such to set the 159 | prior. 160 | 161 | verbose : bool, optional (default=False) 162 | If True, prints information while learning. 163 | 164 | preprocessor : array-like, shape=(n_samples, n_features) or callable 165 | The preprocessor to call to get tuples from indices. If array-like, 166 | tuples will be gotten like this: X[indices]. 167 | 168 | random_state : int or numpy.RandomState or None, optional (default=None) 169 | A pseudo random number generator object or a seed for it if int. If 170 | ``prior='random'``, ``random_state`` is used to set the prior. 171 | 172 | Attributes 173 | ---------- 174 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 175 | The linear transformation ``L`` deduced from the learned Mahalanobis 176 | metric (See function `components_from_metric`.) 177 | 178 | threshold_ : `float` 179 | If the distance metric between two points is lower than this threshold, 180 | points will be classified as similar, otherwise they will be 181 | classified as dissimilar. 182 | 183 | Examples 184 | -------- 185 | >>> from metric_learn import SDML_Supervised 186 | >>> from sklearn.datasets import load_iris 187 | >>> iris_data = load_iris() 188 | >>> X = iris_data['data'] 189 | >>> Y = iris_data['target'] 190 | >>> sdml = SDML_Supervised(n_constraints=200) 191 | >>> sdml.fit(X, Y) 192 | 193 | References 194 | ---------- 195 | .. [1] Qi et al. `An efficient sparse metric learning in high-dimensional 196 | space via L1-penalized log-determinant regularization 197 | `_. 198 | ICML 2009. 199 | 200 | .. [2] Code adapted from https://gist.github.com/kcarnold/5439945 201 | """ 202 | 203 | def fit(self, pairs, y, calibration_params=None): 204 | """Learn the SDML model. 205 | 206 | The threshold will be calibrated on the trainset using the parameters 207 | `calibration_params`. 208 | 209 | Parameters 210 | ---------- 211 | pairs : array-like, shape=(n_constraints, 2, n_features) or \ 212 | (n_constraints, 2) 213 | 3D Array of pairs with each row corresponding to two points, 214 | or 2D array of indices of pairs if the metric learner uses a 215 | preprocessor. 216 | 217 | y : array-like, of shape (n_constraints,) 218 | Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. 219 | 220 | calibration_params : `dict` or `None` 221 | Dictionary of parameters to give to `calibrate_threshold` for the 222 | threshold calibration step done at the end of `fit`. If `None` is 223 | given, `calibrate_threshold` will use the default parameters. 224 | 225 | Returns 226 | ------- 227 | self : object 228 | Returns the instance. 229 | """ 230 | calibration_params = (calibration_params if calibration_params is not 231 | None else dict()) 232 | self._validate_calibration_params(**calibration_params) 233 | self._fit(pairs, y) 234 | self.calibrate_threshold(pairs, y, **calibration_params) 235 | return self 236 | 237 | 238 | class SDML_Supervised(_BaseSDML, TransformerMixin): 239 | """Supervised version of Sparse Distance Metric Learning (SDML) 240 | 241 | `SDML_Supervised` creates pairs of similar sample by taking same class 242 | samples, and pairs of dissimilar samples by taking different class 243 | samples. It then passes these pairs to `SDML` for training. 244 | 245 | Parameters 246 | ---------- 247 | balance_param : float, optional (default=0.5) 248 | Trade off between sparsity and M0 prior. 249 | 250 | sparsity_param : float, optional (default=0.01) 251 | Trade off between optimizer and sparseness (see graph_lasso). 252 | 253 | prior : string or numpy array, optional (default='identity') 254 | Prior to set for the metric. Possible options are 255 | 'identity', 'covariance', 'random', and a numpy array of 256 | shape (n_features, n_features). For SDML, the prior should be strictly 257 | positive definite (PD). 258 | 259 | 'identity' 260 | An identity matrix of shape (n_features, n_features). 261 | 262 | 'covariance' 263 | The inverse covariance matrix. 264 | 265 | 'random' 266 | The prior will be a random SPD matrix of shape 267 | `(n_features, n_features)`, generated using 268 | `sklearn.datasets.make_spd_matrix`. 269 | 270 | numpy array 271 | A positive definite (PD) matrix of shape 272 | (n_features, n_features), that will be used as such to set the 273 | prior. 274 | 275 | n_constraints : int, optional (default=None) 276 | Number of constraints to generate. If None, defaults to `20 * 277 | num_classes**2`. 278 | 279 | verbose : bool, optional (default=False) 280 | If True, prints information while learning. 281 | 282 | preprocessor : array-like, shape=(n_samples, n_features) or callable 283 | The preprocessor to call to get tuples from indices. If array-like, 284 | tuples will be formed like this: X[indices]. 285 | 286 | random_state : int or numpy.RandomState or None, optional (default=None) 287 | A pseudo random number generator object or a seed for it if int. If 288 | ``init='random'``, ``random_state`` is used to set the random 289 | prior. In any case, `random_state` is also used to randomly sample 290 | constraints from labels. 291 | 292 | num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0 293 | 294 | Attributes 295 | ---------- 296 | components_ : `numpy.ndarray`, shape=(n_features, n_features) 297 | The linear transformation ``L`` deduced from the learned Mahalanobis 298 | metric (See function `components_from_metric`.) 299 | 300 | See Also 301 | -------- 302 | metric_learn.SDML : The original weakly-supervised algorithm 303 | :ref:`supervised_version` : The section of the project documentation 304 | that describes the supervised version of weakly supervised estimators. 305 | """ 306 | 307 | def __init__(self, balance_param=0.5, sparsity_param=0.01, prior='identity', 308 | n_constraints=None, verbose=False, preprocessor=None, 309 | random_state=None, num_constraints='deprecated'): 310 | _BaseSDML.__init__(self, balance_param=balance_param, 311 | sparsity_param=sparsity_param, prior=prior, 312 | verbose=verbose, 313 | preprocessor=preprocessor, random_state=random_state) 314 | if num_constraints != 'deprecated': 315 | warnings.warn('"num_constraints" parameter has been renamed to' 316 | ' "n_constraints". It has been deprecated in' 317 | ' version 0.6.3 and will be removed in 0.7.0' 318 | '', FutureWarning) 319 | self.n_constraints = num_constraints 320 | else: 321 | self.n_constraints = n_constraints 322 | # Avoid test get_params from failing (all params passed sholud be set) 323 | self.num_constraints = 'deprecated' 324 | 325 | def fit(self, X, y): 326 | """Create constraints from labels and learn the SDML model. 327 | 328 | Parameters 329 | ---------- 330 | X : array-like, shape (n, d) 331 | data matrix, where each row corresponds to a single instance 332 | 333 | y : array-like, shape (n,) 334 | data labels, one for each instance 335 | 336 | Returns 337 | ------- 338 | self : object 339 | Returns the instance. 340 | """ 341 | X, y = self._prepare_inputs(X, y, ensure_min_samples=2) 342 | n_constraints = self.n_constraints 343 | if n_constraints is None: 344 | num_classes = len(np.unique(y)) 345 | n_constraints = 20 * num_classes**2 346 | 347 | c = Constraints(y) 348 | pos_neg = c.positive_negative_pairs(n_constraints, 349 | random_state=self.random_state) 350 | pairs, y = wrap_pairs(X, pos_neg) 351 | return _BaseSDML._fit(self, pairs, y) 352 | -------------------------------------------------------------------------------- /metric_learn/sklearn_shims.py: -------------------------------------------------------------------------------- 1 | """This file is for fixing imports due to different APIs 2 | depending on the scikit-learn version""" 3 | import sklearn 4 | from packaging import version 5 | SKLEARN_AT_LEAST_0_22 = (version.parse(sklearn.__version__) 6 | >= version.parse('0.22.0')) 7 | if SKLEARN_AT_LEAST_0_22: 8 | from sklearn.utils._testing import (set_random_state, 9 | ignore_warnings, 10 | assert_allclose_dense_sparse, 11 | _get_args) 12 | from sklearn.utils.estimator_checks import (_is_public_parameter 13 | as is_public_parameter) 14 | from sklearn.metrics._scorer import get_scorer 15 | else: 16 | from sklearn.utils.testing import (set_random_state, 17 | ignore_warnings, 18 | assert_allclose_dense_sparse, 19 | _get_args) 20 | from sklearn.utils.estimator_checks import is_public_parameter 21 | from sklearn.metrics.scorer import get_scorer 22 | 23 | __all__ = ['set_random_state', 'set_random_state', 24 | 'ignore_warnings', 'assert_allclose_dense_sparse', '_get_args', 25 | 'is_public_parameter', 'get_scorer'] 26 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | integration: mark a test as integration 4 | unit: mark a test as unit -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | 4 | [metadata] 5 | description-file = README.rst 6 | license_files = 7 | LICENSE.txt 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from setuptools import setup 4 | import os 5 | import io 6 | import sys 7 | 8 | 9 | CURRENT_PYTHON = sys.version_info[:2] 10 | REQUIRED_PYTHON = (3, 6) 11 | 12 | # This check and everything above must remain compatible with Python 2.7. 13 | if CURRENT_PYTHON < REQUIRED_PYTHON: 14 | sys.stderr.write(""" 15 | ========================== 16 | Unsupported Python version 17 | ========================== 18 | This version of metric-learn requires Python {}.{}, but you're trying to 19 | install it on Python {}.{}. 20 | This may be because you are using a version of pip that doesn't 21 | understand the python_requires classifier. Make sure you 22 | have pip >= 9.0 and setuptools >= 24.2, then try again: 23 | $ python -m pip install --upgrade pip setuptools 24 | $ python -m pip install django 25 | This will install the latest version of metric-learn which works on your 26 | version of Python. If you can't upgrade your pip (or Python), request 27 | an older version of metric-learn: 28 | $ python -m pip install "metric-learn<0.6.0" 29 | """.format(*(REQUIRED_PYTHON + CURRENT_PYTHON))) 30 | sys.exit(1) 31 | 32 | 33 | version = {} 34 | with io.open(os.path.join('metric_learn', '_version.py')) as fp: 35 | exec(fp.read(), version) 36 | 37 | # Get the long description from README.md 38 | with io.open('README.rst', encoding='utf-8') as f: 39 | long_description = f.read() 40 | 41 | setup(name='metric-learn', 42 | version=version['__version__'], 43 | description='Python implementations of metric learning algorithms', 44 | long_description=long_description, 45 | python_requires='>={}.{}'.format(*REQUIRED_PYTHON), 46 | author=[ 47 | 'CJ Carey', 48 | 'Yuan Tang', 49 | 'William de Vazelhes', 50 | 'Aurélien Bellet', 51 | 'Nathalie Vauquier' 52 | ], 53 | author_email='ccarey@cs.umass.edu', 54 | url='http://github.com/scikit-learn-contrib/metric-learn', 55 | license='MIT', 56 | classifiers=[ 57 | 'Development Status :: 4 - Beta', 58 | 'License :: OSI Approved :: MIT License', 59 | 'Programming Language :: Python :: 3', 60 | 'Operating System :: OS Independent', 61 | 'Intended Audience :: Science/Research', 62 | 'Topic :: Scientific/Engineering' 63 | ], 64 | packages=['metric_learn'], 65 | install_requires=[ 66 | 'numpy>= 1.11.0', 67 | 'scipy>= 0.17.0', 68 | 'scikit-learn>=0.21.3', 69 | ], 70 | extras_require=dict( 71 | docs=['sphinx', 'sphinx_rtd_theme', 'numpydoc', 'sphinx-gallery', 72 | 'matplotlib'], 73 | demo=['matplotlib'], 74 | sdml=['skggm>=0.2.9'] 75 | ), 76 | test_suite='test', 77 | keywords=[ 78 | 'Metric Learning', 79 | 'Large Margin Nearest Neighbor', 80 | 'Information Theoretic Metric Learning', 81 | 'Sparse Determinant Metric Learning', 82 | 'Least Squares Metric Learning', 83 | 'Neighborhood Components Analysis', 84 | 'Local Fisher Discriminant Analysis', 85 | 'Relative Components Analysis', 86 | 'Mahalanobis Metric for Clustering', 87 | 'Metric Learning for Kernel Regression' 88 | ]) 89 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/metric-learn/dc7e4499b1a9e522f03c87ba8dc249f9747cac82/test/__init__.py -------------------------------------------------------------------------------- /test/test_base_metric.py: -------------------------------------------------------------------------------- 1 | from numpy.core.numeric import array_equal 2 | import warnings 3 | import pytest 4 | import re 5 | import unittest 6 | import metric_learn 7 | import numpy as np 8 | from sklearn import clone 9 | from test.test_utils import ids_metric_learners, metric_learners, remove_y 10 | from metric_learn.sklearn_shims import set_random_state, SKLEARN_AT_LEAST_0_22 11 | 12 | 13 | def remove_spaces(s): 14 | return re.sub(r'\s+', '', s) 15 | 16 | 17 | def sk_repr_kwargs(def_kwargs, nndef_kwargs): 18 | """Given the non-default arguments, and the default 19 | keywords arguments, build the string that will appear 20 | in the __repr__ of the estimator, depending on the 21 | version of scikit-learn. 22 | """ 23 | if SKLEARN_AT_LEAST_0_22: 24 | def_kwargs = {} 25 | def_kwargs.update(nndef_kwargs) 26 | args_str = ",".join(f"{key}={repr(value)}" 27 | for key, value in def_kwargs.items()) 28 | return args_str 29 | 30 | 31 | class TestStringRepr(unittest.TestCase): 32 | 33 | def test_covariance(self): 34 | def_kwargs = {'preprocessor': None} 35 | nndef_kwargs = {} 36 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 37 | self.assertEqual(remove_spaces(str(metric_learn.Covariance())), 38 | remove_spaces(f"Covariance({merged_kwargs})")) 39 | 40 | def test_lmnn(self): 41 | def_kwargs = {'convergence_tol': 0.001, 'init': 'auto', 'n_neighbors': 3, 42 | 'learn_rate': 1e-07, 'max_iter': 1000, 'min_iter': 50, 43 | 'n_components': None, 'preprocessor': None, 44 | 'random_state': None, 'regularization': 0.5, 45 | 'verbose': False} 46 | nndef_kwargs = {'convergence_tol': 0.01, 'n_neighbors': 6} 47 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 48 | self.assertEqual( 49 | remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, 50 | n_neighbors=6))), 51 | remove_spaces(f"LMNN({merged_kwargs})")) 52 | 53 | def test_nca(self): 54 | def_kwargs = {'init': 'auto', 'max_iter': 100, 'n_components': None, 55 | 'preprocessor': None, 'random_state': None, 'tol': None, 56 | 'verbose': False} 57 | nndef_kwargs = {'max_iter': 42} 58 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 59 | self.assertEqual(remove_spaces(str(metric_learn.NCA(max_iter=42))), 60 | remove_spaces(f"NCA({merged_kwargs})")) 61 | 62 | def test_lfda(self): 63 | def_kwargs = {'embedding_type': 'weighted', 'k': None, 64 | 'n_components': None, 'preprocessor': None} 65 | nndef_kwargs = {'k': 2} 66 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 67 | self.assertEqual(remove_spaces(str(metric_learn.LFDA(k=2))), 68 | remove_spaces(f"LFDA({merged_kwargs})")) 69 | 70 | def test_itml(self): 71 | def_kwargs = {'tol': 0.001, 'gamma': 1.0, 72 | 'max_iter': 1000, 'preprocessor': None, 73 | 'prior': 'identity', 'random_state': None, 'verbose': False} 74 | nndef_kwargs = {'gamma': 0.5} 75 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 76 | self.assertEqual(remove_spaces(str(metric_learn.ITML(gamma=0.5))), 77 | remove_spaces(f"ITML({merged_kwargs})")) 78 | def_kwargs = {'tol': 0.001, 'gamma': 1.0, 79 | 'max_iter': 1000, 'n_constraints': None, 80 | 'preprocessor': None, 'prior': 'identity', 81 | 'random_state': None, 'verbose': False} 82 | nndef_kwargs = {'n_constraints': 7} 83 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 84 | self.assertEqual( 85 | remove_spaces(str(metric_learn.ITML_Supervised(n_constraints=7))), 86 | remove_spaces(f"ITML_Supervised({merged_kwargs})")) 87 | 88 | def test_lsml(self): 89 | def_kwargs = {'max_iter': 1000, 'preprocessor': None, 'prior': 'identity', 90 | 'random_state': None, 'tol': 0.001, 'verbose': False} 91 | nndef_kwargs = {'tol': 0.1} 92 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 93 | self.assertEqual(remove_spaces(str(metric_learn.LSML(tol=0.1))), 94 | remove_spaces(f"LSML({merged_kwargs})")) 95 | def_kwargs = {'max_iter': 1000, 'n_constraints': None, 96 | 'preprocessor': None, 'prior': 'identity', 97 | 'random_state': None, 'tol': 0.001, 'verbose': False, 98 | 'weights': None} 99 | nndef_kwargs = {'verbose': True} 100 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 101 | self.assertEqual( 102 | remove_spaces(str(metric_learn.LSML_Supervised(verbose=True))), 103 | remove_spaces(f"LSML_Supervised({merged_kwargs})")) 104 | 105 | def test_sdml(self): 106 | def_kwargs = {'balance_param': 0.5, 'preprocessor': None, 107 | 'prior': 'identity', 'random_state': None, 108 | 'sparsity_param': 0.01, 'verbose': False} 109 | nndef_kwargs = {'verbose': True} 110 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 111 | self.assertEqual(remove_spaces(str(metric_learn.SDML(verbose=True))), 112 | remove_spaces(f"SDML({merged_kwargs})")) 113 | def_kwargs = {'balance_param': 0.5, 'n_constraints': None, 114 | 'preprocessor': None, 'prior': 'identity', 115 | 'random_state': None, 'sparsity_param': 0.01, 116 | 'verbose': False} 117 | nndef_kwargs = {'sparsity_param': 0.5} 118 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 119 | self.assertEqual( 120 | remove_spaces(str(metric_learn.SDML_Supervised(sparsity_param=0.5))), 121 | remove_spaces(f"SDML_Supervised({merged_kwargs})")) 122 | 123 | def test_rca(self): 124 | def_kwargs = {'n_components': None, 'preprocessor': None} 125 | nndef_kwargs = {'n_components': 3} 126 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 127 | self.assertEqual(remove_spaces(str(metric_learn.RCA(n_components=3))), 128 | remove_spaces(f"RCA({merged_kwargs})")) 129 | def_kwargs = {'chunk_size': 2, 'n_components': None, 'n_chunks': 100, 130 | 'preprocessor': None, 'random_state': None} 131 | nndef_kwargs = {'n_chunks': 5} 132 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 133 | self.assertEqual( 134 | remove_spaces(str(metric_learn.RCA_Supervised(n_chunks=5))), 135 | remove_spaces(f"RCA_Supervised({merged_kwargs})")) 136 | 137 | def test_mlkr(self): 138 | def_kwargs = {'init': 'auto', 'max_iter': 1000, 139 | 'n_components': None, 'preprocessor': None, 140 | 'random_state': None, 'tol': None, 'verbose': False} 141 | nndef_kwargs = {'max_iter': 777} 142 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 143 | self.assertEqual(remove_spaces(str(metric_learn.MLKR(max_iter=777))), 144 | remove_spaces(f"MLKR({merged_kwargs})")) 145 | 146 | def test_mmc(self): 147 | def_kwargs = {'tol': 0.001, 'diagonal': False, 148 | 'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100, 149 | 'max_proj': 10000, 'preprocessor': None, 150 | 'random_state': None, 'verbose': False} 151 | nndef_kwargs = {'diagonal': True} 152 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 153 | self.assertEqual(remove_spaces(str(metric_learn.MMC(diagonal=True))), 154 | remove_spaces(f"MMC({merged_kwargs})")) 155 | def_kwargs = {'tol': 1e-06, 'diagonal': False, 156 | 'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100, 157 | 'max_proj': 10000, 'n_constraints': None, 158 | 'preprocessor': None, 'random_state': None, 159 | 'verbose': False} 160 | nndef_kwargs = {'max_iter': 1} 161 | merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) 162 | self.assertEqual( 163 | remove_spaces(str(metric_learn.MMC_Supervised(max_iter=1))), 164 | remove_spaces(f"MMC_Supervised({merged_kwargs})")) 165 | 166 | 167 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners, 168 | ids=ids_metric_learners) 169 | def test_get_metric_is_independent_from_metric_learner(estimator, 170 | build_dataset): 171 | """Tests that the get_metric method returns a function that is independent 172 | from the original metric learner""" 173 | input_data, labels, _, X = build_dataset() 174 | model = clone(estimator) 175 | set_random_state(model) 176 | 177 | # we fit the metric learner on it and then we compute the metric on some 178 | # points 179 | model.fit(*remove_y(model, input_data, labels)) 180 | metric = model.get_metric() 181 | score = metric(X[0], X[1]) 182 | 183 | # then we refit the estimator on another dataset 184 | model.fit(*remove_y(model, np.sin(input_data), labels)) 185 | 186 | # we recompute the distance between the two points: it should be the same 187 | score_bis = metric(X[0], X[1]) 188 | assert score_bis == score 189 | 190 | 191 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners, 192 | ids=ids_metric_learners) 193 | def test_get_metric_raises_error(estimator, build_dataset): 194 | """Tests that the metric returned by get_metric raises errors similar to 195 | the distance functions in scipy.spatial.distance""" 196 | input_data, labels, _, X = build_dataset() 197 | model = clone(estimator) 198 | set_random_state(model) 199 | model.fit(*remove_y(model, input_data, labels)) 200 | metric = model.get_metric() 201 | 202 | list_test_get_metric_raises = [(X[0].tolist() + [5.2], X[1]), # vectors with 203 | # different dimensions 204 | (X[0:4], X[1:5]), # 2D vectors 205 | (X[0].tolist() + [5.2], X[1] + [7.2])] 206 | # vectors of same dimension but incompatible with what the metric learner 207 | # was trained on 208 | 209 | for u, v in list_test_get_metric_raises: 210 | with pytest.raises(ValueError): 211 | metric(u, v) 212 | 213 | 214 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners, 215 | ids=ids_metric_learners) 216 | def test_get_metric_works_does_not_raise(estimator, build_dataset): 217 | """Tests that the metric returned by get_metric does not raise errors (or 218 | warnings) similarly to the distance functions in scipy.spatial.distance""" 219 | input_data, labels, _, X = build_dataset() 220 | model = clone(estimator) 221 | set_random_state(model) 222 | model.fit(*remove_y(model, input_data, labels)) 223 | metric = model.get_metric() 224 | 225 | list_test_get_metric_doesnt_raise = [(X[0], X[1]), 226 | (X[0].tolist(), X[1].tolist()), 227 | (X[0][None], X[1][None])] 228 | 229 | for u, v in list_test_get_metric_doesnt_raise: 230 | with warnings.catch_warnings(record=True) as record: 231 | metric(u, v) 232 | assert len(record) == 0 233 | 234 | # Test that the scalar case works 235 | model.components_ = np.array([3.1]) 236 | metric = model.get_metric() 237 | for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]: 238 | with warnings.catch_warnings(record=True) as record: 239 | metric(u, v) 240 | assert len(record) == 0 241 | 242 | 243 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners, 244 | ids=ids_metric_learners) 245 | def test_n_components(estimator, build_dataset): 246 | """Check that estimators that have a n_components parameters can use it 247 | and that it actually works as expected""" 248 | input_data, labels, _, X = build_dataset() 249 | model = clone(estimator) 250 | 251 | if hasattr(model, 'n_components'): 252 | set_random_state(model) 253 | model.set_params(n_components=None) 254 | model.fit(*remove_y(model, input_data, labels)) 255 | assert model.components_.shape == (X.shape[1], X.shape[1]) 256 | 257 | model = clone(estimator) 258 | set_random_state(model) 259 | model.set_params(n_components=X.shape[1] - 1) 260 | model.fit(*remove_y(model, input_data, labels)) 261 | assert model.components_.shape == (X.shape[1] - 1, X.shape[1]) 262 | 263 | model = clone(estimator) 264 | set_random_state(model) 265 | model.set_params(n_components=X.shape[1] + 1) 266 | with pytest.raises(ValueError) as expected_err: 267 | model.fit(*remove_y(model, input_data, labels)) 268 | assert (str(expected_err.value) == 269 | 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) 270 | 271 | model = clone(estimator) 272 | set_random_state(model) 273 | model.set_params(n_components=0) 274 | with pytest.raises(ValueError) as expected_err: 275 | model.fit(*remove_y(model, input_data, labels)) 276 | assert (str(expected_err.value) == 277 | 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) 278 | 279 | 280 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners, 281 | ids=ids_metric_learners) 282 | def test_score_pairs_warning(estimator, build_dataset): 283 | """Tests that score_pairs returns a FutureWarning regarding deprecation. 284 | Also that score_pairs and pair_distance have the same behaviour""" 285 | input_data, labels, _, X = build_dataset() 286 | model = clone(estimator) 287 | set_random_state(model) 288 | 289 | # We fit the metric learner on it and then we call score_pairs on some 290 | # points 291 | model.fit(*remove_y(model, input_data, labels)) 292 | 293 | msg = ("score_pairs will be deprecated in release 0.7.0. " 294 | "Use pair_score to compute similarity scores, or " 295 | "pair_distances to compute distances.") 296 | with pytest.warns(FutureWarning) as raised_warning: 297 | score = model.score_pairs([[X[0], X[1]], ]) 298 | dist = model.pair_distance([[X[0], X[1]], ]) 299 | assert array_equal(score, dist) 300 | assert any([str(warning.message) == msg for warning in raised_warning]) 301 | 302 | 303 | if __name__ == '__main__': 304 | unittest.main() 305 | -------------------------------------------------------------------------------- /test/test_components_metric_conversion.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import pytest 4 | from scipy.stats import ortho_group 5 | from sklearn.datasets import load_iris 6 | from numpy.testing import assert_array_almost_equal, assert_allclose 7 | from metric_learn.sklearn_shims import ignore_warnings 8 | 9 | from metric_learn import ( 10 | LMNN, NCA, LFDA, Covariance, MLKR, 11 | LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) 12 | from metric_learn._util import components_from_metric 13 | from metric_learn.exceptions import NonPSDError 14 | 15 | 16 | class TestTransformerMetricConversion(unittest.TestCase): 17 | @classmethod 18 | def setUpClass(self): 19 | # runs once per test class 20 | iris_data = load_iris() 21 | self.X = iris_data['data'] 22 | self.y = iris_data['target'] 23 | 24 | def test_cov(self): 25 | cov = Covariance() 26 | cov.fit(self.X) 27 | L = cov.components_ 28 | assert_array_almost_equal(L.T.dot(L), cov.get_mahalanobis_matrix()) 29 | 30 | def test_lsml_supervised(self): 31 | seed = np.random.RandomState(1234) 32 | lsml = LSML_Supervised(n_constraints=200, random_state=seed) 33 | lsml.fit(self.X, self.y) 34 | L = lsml.components_ 35 | assert_array_almost_equal(L.T.dot(L), lsml.get_mahalanobis_matrix()) 36 | 37 | def test_itml_supervised(self): 38 | seed = np.random.RandomState(1234) 39 | itml = ITML_Supervised(n_constraints=200, random_state=seed) 40 | itml.fit(self.X, self.y) 41 | L = itml.components_ 42 | assert_array_almost_equal(L.T.dot(L), itml.get_mahalanobis_matrix()) 43 | 44 | def test_lmnn(self): 45 | lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) 46 | lmnn.fit(self.X, self.y) 47 | L = lmnn.components_ 48 | assert_array_almost_equal(L.T.dot(L), lmnn.get_mahalanobis_matrix()) 49 | 50 | def test_sdml_supervised(self): 51 | seed = np.random.RandomState(1234) 52 | sdml = SDML_Supervised(n_constraints=1500, prior='identity', 53 | balance_param=1e-5, random_state=seed) 54 | sdml.fit(self.X, self.y) 55 | L = sdml.components_ 56 | assert_array_almost_equal(L.T.dot(L), sdml.get_mahalanobis_matrix()) 57 | 58 | def test_nca(self): 59 | n = self.X.shape[0] 60 | nca = NCA(max_iter=(100000 // n)) 61 | nca.fit(self.X, self.y) 62 | L = nca.components_ 63 | assert_array_almost_equal(L.T.dot(L), nca.get_mahalanobis_matrix()) 64 | 65 | def test_lfda(self): 66 | lfda = LFDA(k=2, n_components=2) 67 | lfda.fit(self.X, self.y) 68 | L = lfda.components_ 69 | assert_array_almost_equal(L.T.dot(L), lfda.get_mahalanobis_matrix()) 70 | 71 | def test_rca_supervised(self): 72 | rca = RCA_Supervised(n_components=2, n_chunks=30, chunk_size=2) 73 | rca.fit(self.X, self.y) 74 | L = rca.components_ 75 | assert_array_almost_equal(L.T.dot(L), rca.get_mahalanobis_matrix()) 76 | 77 | def test_mlkr(self): 78 | mlkr = MLKR(n_components=2) 79 | mlkr.fit(self.X, self.y) 80 | L = mlkr.components_ 81 | assert_array_almost_equal(L.T.dot(L), mlkr.get_mahalanobis_matrix()) 82 | 83 | @ignore_warnings 84 | def test_components_from_metric_edge_cases(self): 85 | """Test that components_from_metric returns the right result in various 86 | edge cases""" 87 | rng = np.random.RandomState(42) 88 | 89 | # an orthonormal matrix useful for creating matrices with given 90 | # eigenvalues: 91 | P = ortho_group.rvs(7, random_state=rng) 92 | 93 | # matrix with all its coefficients very low (to check that the algorithm 94 | # does not consider it as a diagonal matrix)(non regression test for 95 | # https://github.com/scikit-learn-contrib/metric-learn/issues/175) 96 | M = np.diag([1e-15, 2e-16, 3e-15, 4e-16, 5e-15, 6e-16, 7e-15]) 97 | M = P.dot(M).dot(P.T) 98 | L = components_from_metric(M) 99 | assert_allclose(L.T.dot(L), M) 100 | 101 | # diagonal matrix 102 | M = np.diag(np.abs(rng.randn(5))) 103 | L = components_from_metric(M) 104 | assert_allclose(L.T.dot(L), M) 105 | 106 | # low-rank matrix (with zeros) 107 | M = np.zeros((7, 7)) 108 | small_random = rng.randn(3, 3) 109 | M[:3, :3] = small_random.T.dot(small_random) 110 | L = components_from_metric(M) 111 | assert_allclose(L.T.dot(L), M) 112 | 113 | # low-rank matrix (without necessarily zeros) 114 | R = np.abs(rng.randn(7, 7)) 115 | M = R.dot(np.diag([1, 5, 3, 2, 0, 0, 0])).dot(R.T) 116 | L = components_from_metric(M) 117 | assert_allclose(L.T.dot(L), M) 118 | 119 | # matrix with a determinant still high but which is 120 | # undefinite w.r.t to numpy standards 121 | M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20]) 122 | M = P.dot(M).dot(P.T) 123 | assert np.abs(np.linalg.det(M)) > 10 124 | assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed 125 | # determinant is far from null) 126 | assert np.linalg.matrix_rank(M) < M.shape[0] 127 | # (just to show that this case is indeed considered by numpy as an 128 | # indefinite case) 129 | L = components_from_metric(M) 130 | assert_allclose(L.T.dot(L), M) 131 | 132 | # matrix with lots of small nonzeros that make a big zero when multiplied 133 | M = np.diag([1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3]) 134 | L = components_from_metric(M) 135 | assert_allclose(L.T.dot(L), M) 136 | 137 | # full rank matrix 138 | M = rng.randn(10, 10) 139 | M = M.T.dot(M) 140 | assert np.linalg.matrix_rank(M) == 10 141 | L = components_from_metric(M) 142 | assert_allclose(L.T.dot(L), M) 143 | 144 | def test_non_symmetric_matrix_raises(self): 145 | """Checks that if a non symmetric matrix is given to 146 | components_from_metric, an error is thrown""" 147 | rng = np.random.RandomState(42) 148 | M = rng.randn(10, 10) 149 | with pytest.raises(ValueError) as raised_error: 150 | components_from_metric(M) 151 | assert str(raised_error.value) == "The input metric should be symmetric." 152 | 153 | def test_non_psd_raises(self): 154 | """Checks that a non PSD matrix (i.e. with negative eigenvalues) will 155 | raise an error when passed to components_from_metric""" 156 | rng = np.random.RandomState(42) 157 | D = np.diag([1, 5, 3, 4.2, -4, -2, 1]) 158 | P = ortho_group.rvs(7, random_state=rng) 159 | M = P.dot(D).dot(P.T) 160 | msg = ("Matrix is not positive semidefinite (PSD).") 161 | with pytest.raises(NonPSDError) as raised_error: 162 | components_from_metric(M) 163 | assert str(raised_error.value) == msg 164 | with pytest.raises(NonPSDError) as raised_error: 165 | components_from_metric(D) 166 | assert str(raised_error.value) == msg 167 | 168 | def test_almost_psd_dont_raise(self): 169 | """Checks that if the metric is almost PSD (i.e. it has some negative 170 | eigenvalues very close to zero), then components_from_metric will still 171 | work""" 172 | rng = np.random.RandomState(42) 173 | D = np.diag([1, 5, 3, 4.2, -1e-20, -2e-20, -1e-20]) 174 | P = ortho_group.rvs(7, random_state=rng) 175 | M = P.dot(D).dot(P.T) 176 | L = components_from_metric(M) 177 | assert_allclose(L.T.dot(L), M) 178 | 179 | 180 | if __name__ == '__main__': 181 | unittest.main() 182 | -------------------------------------------------------------------------------- /test/test_constraints.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from sklearn.utils import shuffle 4 | from metric_learn.constraints import Constraints 5 | from sklearn.datasets import make_blobs 6 | 7 | SEED = 42 8 | 9 | 10 | def gen_labels_for_chunks(n_chunks, chunk_size, 11 | n_classes=10, n_unknown_labels=5): 12 | """Generates n_chunks*chunk_size labels that split in n_chunks chunks, 13 | that are homogeneous in the label.""" 14 | assert min(n_chunks, chunk_size) > 0 15 | classes = shuffle(np.arange(n_classes), random_state=SEED) 16 | n_per_class = chunk_size * (n_chunks // n_classes) 17 | n_maj_class = chunk_size * n_chunks - n_per_class * (n_classes - 1) 18 | 19 | first_labels = classes[0] * np.ones(n_maj_class, dtype=int) 20 | remaining_labels = np.concatenate([k * np.ones(n_per_class, dtype=int) 21 | for k in classes[1:]]) 22 | unknown_labels = -1 * np.ones(n_unknown_labels, dtype=int) 23 | 24 | labels = np.concatenate([first_labels, remaining_labels, unknown_labels]) 25 | return shuffle(labels, random_state=SEED) 26 | 27 | 28 | @pytest.mark.parametrize("n_chunks, chunk_size", [(5, 10), (10, 50)]) 29 | def test_exact_num_points_for_chunks(n_chunks, chunk_size): 30 | """Checks that the chunk generation works well with just enough points.""" 31 | labels = gen_labels_for_chunks(n_chunks, chunk_size) 32 | 33 | constraints = Constraints(labels) 34 | chunks = constraints.chunks(n_chunks=n_chunks, chunk_size=chunk_size, 35 | random_state=SEED) 36 | 37 | chunk_no, size_each_chunk = np.unique(chunks[chunks >= 0], 38 | return_counts=True) 39 | 40 | np.testing.assert_array_equal(size_each_chunk, chunk_size) 41 | assert chunk_no.shape[0] == n_chunks 42 | 43 | 44 | @pytest.mark.parametrize("n_chunks, chunk_size", [(5, 10), (10, 50)]) 45 | def test_chunk_case_one_miss_point(n_chunks, chunk_size): 46 | """Checks that the chunk generation breaks when one point is missing.""" 47 | labels = gen_labels_for_chunks(n_chunks, chunk_size) 48 | 49 | assert len(labels) >= 1 50 | constraints = Constraints(labels[1:]) 51 | with pytest.raises(ValueError) as e: 52 | constraints.chunks(n_chunks=n_chunks, chunk_size=chunk_size, 53 | random_state=SEED) 54 | 55 | expected_message = (('Not enough possible chunks of %d elements in each' 56 | ' class to form expected %d chunks - maximum number' 57 | ' of chunks is %d' 58 | ) % (chunk_size, n_chunks, n_chunks - 1)) 59 | 60 | assert str(e.value) == expected_message 61 | 62 | 63 | @pytest.mark.parametrize("n_chunks, chunk_size", [(5, 10), (10, 50)]) 64 | def test_unknown_labels_not_in_chunks(n_chunks, chunk_size): 65 | """Checks that unknown labels are not assigned to any chunk.""" 66 | labels = gen_labels_for_chunks(n_chunks, chunk_size) 67 | 68 | constraints = Constraints(labels) 69 | chunks = constraints.chunks(n_chunks=n_chunks, chunk_size=chunk_size, 70 | random_state=SEED) 71 | 72 | assert np.all(chunks[labels < 0] < 0) 73 | 74 | 75 | @pytest.mark.parametrize("k_genuine, k_impostor, T_test", 76 | [(2, 2, 77 | [[0, 1, 3], [0, 1, 4], [0, 2, 3], [0, 2, 4], 78 | [1, 0, 3], [1, 0, 4], [1, 2, 3], [1, 2, 4], 79 | [2, 0, 3], [2, 0, 4], [2, 1, 3], [2, 1, 4], 80 | [3, 4, 1], [3, 4, 2], [3, 5, 1], [3, 5, 2], 81 | [4, 3, 1], [4, 3, 2], [4, 5, 1], [4, 5, 2], 82 | [5, 3, 1], [5, 3, 2], [5, 4, 1], [5, 4, 2]]), 83 | (1, 3, 84 | [[0, 1, 3], [0, 1, 4], [0, 1, 5], [1, 0, 3], 85 | [1, 0, 4], [1, 0, 5], [2, 1, 3], [2, 1, 4], 86 | [2, 1, 5], [3, 4, 0], [3, 4, 1], [3, 4, 2], 87 | [4, 3, 0], [4, 3, 1], [4, 3, 2], [5, 4, 0], 88 | [5, 4, 1], [5, 4, 2]]), 89 | (1, 2, 90 | [[0, 1, 3], [0, 1, 4], [1, 0, 3], [1, 0, 4], 91 | [2, 1, 3], [2, 1, 4], [3, 4, 1], [3, 4, 2], 92 | [4, 3, 1], [4, 3, 2], [5, 4, 1], [5, 4, 2]])]) 93 | def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test): 94 | """Checks under the edge cases of knn triplet construction with enough 95 | neighbors""" 96 | 97 | X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]]) 98 | y = np.array([1, 1, 1, 2, 2, 2, -1]) 99 | 100 | T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) 101 | 102 | assert np.array_equal(sorted(T.tolist()), T_test) 103 | 104 | 105 | @pytest.mark.parametrize("k_genuine, k_impostor,", 106 | [(3, 3), (2, 4), (3, 4), (10, 9), (144, 33)]) 107 | def test_generate_knntriplets(k_genuine, k_impostor): 108 | """Checks edge and over the edge cases of knn triplet construction with not 109 | enough neighbors""" 110 | 111 | T_test = [[0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3], [0, 2, 4], [0, 2, 5], 112 | [1, 0, 3], [1, 0, 4], [1, 0, 5], [1, 2, 3], [1, 2, 4], [1, 2, 5], 113 | [2, 0, 3], [2, 0, 4], [2, 0, 5], [2, 1, 3], [2, 1, 4], [2, 1, 5], 114 | [3, 4, 0], [3, 4, 1], [3, 4, 2], [3, 5, 0], [3, 5, 1], [3, 5, 2], 115 | [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], [4, 5, 1], [4, 5, 2], 116 | [5, 3, 0], [5, 3, 1], [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]] 117 | 118 | X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]]) 119 | y = np.array([1, 1, 1, 2, 2, 2, -1]) 120 | 121 | msg1 = ("The class 1 has 3 elements, which is not sufficient to " 122 | f"generate {k_genuine+1} genuine neighbors " 123 | "as specified by k_genuine") 124 | msg2 = ("The class 2 has 3 elements, which is not sufficient to " 125 | f"generate {k_genuine+1} genuine neighbors " 126 | "as specified by k_genuine") 127 | msg3 = ("The class 1 has 3 elements of other classes, which is " 128 | f"not sufficient to generate {k_impostor} impostor " 129 | "neighbors as specified by k_impostor") 130 | msg4 = ("The class 2 has 3 elements of other classes, which is " 131 | f"not sufficient to generate {k_impostor} impostor " 132 | "neighbors as specified by k_impostor") 133 | msgs = [msg1, msg2, msg3, msg4] 134 | with pytest.warns(UserWarning) as user_warning: 135 | T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) 136 | assert any([[msg in str(warn.message) for msg in msgs] 137 | for warn in user_warning]) 138 | assert np.array_equal(sorted(T.tolist()), T_test) 139 | 140 | 141 | def test_generate_knntriplets_k_genuine(): 142 | """Checks the correct error raised when k_genuine is too big """ 143 | X, y = shuffle(*make_blobs(random_state=SEED), 144 | random_state=SEED) 145 | 146 | label, labels_count = np.unique(y, return_counts=True) 147 | labels_count_min = np.min(labels_count) 148 | idx_smallest_label, = np.where(labels_count == labels_count_min) 149 | k_genuine = labels_count_min 150 | 151 | warn_msgs = [] 152 | for idx in idx_smallest_label: 153 | warn_msgs.append("The class {} has {} elements, which is not sufficient " 154 | "to generate {} genuine neighbors as specified by " 155 | "k_genuine. Will generate {} genuine neighbors instead." 156 | "\n" 157 | .format(label[idx], k_genuine, k_genuine+1, k_genuine-1)) 158 | 159 | with pytest.warns(UserWarning) as raised_warning: 160 | Constraints(y).generate_knntriplets(X, k_genuine, 1) 161 | for warn in raised_warning: 162 | assert str(warn.message) in warn_msgs 163 | 164 | 165 | def test_generate_knntriplets_k_impostor(): 166 | """Checks the correct error raised when k_impostor is too big """ 167 | X, y = shuffle(*make_blobs(random_state=SEED), 168 | random_state=SEED) 169 | 170 | length = len(y) 171 | label, labels_count = np.unique(y, return_counts=True) 172 | labels_count_max = np.max(labels_count) 173 | idx_biggest_label, = np.where(labels_count == labels_count_max) 174 | k_impostor = length - labels_count_max + 1 175 | 176 | warn_msgs = [] 177 | for idx in idx_biggest_label: 178 | warn_msgs.append("The class {} has {} elements of other classes, which is" 179 | " not sufficient to generate {} impostor neighbors as " 180 | "specified by k_impostor. Will generate {} impostor " 181 | "neighbors instead.\n" 182 | .format(label[idx], k_impostor-1, k_impostor, 183 | k_impostor-1)) 184 | 185 | with pytest.warns(UserWarning) as raised_warning: 186 | Constraints(y).generate_knntriplets(X, 1, k_impostor) 187 | for warn in raised_warning: 188 | assert str(warn.message) in warn_msgs 189 | -------------------------------------------------------------------------------- /test/test_fit_transform.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from sklearn.datasets import load_iris 4 | from numpy.testing import assert_array_almost_equal 5 | 6 | from metric_learn import ( 7 | LMNN, NCA, LFDA, Covariance, MLKR, 8 | LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, 9 | MMC_Supervised) 10 | 11 | 12 | class TestFitTransform(unittest.TestCase): 13 | @classmethod 14 | def setUpClass(self): 15 | # runs once per test class 16 | iris_data = load_iris() 17 | self.X = iris_data['data'] 18 | self.y = iris_data['target'] 19 | 20 | def test_cov(self): 21 | cov = Covariance() 22 | cov.fit(self.X) 23 | res_1 = cov.transform(self.X) 24 | 25 | cov = Covariance() 26 | res_2 = cov.fit_transform(self.X) 27 | # deterministic result 28 | assert_array_almost_equal(res_1, res_2) 29 | 30 | def test_lsml_supervised(self): 31 | seed = np.random.RandomState(1234) 32 | lsml = LSML_Supervised(n_constraints=200, random_state=seed) 33 | lsml.fit(self.X, self.y) 34 | res_1 = lsml.transform(self.X) 35 | 36 | seed = np.random.RandomState(1234) 37 | lsml = LSML_Supervised(n_constraints=200, random_state=seed) 38 | res_2 = lsml.fit_transform(self.X, self.y) 39 | 40 | assert_array_almost_equal(res_1, res_2) 41 | 42 | def test_itml_supervised(self): 43 | seed = np.random.RandomState(1234) 44 | itml = ITML_Supervised(n_constraints=200, random_state=seed) 45 | itml.fit(self.X, self.y) 46 | res_1 = itml.transform(self.X) 47 | 48 | seed = np.random.RandomState(1234) 49 | itml = ITML_Supervised(n_constraints=200, random_state=seed) 50 | res_2 = itml.fit_transform(self.X, self.y) 51 | 52 | assert_array_almost_equal(res_1, res_2) 53 | 54 | def test_lmnn(self): 55 | lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) 56 | lmnn.fit(self.X, self.y) 57 | res_1 = lmnn.transform(self.X) 58 | 59 | lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) 60 | res_2 = lmnn.fit_transform(self.X, self.y) 61 | 62 | assert_array_almost_equal(res_1, res_2) 63 | 64 | def test_sdml_supervised(self): 65 | seed = np.random.RandomState(1234) 66 | sdml = SDML_Supervised(n_constraints=1500, balance_param=1e-5, 67 | prior='identity', random_state=seed) 68 | sdml.fit(self.X, self.y) 69 | res_1 = sdml.transform(self.X) 70 | 71 | seed = np.random.RandomState(1234) 72 | sdml = SDML_Supervised(n_constraints=1500, balance_param=1e-5, 73 | prior='identity', random_state=seed) 74 | res_2 = sdml.fit_transform(self.X, self.y) 75 | 76 | assert_array_almost_equal(res_1, res_2) 77 | 78 | def test_nca(self): 79 | n = self.X.shape[0] 80 | nca = NCA(max_iter=(100000 // n)) 81 | nca.fit(self.X, self.y) 82 | res_1 = nca.transform(self.X) 83 | 84 | nca = NCA(max_iter=(100000 // n)) 85 | res_2 = nca.fit_transform(self.X, self.y) 86 | 87 | assert_array_almost_equal(res_1, res_2) 88 | 89 | def test_lfda(self): 90 | lfda = LFDA(k=2, n_components=2) 91 | lfda.fit(self.X, self.y) 92 | res_1 = lfda.transform(self.X) 93 | 94 | lfda = LFDA(k=2, n_components=2) 95 | res_2 = lfda.fit_transform(self.X, self.y) 96 | 97 | # signs may be flipped, that's okay 98 | assert_array_almost_equal(abs(res_1), abs(res_2)) 99 | 100 | def test_rca_supervised(self): 101 | seed = np.random.RandomState(1234) 102 | rca = RCA_Supervised(n_components=2, n_chunks=30, chunk_size=2, 103 | random_state=seed) 104 | rca.fit(self.X, self.y) 105 | res_1 = rca.transform(self.X) 106 | 107 | seed = np.random.RandomState(1234) 108 | rca = RCA_Supervised(n_components=2, n_chunks=30, chunk_size=2, 109 | random_state=seed) 110 | res_2 = rca.fit_transform(self.X, self.y) 111 | 112 | assert_array_almost_equal(res_1, res_2) 113 | 114 | def test_mlkr(self): 115 | mlkr = MLKR(n_components=2) 116 | mlkr.fit(self.X, self.y) 117 | res_1 = mlkr.transform(self.X) 118 | 119 | mlkr = MLKR(n_components=2) 120 | res_2 = mlkr.fit_transform(self.X, self.y) 121 | 122 | assert_array_almost_equal(res_1, res_2) 123 | 124 | def test_mmc_supervised(self): 125 | seed = np.random.RandomState(1234) 126 | mmc = MMC_Supervised(n_constraints=200, random_state=seed) 127 | mmc.fit(self.X, self.y) 128 | res_1 = mmc.transform(self.X) 129 | 130 | seed = np.random.RandomState(1234) 131 | mmc = MMC_Supervised(n_constraints=200, random_state=seed) 132 | res_2 = mmc.fit_transform(self.X, self.y) 133 | 134 | assert_array_almost_equal(res_1, res_2) 135 | 136 | 137 | if __name__ == '__main__': 138 | unittest.main() 139 | -------------------------------------------------------------------------------- /test/test_quadruplets_classifiers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sklearn.exceptions import NotFittedError 3 | from sklearn.model_selection import train_test_split 4 | 5 | from test.test_utils import quadruplets_learners, ids_quadruplets_learners 6 | from metric_learn.sklearn_shims import set_random_state 7 | from sklearn import clone 8 | import numpy as np 9 | 10 | 11 | @pytest.mark.parametrize('with_preprocessor', [True, False]) 12 | @pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, 13 | ids=ids_quadruplets_learners) 14 | def test_predict_only_one_or_minus_one(estimator, build_dataset, 15 | with_preprocessor): 16 | """Test that all predicted values are either +1 or -1""" 17 | input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) 18 | estimator = clone(estimator) 19 | estimator.set_params(preprocessor=preprocessor) 20 | set_random_state(estimator) 21 | (quadruplets_train, 22 | quadruplets_test, y_train, y_test) = train_test_split(input_data, labels) 23 | estimator.fit(quadruplets_train) 24 | predictions = estimator.predict(quadruplets_test) 25 | not_valid = [e for e in predictions if e not in [-1, 1]] 26 | assert len(not_valid) == 0 27 | 28 | 29 | @pytest.mark.parametrize('with_preprocessor', [True, False]) 30 | @pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, 31 | ids=ids_quadruplets_learners) 32 | def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, 33 | with_preprocessor): 34 | """Test that a NotFittedError is raised if someone tries to predict and 35 | the metric learner has not been fitted.""" 36 | input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) 37 | estimator = clone(estimator) 38 | estimator.set_params(preprocessor=preprocessor) 39 | set_random_state(estimator) 40 | with pytest.raises(NotFittedError): 41 | estimator.predict(input_data) 42 | 43 | 44 | @pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, 45 | ids=ids_quadruplets_learners) 46 | def test_accuracy_toy_example(estimator, build_dataset): 47 | """Test that the default scoring for quadruplets (accuracy) works on some 48 | toy example""" 49 | input_data, labels, preprocessor, X = build_dataset(with_preprocessor=False) 50 | estimator = clone(estimator) 51 | estimator.set_params(preprocessor=preprocessor) 52 | set_random_state(estimator) 53 | estimator.fit(input_data) 54 | # We take the two first points and we build 4 regularly spaced points on the 55 | # line they define, so that it's easy to build quadruplets of different 56 | # similarities. 57 | X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4 58 | quadruplets_test = np.array( 59 | [[X_test[0], X_test[2], X_test[0], X_test[1]], 60 | [X_test[1], X_test[3], X_test[1], X_test[0]], 61 | [X_test[1], X_test[2], X_test[0], X_test[3]], 62 | [X_test[3], X_test[0], X_test[2], X_test[1]]]) 63 | # we force the transformation to be identity so that we control what it does 64 | estimator.components_ = np.eye(X.shape[1]) 65 | assert estimator.score(quadruplets_test) == 0.25 66 | -------------------------------------------------------------------------------- /test/test_triplets_classifiers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sklearn.exceptions import NotFittedError 3 | from sklearn.model_selection import train_test_split 4 | 5 | from metric_learn import SCML 6 | from test.test_utils import ( 7 | triplets_learners, 8 | ids_triplets_learners, 9 | build_triplets 10 | ) 11 | from metric_learn.sklearn_shims import set_random_state 12 | from sklearn import clone 13 | import numpy as np 14 | from numpy.testing import assert_array_equal 15 | 16 | 17 | @pytest.mark.parametrize('with_preprocessor', [True, False]) 18 | @pytest.mark.parametrize('estimator, build_dataset', triplets_learners, 19 | ids=ids_triplets_learners) 20 | def test_predict_only_one_or_minus_one(estimator, build_dataset, 21 | with_preprocessor): 22 | """Test that all predicted values are either +1 or -1""" 23 | input_data, _, preprocessor, _ = build_dataset(with_preprocessor) 24 | estimator = clone(estimator) 25 | estimator.set_params(preprocessor=preprocessor) 26 | set_random_state(estimator) 27 | triplets_train, triplets_test = train_test_split(input_data) 28 | estimator.fit(triplets_train) 29 | predictions = estimator.predict(triplets_test) 30 | 31 | not_valid = [e for e in predictions if e not in [-1, 1]] 32 | assert len(not_valid) == 0 33 | 34 | 35 | @pytest.mark.parametrize('estimator, build_dataset', triplets_learners, 36 | ids=ids_triplets_learners) 37 | def test_no_zero_prediction(estimator, build_dataset): 38 | """ 39 | Test that all predicted values are not zero, even when the 40 | distance d(x,y) and d(x,z) is the same for a triplet of the 41 | form (x, y, z). i.e border cases. 42 | """ 43 | triplets, _, _, X = build_dataset(with_preprocessor=False) 44 | # Force 3 dimentions only, to use cross product and get easy orthogonal vec. 45 | triplets = np.array([[t[0][:3], t[1][:3], t[2][:3]] for t in triplets]) 46 | X = X[:, :3] 47 | # Dummy fit 48 | estimator = clone(estimator) 49 | set_random_state(estimator) 50 | estimator.fit(triplets) 51 | # We force the transformation to be identity, to force euclidean distance 52 | estimator.components_ = np.eye(X.shape[1]) 53 | 54 | # Get two orthogonal vectors in respect to X[1] 55 | k = X[1] / np.linalg.norm(X[1]) # Normalize first vector 56 | x = X[2] - X[2].dot(k) * k # Get random orthogonal vector 57 | x /= np.linalg.norm(x) # Normalize 58 | y = np.cross(k, x) # Get orthogonal vector to x 59 | # Assert these orthogonal vectors are different 60 | with pytest.raises(AssertionError): 61 | assert_array_equal(X[1], x) 62 | with pytest.raises(AssertionError): 63 | assert_array_equal(X[1], y) 64 | # Assert the distance is the same for both 65 | assert estimator.get_metric()(X[1], x) == estimator.get_metric()(X[1], y) 66 | 67 | # Form the three scenarios where predict() gives 0 with numpy.sign 68 | triplets_test = np.array( # Critical examples 69 | [[X[0], X[2], X[2]], 70 | [X[1], X[1], X[1]], 71 | [X[1], x, y]]) 72 | # Predict 73 | predictions = estimator.predict(triplets_test) 74 | # Check there are no zero values 75 | assert np.sum(predictions == 0) == 0 76 | 77 | 78 | @pytest.mark.parametrize('with_preprocessor', [True, False]) 79 | @pytest.mark.parametrize('estimator, build_dataset', triplets_learners, 80 | ids=ids_triplets_learners) 81 | def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, 82 | with_preprocessor): 83 | """Test that a NotFittedError is raised if someone tries to predict and 84 | the metric learner has not been fitted.""" 85 | input_data, _, preprocessor, _ = build_dataset(with_preprocessor) 86 | estimator = clone(estimator) 87 | estimator.set_params(preprocessor=preprocessor) 88 | set_random_state(estimator) 89 | with pytest.raises(NotFittedError): 90 | estimator.predict(input_data) 91 | 92 | 93 | @pytest.mark.parametrize('estimator, build_dataset', triplets_learners, 94 | ids=ids_triplets_learners) 95 | def test_accuracy_toy_example(estimator, build_dataset): 96 | """Test that the default scoring for triplets (accuracy) works on some 97 | toy example""" 98 | triplets, _, _, X = build_dataset(with_preprocessor=False) 99 | estimator = clone(estimator) 100 | set_random_state(estimator) 101 | estimator.fit(triplets) 102 | # We take the two first points and we build 4 regularly spaced points on the 103 | # line they define, so that it's easy to build triplets of different 104 | # similarities. 105 | X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4 106 | 107 | triplets_test = np.array( 108 | [[X_test[0], X_test[2], X_test[1]], 109 | [X_test[1], X_test[3], X_test[0]], 110 | [X_test[1], X_test[2], X_test[3]], 111 | [X_test[3], X_test[0], X_test[2]]]) 112 | # we force the transformation to be identity so that we control what it does 113 | estimator.components_ = np.eye(X.shape[1]) 114 | assert estimator.score(triplets_test) == 0.25 115 | 116 | 117 | def test_raise_big_number_of_features(): 118 | triplets, _, _, X = build_triplets(with_preprocessor=False) 119 | triplets = triplets[:3, :, :] 120 | estimator = SCML(n_basis=320) 121 | set_random_state(estimator) 122 | with pytest.raises(ValueError) as exc_info: 123 | estimator.fit(triplets) 124 | assert exc_info.value.args[0] == \ 125 | "Number of features (4) is greater than the number of triplets(3)." \ 126 | "\nConsider using dimensionality reduction or using another basis " \ 127 | "generation scheme." 128 | --------------------------------------------------------------------------------