├── .gitattributes ├── .github └── workflows │ ├── doc.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── README_datasets.png ├── README_results.png ├── doc ├── Makefile ├── make.bat └── source │ ├── changelog.rst │ ├── conf.py │ └── index.rst ├── pyproject.toml ├── src └── pyannote │ └── pipeline │ ├── __init__.py │ ├── experiment.py │ ├── optimizer.py │ ├── parameter.py │ ├── pipeline.py │ └── typing.py ├── tests ├── __init__.py ├── test_optimizer.py ├── test_pipeline.py └── utils.py └── uv.lock /.gitattributes: -------------------------------------------------------------------------------- 1 | doc/source/conf.py export-subst 2 | -------------------------------------------------------------------------------- /.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | on: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | build-and-deploy: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | persist-credentials: false 15 | fetch-depth: 0 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v5 18 | with: 19 | enable-cache: true 20 | cache-dependency-glob: uv.lock 21 | 22 | - name: Install the project 23 | run: uv sync --extra doc 24 | 25 | - name: Build documentation 26 | run: | 27 | make --directory=doc html 28 | touch ./doc/build/html/.nojekyll 29 | - name: Deploy 30 | uses: peaceiris/actions-gh-pages@v3 31 | with: 32 | github_token: ${{ secrets.GITHUB_TOKEN }} 33 | publish_dir: ./doc/build/html 34 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | name: Build distribution 📦 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | with: 13 | persist-credentials: false 14 | fetch-depth: 0 15 | - name: Install uv 16 | uses: astral-sh/setup-uv@v5 17 | with: 18 | enable-cache: true 19 | cache-dependency-glob: uv.lock 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version-file: ".python-version" 24 | - name: Build 25 | run: uv build 26 | - name: Store the distribution packages 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: python-package-distributions 30 | path: dist/ 31 | 32 | publish-to-pypi: 33 | name: >- 34 | Publish Python 🐍 distribution 📦 to PyPI 35 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 36 | needs: 37 | - build 38 | runs-on: ubuntu-latest 39 | environment: 40 | name: pypi 41 | permissions: 42 | id-token: write 43 | steps: 44 | - name: Download all the dists 45 | uses: actions/download-artifact@v4 46 | with: 47 | name: python-package-distributions 48 | path: dist/ 49 | - name: Install uv 50 | uses: astral-sh/setup-uv@v5 51 | with: 52 | enable-cache: true 53 | cache-dependency-glob: uv.lock 54 | - name: Publish distribution 📦 to PyPI 55 | run: uv publish --trusted-publishing always --publish-url https://upload.pypi.org/legacy/ 56 | 57 | 58 | github-release: 59 | name: >- 60 | Sign the Python 🐍 distribution 📦 with Sigstore 61 | and upload them to GitHub Release 62 | needs: 63 | - publish-to-pypi 64 | runs-on: ubuntu-latest 65 | 66 | permissions: 67 | contents: write # IMPORTANT: mandatory for making GitHub Releases 68 | id-token: write # IMPORTANT: mandatory for sigstore 69 | 70 | steps: 71 | - name: Download all the dists 72 | uses: actions/download-artifact@v4 73 | with: 74 | name: python-package-distributions 75 | path: dist/ 76 | - name: Sign the dists with Sigstore 77 | uses: sigstore/gh-action-sigstore-python@v3.0.0 78 | with: 79 | inputs: >- 80 | ./dist/*.tar.gz 81 | ./dist/*.whl 82 | - name: Create GitHub Release 83 | env: 84 | GITHUB_TOKEN: ${{ github.token }} 85 | run: >- 86 | gh release create 87 | "$GITHUB_REF_NAME" 88 | --repo "$GITHUB_REPOSITORY" 89 | --notes "" 90 | - name: Upload artifact signatures to GitHub Release 91 | env: 92 | GITHUB_TOKEN: ${{ github.token }} 93 | # Upload to GitHub Release using the `gh` CLI. 94 | # `dist/` contains the built packages, and the 95 | # sigstore-produced signatures and certificates. 96 | run: >- 97 | gh release upload 98 | "$GITHUB_REF_NAME" dist/** 99 | --repo "$GITHUB_REPOSITORY" 100 | 101 | # publish-to-testpypi: 102 | # name: Publish Python 🐍 distribution 📦 to TestPyPI 103 | # needs: 104 | # - build 105 | # runs-on: ubuntu-latest 106 | # 107 | # environment: 108 | # name: testpypi 109 | # 110 | # permissions: 111 | # id-token: write # IMPORTANT: mandatory for trusted publishing 112 | # 113 | # steps: 114 | # - name: Download all the dists 115 | # uses: actions/download-artifact@v4 116 | # with: 117 | # name: python-package-distributions 118 | # path: dist/ 119 | # - name: Install uv 120 | # uses: astral-sh/setup-uv@v5 121 | # with: 122 | # enable-cache: true 123 | # cache-dependency-glob: uv.lock 124 | # - name: Publish distribution 📦 to PyPI 125 | # run: uv publish --trusted-publishing always --publish-url https://test.pypi.org/legacy/ 126 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - develop 7 | push: 8 | branches: 9 | - develop 10 | - master 11 | - release/* 12 | 13 | 14 | jobs: 15 | test: 16 | name: Test 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: 21 | - "3.10" 22 | - "3.11" 23 | - "3.12" 24 | env: 25 | UV_PYTHON: ${{ matrix.python-version }} 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Install uv 30 | uses: astral-sh/setup-uv@v5 31 | 32 | - name: Install the project 33 | run: uv sync --extra test 34 | 35 | - name: Run tests 36 | run: uv run pytest tests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | MANIFEST 10 | .Python 11 | env/ 12 | venv/ 13 | bin/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # Installer logs 28 | pip-log.txt 29 | pip-delete-this-directory.txt 30 | 31 | # Unit test / coverage reports 32 | htmlcov/ 33 | .tox/ 34 | .coverage 35 | .cache 36 | nosetests.xml 37 | coverage.xml 38 | 39 | # Translations 40 | *.mo 41 | 42 | # Mr Developer 43 | .mr.developer.cfg 44 | .project 45 | .pydevproject 46 | 47 | # Rope 48 | .ropeproject 49 | 50 | # Django stuff: 51 | *.log 52 | *.pot 53 | 54 | # emacs temporary files 55 | *~ 56 | 57 | .mypy_cache/* 58 | 59 | # Pycharm config files 60 | .idea/ -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 CNRS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | AUTHORS 24 | Hervé BREDIN - http://herve.niderb.fr 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ```python 2 | >>> %pylab inline 3 | Populating the interactive namespace from numpy and matplotlib 4 | ``` 5 | 6 | # pyannote.pipeline 7 | 8 | ## Installation 9 | ```bash 10 | $ pip install pyannote.pipeline 11 | ``` 12 | 13 | ## Tutorial 14 | 15 | In this tutorial, we are going to tune hyper-parameters of a clustering pipeline. 16 | 17 | We start by gathering a "training set" of 5 different clustering datasets. 18 | 19 | ```python 20 | >>> # shamelessly stolen from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html 21 | ... 22 | ... import numpy as np 23 | >>> from sklearn import datasets 24 | ... 25 | >>> np.random.seed(0) 26 | ... 27 | >>> n_samples = 1500 28 | >>> noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5, 29 | ... noise=.05) 30 | >>> noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05) 31 | >>> blobs = datasets.make_blobs(n_samples=n_samples, random_state=8) 32 | ... 33 | >>> # Anisotropicly distributed data 34 | ... random_state = 170 35 | >>> X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state) 36 | >>> transformation = [[0.6, -0.6], [-0.4, 0.8]] 37 | >>> X_aniso = np.dot(X, transformation) 38 | >>> aniso = (X_aniso, y) 39 | ... 40 | >>> # blobs with varied variances 41 | ... varied = datasets.make_blobs(n_samples=n_samples, 42 | ... cluster_std=[1.0, 2.5, 0.5], 43 | ... random_state=random_state) 44 | ... 45 | >>> datasets = [noisy_circles, noisy_moons, varied, aniso, blobs] 46 | ... 47 | >>> figsize(20, 4) 48 | >>> for d, dataset in enumerate(datasets): 49 | ... X, y = dataset 50 | ... subplot(1, len(datasets), d + 1) 51 | ... scatter(X[:, 0], X[:, 1], c=y) 52 | >>> savefig('README_datasets.png') 53 |
54 | ``` 55 | 56 | ![datasets](README_datasets.png) 57 | 58 | 59 | Then, we define the clustering pipeline (including its set of hyper-parameters and the objective function). 60 | 61 | ```python 62 | >>> from pyannote.pipeline import Pipeline 63 | >>> from pyannote.pipeline.parameter import Uniform 64 | >>> from pyannote.pipeline.parameter import Integer 65 | >>> from sklearn.cluster import DBSCAN 66 | >>> from sklearn.metrics import v_measure_score 67 | ... 68 | >>> # a pipeline should inherit from `pyannote.pipeline.Pipeline` 69 | ... class SimplePipeline(Pipeline): 70 | ... 71 | ... # this pipeline has two hyper-parameters. 72 | ... # `eps` follows a uniform distribution between 0 and 10 73 | ... # `min_samples` is a random integer between 1 and 20 74 | ... def __init__(self): 75 | ... super().__init__() 76 | ... self.eps = Uniform(0, 10) 77 | ... self.min_samples = Integer(1, 20) 78 | ... 79 | ... # `initialize` should be used to setup the pipeline. it 80 | ... # is called every time a new set of hyper-parameters is 81 | ... # tested. 82 | ... def initialize(self): 83 | ... # this pipeline relies on scikit-learn DBSCAN. 84 | ... self._dbscan = DBSCAN(eps=self.eps, 85 | ... min_samples=self.min_samples) 86 | ... 87 | ... # this is where the pipeline is applied to a dataset 88 | ... def __call__(self, dataset): 89 | ... X = dataset[0] 90 | ... y_pred = self._dbscan.fit_predict(X) 91 | ... return y_pred 92 | ... 93 | ... # this is the loss we are trying to minimize 94 | ... def loss(self, dataset, y_pred): 95 | ... # we rely on sklearn v_measure_score 96 | ... y_true = dataset[1] 97 | ... return 1. - v_measure_score(y_true, y_pred) 98 | ... 99 | >>> pipeline = SimplePipeline() 100 | ``` 101 | 102 | This is where the hyper-parameter optimization actually happens. 103 | 104 | ```python 105 | >>> !rm dbscan.db 106 | ``` 107 | 108 | ```python 109 | >>> # we initialize an optimizer (that store its trials in SQLite file dbscan.db) 110 | ... from pyannote.pipeline import Optimizer 111 | >>> optimizer = Optimizer(pipeline, db=Path('dbscan.db')) 112 | ``` 113 | 114 | ```python 115 | >>> # we run 100 optimization iterations and display the best set of hyper-parameters 116 | ... optimizer.tune(datasets, n_iterations=100) 117 | >>> optimizer.best_params 118 | {'eps': 0.1912781975831715, 'min_samples': 18} 119 | ``` 120 | 121 | We then compare expected (upper row) and actual (lower row) clustering results with the best set of hyper-parameters 122 | 123 | ```python 124 | >>> best_pipeline = optimizer.best_pipeline 125 | >>> # equivalent to 126 | ... # best_pipeline = pipeline.instantiate(optimizer.best_params) 127 | ... 128 | ... figsize(20, 8) 129 | >>> for d, dataset in enumerate(datasets): 130 | ... X, y_true = dataset 131 | ... y_pred = best_pipeline(dataset) 132 | ... subplot(2, len(datasets), d + 1) 133 | ... scatter(X[:, 0], X[:, 1], c=y_true) 134 | ... subplot(2, len(datasets), d + 1 + len(datasets)) 135 | ... scatter(X[:, 0], X[:, 1], c=y_pred) 136 | ... title(f'score = {v_measure_score(y_true, y_pred):g}') 137 | >>> savefig('README_results.png') 138 |
139 | ``` 140 | 141 | ![results](README_results.png) 142 | 143 | 144 | ## Documentation 145 | 146 | `pyannote.pipeline` can do much more than that (including composing pipelines and freezing hyper-parameters). 147 | See `pyannote.audio.pipeline` for advanced examples. 148 | -------------------------------------------------------------------------------- /README_datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyannote/pyannote-pipeline/d7013b673ce0bd3c9bdd225c78c2cbd7ff280ff3/README_datasets.png -------------------------------------------------------------------------------- /README_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyannote/pyannote-pipeline/d7013b673ce0bd3c9bdd225c78c2cbd7ff280ff3/README_results.png -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = pyannotepipeline 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | uv run --extra doc $(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | uv run --extra doc $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=pyannotepipeline 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /doc/source/changelog.rst: -------------------------------------------------------------------------------- 1 | ######### 2 | Changelog 3 | ######### 4 | 5 | Version 4.0.0rc2 (2025-02-23) 6 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 7 | 8 | - feat(optimize): add option to pass keyword arguments to pipeline during optimization 9 | 10 | Version 4.0.0rc1 (2025-02-11) 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | - feat(optimize): add option to pass keyword arguments to pipeline during optimization 14 | - BREAKING: drop support for `Python` < 3.10 15 | - BREAKING: switch to native namespace package 16 | - BREAKING: remove `pyannote.pipeline.blocks` submodule 17 | - setup: switch to `uv` 18 | 19 | Version 3.1.2 (2025-02-07) 20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 21 | 22 | - fix: should_prune() takes no argument 23 | - fix: make `Optimizer.best_loss`` return infinity with no trial 24 | - fix: fix missing `scipy` import (#54, @arxaqapi) 25 | - feat: add "--use-filter" option to filter training/validation files 26 | 27 | Version 3.0.1 (2023-09-22) 28 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 29 | 30 | - BREAKING(cli): switch to latest pyannote.database API 31 | - feat: add "seed" parameter for reproducible optimization 32 | - feat(cli): add "device" section in configuration file 33 | - feat(cli): add "--registry" option for custom database loading 34 | - feat(cli): add "--average-case" option to optimize for average case 35 | - setup: switch to optuna 3.1+ 36 | - feat: add support for optuna Journal storage 37 | 38 | Version 2.3 (2022-06-16) 39 | ~~~~~~~~~~~~~~~~~~~~~~~~ 40 | 41 | - BREAKING: optimize loss estimate upper bound instead of average (#42) 42 | - feat: add tests and typing (#41, @hadware) 43 | - feat: add ParamDict structured hyper-parameter (#40, @hadware) 44 | - feat: set sub-pipeline "training" attribute recursively (#39) 45 | - doc: fix various typos 46 | 47 | Version 2.2 (2021-12-10) 48 | ~~~~~~~~~~~~~~~~~~~~~~~~ 49 | 50 | - feat: add Pipeline.instantiated attribute 51 | 52 | Version 2.1 (2021-09-15) 53 | ~~~~~~~~~~~~~~~~~~~~~~~~ 54 | 55 | - feat: add Pipeline.training attribute 56 | 57 | Version 2.0 (2020-11-25) 58 | ~~~~~~~~~~~~~~~~~~~~~~~~ 59 | 60 | - BREAKING: remove "direction" argument from Optimizer 61 | - feat: add Pipeline.get_direction method (defaults to "minimize") 62 | - feat: add progress bar in "apply" mode 63 | 64 | Version 1.5.2 (2020-06-26) 65 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 66 | 67 | - feat: add show_progress option to control second progress bar 68 | - improve: catch optuna's ExperimentalWarning 69 | 70 | Version 1.5.1 (2020-06-18) 71 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 72 | 73 | - feat: add second progress bar to display trial internal progress 74 | - fix: skip Frozen parameters in pipeline.instantiate (@PaulLerner) 75 | - fix: switch to pyannote.database 4.0+ (@PaulLerner) 76 | - setup: switch to optuna 1.4+ and pyannote.core 4.0+ 77 | 78 | Version 1.5 (2020-04-01) 79 | ~~~~~~~~~~~~~~~~~~~~~~~~ 80 | 81 | - feat: add "direction" parameter to Optimizer 82 | - fix: fix support for in-memory optimization (when db is None) 83 | - setup: switch to pyannote.database 3.0 84 | 85 | Version 1.4 (2020-03-10) 86 | ~~~~~~~~~~~~~~~~~~~~~~~~ 87 | 88 | - feat: add option to bootstrap optimization with pretrained pipeline 89 | 90 | Version 1.3 (2020-01-27) 91 | ~~~~~~~~~~~~~~~~~~~~~~~~ 92 | 93 | - BREAKING: write "apply" mode output to "train" subdirectory 94 | - feat: store best loss value in "params.yml" 95 | - fix: handle corner case in pyannote.pipeline.blocks.clustering 96 | - fix: use YAML safe loader 97 | 98 | Version 1.2 (2019-06-26) 99 | ~~~~~~~~~~~~~~~~~~~~~~~~ 100 | 101 | - feat: add support for callable preprocessors 102 | - setup: switch to pyannote.core 3.0 103 | - setup: switch to pyannote.database 2.2 104 | 105 | Version 1.1.1 (2019-04-09) 106 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 107 | 108 | - fix: do not raise FileExistsError when output directory exists in `pyannote-pipeline apply` 109 | - fix: skip evaluation of protocols without groundtruth in `pyannote-pipeline apply` 110 | - setup: switch to pyannote.database 2.1 111 | 112 | Version 1.1 (2019-03-20) 113 | ~~~~~~~~~~~~~~~~~~~~~~~~ 114 | 115 | - feat: add export to RTTM format 116 | - setup: switch to pyannote.database 2.0 117 | - fix: fix "use_threshold" parameter in HAC block 118 | 119 | Version 1.0 (2019-02-05) 120 | ~~~~~~~~~~~~~~~~~~~~~~~~ 121 | 122 | - feat: add support for pyannote.metrics (through `Pipeline.get_metric`) 123 | - feat: add support for optuna trial pruning 124 | - feat: keep track of processing & evaluation time 125 | 126 | Version 0.3 (2019-01-17) 127 | ~~~~~~~~~~~~~~~~~~~~~~~~ 128 | 129 | - feat: switch to optuna backend 130 | - feat: add "use_threshold" option to HAC pipeline 131 | - BREAKING: update Pipeline API 132 | - BREAKING: update Optimizer API 133 | - BREAKING: remove tensorboard support (for now) 134 | 135 | Version 0.2.1 (2018-12-04) 136 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 137 | 138 | - first public release 139 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # pyannote.pipeline documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Jan 24 15:45:55 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | 20 | # allow pyannote.pipeline import 21 | import os 22 | import sys 23 | 24 | sys.path.insert(0, os.path.abspath("../..")) 25 | 26 | 27 | # -- General configuration ------------------------------------------------ 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # 31 | # needs_sphinx = '1.0' 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.napoleon", 39 | "sphinx.ext.intersphinx", 40 | "sphinx.ext.todo", 41 | "sphinx.ext.coverage", 42 | "sphinx.ext.mathjax", 43 | "sphinx.ext.viewcode", 44 | "sphinx.ext.githubpages", 45 | ] 46 | 47 | 48 | # Napoleon settings 49 | napoleon_google_docstring = True 50 | napoleon_numpy_docstring = True 51 | napoleon_include_init_with_doc = False 52 | napoleon_include_private_with_doc = False 53 | napoleon_include_special_with_doc = False 54 | napoleon_use_admonition_for_examples = False 55 | napoleon_use_admonition_for_notes = False 56 | napoleon_use_admonition_for_references = False 57 | napoleon_use_ivar = False 58 | napoleon_use_param = True 59 | napoleon_use_rtype = True 60 | napoleon_use_keyword = True 61 | 62 | # Add any paths that contain templates here, relative to this directory. 63 | templates_path = ["_templates"] 64 | 65 | # The suffix(es) of source filenames. 66 | # You can specify multiple suffix as a list of string: 67 | # 68 | # source_suffix = ['.rst', '.md'] 69 | source_suffix = {".rst": "restructuredtext"} 70 | 71 | # The master toctree document. 72 | master_doc = "index" 73 | 74 | # General information about the project. 75 | project = "pyannote.pipeline" 76 | copyright = "2017, CNRS" 77 | author = "Hervé Bredin" 78 | 79 | # The version info for the project you're documenting, acts as replacement for 80 | # |version| and |release|, also used in various other places throughout the 81 | # built documents. 82 | 83 | import pyannote.pipeline 84 | 85 | # The short X.Y version. 86 | version = pyannote.pipeline.__version__.split("+")[0] 87 | # The full version, including alpha/beta/rc tags. 88 | release = pyannote.pipeline.__version__ 89 | 90 | # The language for content autogenerated by Sphinx. Refer to documentation 91 | # for a list of supported languages. 92 | # 93 | # This is also used if you do content translation via gettext catalogs. 94 | # Usually you set "language" from the command line for these cases. 95 | language = "en" 96 | 97 | # List of patterns, relative to source directory, that match files and 98 | # directories to ignore when looking for source files. 99 | # This patterns also effect to html_static_path and html_extra_path 100 | exclude_patterns = [] 101 | 102 | # The name of the Pygments (syntax highlighting) style to use. 103 | pygments_style = "sphinx" 104 | 105 | # If true, `todo` and `todoList` produce output, else they produce nothing. 106 | todo_include_todos = True 107 | 108 | 109 | # -- Options for HTML output ---------------------------------------------- 110 | 111 | # The theme to use for HTML and HTML Help pages. See the documentation for 112 | # a list of builtin themes. 113 | # 114 | html_theme = "sphinx_rtd_theme" 115 | 116 | # Theme options are theme-specific and customize the look and feel of a theme 117 | # further. For a list of options available for each theme, see the 118 | # documentation. 119 | # 120 | # html_theme_options = {} 121 | 122 | # Add any paths that contain custom static files (such as style sheets) here, 123 | # relative to this directory. They are copied after the builtin static files, 124 | # so a file named "default.css" will overwrite the builtin "default.css". 125 | html_static_path = ["_static"] 126 | 127 | 128 | # -- Options for HTMLHelp output ------------------------------------------ 129 | 130 | # Output file base name for HTML help builder. 131 | htmlhelp_basename = "pyannotepipelinedoc" 132 | 133 | 134 | # -- Options for LaTeX output --------------------------------------------- 135 | 136 | latex_elements = { 137 | # The paper size ('letterpaper' or 'a4paper'). 138 | # 139 | # 'papersize': 'letterpaper', 140 | # The font size ('10pt', '11pt' or '12pt'). 141 | # 142 | # 'pointsize': '10pt', 143 | # Additional stuff for the LaTeX preamble. 144 | # 145 | # 'preamble': '', 146 | # Latex figure (float) alignment 147 | # 148 | # 'figure_align': 'htbp', 149 | } 150 | 151 | # Grouping the document tree into LaTeX files. List of tuples 152 | # (source start file, target name, title, 153 | # author, documentclass [howto, manual, or own class]). 154 | latex_documents = [ 155 | ( 156 | master_doc, 157 | "pyannotepipeline.tex", 158 | "pyannote.pipeline Documentation", 159 | "Hervé Bredin", 160 | "manual", 161 | ), 162 | ] 163 | 164 | 165 | # -- Options for manual page output --------------------------------------- 166 | 167 | # One entry per manual page. List of tuples 168 | # (source start file, name, description, authors, manual section). 169 | man_pages = [ 170 | (master_doc, "pyannotepipeline", "pyannote.pipeline Documentation", [author], 1) 171 | ] 172 | 173 | 174 | # -- Options for Texinfo output ------------------------------------------- 175 | 176 | # Grouping the document tree into Texinfo files. List of tuples 177 | # (source start file, target name, title, author, 178 | # dir menu entry, description, category) 179 | texinfo_documents = [ 180 | ( 181 | master_doc, 182 | "pyannotepipeline", 183 | "pyannote.pipeline Documentation", 184 | author, 185 | "pyannotepipeline", 186 | "One line description of project.", 187 | "Miscellaneous", 188 | ), 189 | ] 190 | 191 | # Example configuration for intersphinx: refer to the Python standard library. 192 | intersphinx_mapping = { 193 | "python": ("https://docs.python.org/", None), 194 | "pyannote.core": ("https://pyannote.github.io/pyannote-core", None), 195 | "pyannote.database": ("https://pyannote.github.io/pyannote-database", None), 196 | } 197 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | ################# 3 | pyannote.pipeline 4 | ################# 5 | 6 | `pyannote.pipeline` is an open-source Python library for pipeline optimization 7 | 8 | 9 | Installation 10 | ============ 11 | 12 | :: 13 | 14 | $ conda create -n pyannote python=3.6 anaconda 15 | $ source activate pyannote 16 | $ pip install --process-dependency-links pyannote.pipeline 17 | 18 | If on MacOS, an `ghalton` fails to install, you might need to do something like 19 | 20 | :: 21 | 22 | $ export CFLAGS="-Wno-deprecated-declarations -std=c++11 -stdlib=libc++" 23 | $ pip install ghalton 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pyannote-pipeline" 3 | description = "Tunable pipelines" 4 | readme = "README.md" 5 | authors = [ 6 | { name = "Hervé BREDIN", email = "herve@pyannote.ai" } 7 | ] 8 | requires-python = ">=3.10" 9 | 10 | dynamic = [ 11 | "version", 12 | ] 13 | 14 | dependencies = [ 15 | "filelock>=3.17.0", 16 | "optuna>=4.2.0", 17 | "pyannote-core>=5.0.0", 18 | "pyannote-database>=5.1.3", 19 | "pyyaml>=6.0.2", 20 | "tqdm>=4.67.1", 21 | ] 22 | 23 | [project.scripts] 24 | pyannote-pipeline = "pyannote.pipeline.experiment:main" 25 | 26 | 27 | [project.optional-dependencies] 28 | test = [ 29 | "pytest>=8.3.4", 30 | ] 31 | doc = [ 32 | "sphinx-rtd-theme>=3.0.2", 33 | "sphinx>=8.1.3", 34 | ] 35 | cli = [ 36 | "docopt>=0.6.2", 37 | ] 38 | 39 | [build-system] 40 | requires = ["hatchling", "hatch-vcs"] 41 | build-backend = "hatchling.build" 42 | 43 | [tool.hatch.version] 44 | source = "vcs" 45 | 46 | [tool.hatch.build.targets.wheel] 47 | packages = ["src/pyannote"] 48 | 49 | [dependency-groups] 50 | dev = [ 51 | "ipykernel>=6.29.5", 52 | ] 53 | -------------------------------------------------------------------------------- /src/pyannote/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2020 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | 29 | import importlib.metadata 30 | 31 | __version__ = importlib.metadata.version("pyannote-pipeline") 32 | 33 | from .pipeline import Pipeline 34 | from .optimizer import Optimizer 35 | 36 | __all__ = ["Pipeline", "Optimizer"] 37 | -------------------------------------------------------------------------------- /src/pyannote/pipeline/experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018- CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | 29 | """ 30 | Pipeline 31 | 32 | Usage: 33 | pyannote-pipeline train [options] [(--forever | --iterations=)] 34 | pyannote-pipeline best [options] 35 | pyannote-pipeline apply [options] 36 | pyannote-pipeline -h | --help 37 | pyannote-pipeline --version 38 | 39 | Common options: 40 | Experimental protocol (e.g. "Etape.SpeakerDiarization.TV") 41 | --registry= Path to, comma-separated, database configuration files. 42 | [default: ~/.pyannote/db.yml] 43 | --subset= Set subset. Defaults to 'development' in "train" 44 | mode, and to 'test' in "apply" mode. 45 | 46 | "train" mode: 47 | Set experiment root directory. This script expects 48 | a configuration file called "config.yml" to live 49 | in this directory. See "Configuration file" 50 | section below for more details. 51 | --iterations= Number of iterations. [default: 1] 52 | --forever Iterate forever. 53 | --sampler= Choose sampler between RandomSampler or TPESampler 54 | [default: TPESampler]. 55 | --pruner= Choose pruner between MedianPruner or 56 | SuccessiveHalvingPruner. Defaults to no pruning. 57 | --pretrained= Use parameters in existing training directory to 58 | bootstrap the optimization process. In practice, 59 | this will simply run a first trial with this set 60 | of parameters. 61 | --average-case Optimize for average case instead of worst case. 62 | 63 | "apply" mode: 64 | Path to the directory containing trained hyper- 65 | parameters (i.e. the output of "train" mode). 66 | 67 | --use-filter Apply pipeline only to files that pass the filter. 68 | 69 | Configuration file: 70 | The configuration of each experiment is described in a file called 71 | /config.yml that describes the pipeline. 72 | 73 | ................... /config.yml ................... 74 | pipeline: 75 | name: Yin2018 76 | params: 77 | sad: tutorials/pipeline/sad 78 | scd: tutorials/pipeline/scd 79 | emb: tutorials/pipeline/emb 80 | metric: angular 81 | 82 | # preprocessors can be used to automatically add keys into 83 | # each (dict) file obtained from pyannote.database protocols. 84 | preprocessors: 85 | audio: ~/.pyannote/db.yml # load template from YAML file 86 | video: ~/videos/{uri}.mp4 # define template directly 87 | 88 | # filters can be used to filter out some files from the protocol 89 | # (e.g. to only keep files with a specific number of speakers) 90 | filters: 91 | pyannote.audio.utils.protocol.FilterByNumberOfSpeakers: 92 | num_speakers: 2 93 | 94 | # one can freeze some hyper-parameters if needed (e.g. when 95 | # only part of the pipeline needs to be updated) 96 | freeze: 97 | speech_turn_segmentation: 98 | speech_activity_detection: 99 | onset: 0.5 100 | offset: 0.5 101 | 102 | # pyannote.audio pipelines will run on CPU by default. 103 | # use `device` key to send it to GPU. 104 | device: cuda 105 | ................................................................... 106 | 107 | "train" mode: 108 | Tune the pipeline hyper-parameters 109 | /..yml 110 | 111 | "best" mode: 112 | Display current best loss and corresponding hyper-paramters. 113 | 114 | "apply" mode 115 | Apply the pipeline (with best set of hyper-parameters) 116 | 117 | """ 118 | 119 | import os 120 | import os.path 121 | import yaml 122 | import numpy as np 123 | from typing import Optional 124 | from pathlib import Path 125 | from docopt import docopt 126 | 127 | import itertools 128 | from tqdm import tqdm 129 | from datetime import datetime 130 | 131 | from pyannote.database import FileFinder 132 | from pyannote.database import registry 133 | from pyannote.database import get_annotated 134 | 135 | from pyannote.core.utils.helper import get_class_by_name 136 | from .optimizer import Optimizer 137 | 138 | 139 | class Experiment: 140 | """Pipeline experiment 141 | 142 | Parameters 143 | ---------- 144 | experiment_dir : `Path` 145 | Experiment root directory. 146 | training : `bool`, optional 147 | Switch to training mode 148 | """ 149 | 150 | CONFIG_YML = "{experiment_dir}/config.yml" 151 | TRAIN_DIR = "{experiment_dir}/train/{protocol}.{subset}" 152 | APPLY_DIR = "{train_dir}/apply/{date}" 153 | 154 | @classmethod 155 | def from_train_dir(cls, train_dir: Path, training: bool = False) -> "Experiment": 156 | """Load pipeline from train directory 157 | 158 | Parameters 159 | ---------- 160 | train_dir : `Path` 161 | Path to train directory 162 | training : `bool`, optional 163 | Switch to training mode. 164 | 165 | Returns 166 | ------- 167 | xp : `Experiment` 168 | Pipeline experiment. 169 | """ 170 | experiment_dir = train_dir.parents[1] 171 | xp = cls(experiment_dir, training=training) 172 | params_yml = train_dir / "params.yml" 173 | xp.mtime_ = datetime.fromtimestamp(os.path.getmtime(params_yml)) 174 | xp.pipeline_.load_params(params_yml) 175 | return xp 176 | 177 | def __init__(self, experiment_dir: Path, training: bool = False): 178 | super().__init__() 179 | 180 | self.experiment_dir = experiment_dir 181 | 182 | # load configuration file 183 | config_yml = self.CONFIG_YML.format(experiment_dir=self.experiment_dir) 184 | with open(config_yml, "r") as fp: 185 | self.config_ = yaml.load(fp, Loader=yaml.SafeLoader) 186 | 187 | # initialize preprocessors 188 | preprocessors = {} 189 | for key, preprocessor in self.config_.get("preprocessors", {}).items(): 190 | # preprocessors: 191 | # key: 192 | # name: package.module.ClassName 193 | # params: 194 | # param1: value1 195 | # param2: value2 196 | if isinstance(preprocessor, dict): 197 | Klass = get_class_by_name( 198 | preprocessor["name"], default_module_name="pyannote.pipeline" 199 | ) 200 | preprocessors[key] = Klass(**preprocessor.get("params", {})) 201 | continue 202 | 203 | try: 204 | # preprocessors: 205 | # key: /path/to/database.yml 206 | preprocessors[key] = FileFinder(database_yml=preprocessor) 207 | 208 | except FileNotFoundError as e: 209 | # preprocessors: 210 | # key: /path/to/{uri}.wav 211 | template = preprocessor 212 | preprocessors[key] = template 213 | 214 | self.preprocessors_ = preprocessors 215 | 216 | # initialize filters 217 | filters = [] 218 | for key, params in self.config_.get("filters", {}).items(): 219 | Klass = get_class_by_name(key) 220 | filters.append(Klass(**params)) 221 | 222 | def all_filters(i) -> bool: 223 | return all(f(i) for f in filters) 224 | 225 | self.filters_ = all_filters 226 | 227 | # initialize pipeline 228 | pipeline_name = self.config_["pipeline"]["name"] 229 | Klass = get_class_by_name( 230 | pipeline_name, default_module_name="pyannote.pipeline.blocks" 231 | ) 232 | self.pipeline_ = Klass(**self.config_["pipeline"].get("params", {})) 233 | 234 | # freeze parameters 235 | if "freeze" in self.config_: 236 | params = self.config_["freeze"] 237 | self.pipeline_.freeze(params) 238 | 239 | # send to device 240 | if "device" in self.config_: 241 | import torch 242 | 243 | device = torch.device(self.config_["device"]) 244 | self.pipeline_.to(device) 245 | 246 | def train( 247 | self, 248 | protocol_name: str, 249 | subset: Optional[str] = "development", 250 | pretrained: Optional[Path] = None, 251 | n_iterations: int = 1, 252 | sampler: Optional[str] = None, 253 | pruner: Optional[str] = None, 254 | average_case: bool = False, 255 | ): 256 | """Train pipeline 257 | 258 | Parameters 259 | ---------- 260 | protocol_name : `str` 261 | Name of pyannote.database protocol to use. 262 | subset : `str`, optional 263 | Use this subset for training. Defaults to 'development'. 264 | pretrained : Path, optional 265 | Use parameters in "pretrained" training directory to bootstrap the 266 | optimization process. In practice this will simply run a first trial 267 | with this set of parameters. 268 | n_iterations : `int`, optional 269 | Number of iterations. Defaults to 1. 270 | sampler : `str`, optional 271 | Choose sampler between RandomSampler and TPESampler 272 | pruner : `str`, optional 273 | Choose between MedianPruner or SuccessiveHalvingPruner. 274 | average_case : `bool`, optional 275 | Optimise for average case. Defaults to False (i.e. worst case). 276 | """ 277 | train_dir = Path( 278 | self.TRAIN_DIR.format( 279 | experiment_dir=self.experiment_dir, 280 | protocol=protocol_name, 281 | subset=subset, 282 | ) 283 | ) 284 | train_dir.mkdir(parents=True, exist_ok=True) 285 | 286 | protocol = registry.get_protocol( 287 | protocol_name, preprocessors=self.preprocessors_ 288 | ) 289 | 290 | study_name = "default" 291 | optimizer = Optimizer( 292 | self.pipeline_, 293 | db=train_dir / "trials.journal", 294 | study_name=study_name, 295 | sampler=sampler, 296 | pruner=pruner, 297 | average_case=average_case, 298 | ) 299 | 300 | direction = 1 if self.pipeline_.get_direction() == "minimize" else -1 301 | 302 | params_yml = train_dir / "params.yml" 303 | 304 | progress_bar = tqdm(unit="trial", position=0, leave=True) 305 | progress_bar.set_description("First trial in progress") 306 | progress_bar.update(0) 307 | 308 | if pretrained: 309 | pre_params_yml = pretrained / "params.yml" 310 | with open(pre_params_yml, mode="r") as fp: 311 | pre_params = yaml.load(fp, Loader=yaml.SafeLoader) 312 | warm_start = pre_params["params"] 313 | 314 | else: 315 | warm_start = None 316 | 317 | inputs = list(filter(self.filters_, getattr(protocol, subset)())) 318 | 319 | iterations = optimizer.tune_iter( 320 | inputs, warm_start=warm_start, show_progress=True 321 | ) 322 | 323 | try: 324 | best_loss = optimizer.best_loss 325 | except ValueError as e: 326 | best_loss = direction * np.inf 327 | count = itertools.count() if n_iterations < 0 else range(n_iterations) 328 | 329 | for i, status in zip(count, iterations): 330 | loss = status["loss"] 331 | 332 | if direction * loss < direction * best_loss: 333 | best_params = status["params"] 334 | best_loss = loss 335 | self.pipeline_.dump_params( 336 | params_yml, params=best_params, loss=best_loss 337 | ) 338 | 339 | # progress bar 340 | desc = f"Best trial: {100 * best_loss:g}%" 341 | progress_bar.set_description(desc=desc) 342 | progress_bar.update(1) 343 | 344 | def best(self, protocol_name: str, subset: str = "development"): 345 | """Print current best pipeline 346 | 347 | Parameters 348 | ---------- 349 | protocol_name : `str` 350 | Name of pyannote.database protocol used for training. 351 | subset : `str`, optional 352 | Subset used for training. Defaults to 'development'. 353 | """ 354 | 355 | train_dir = Path( 356 | self.TRAIN_DIR.format( 357 | experiment_dir=self.experiment_dir, 358 | protocol=protocol_name, 359 | subset=subset, 360 | ) 361 | ) 362 | 363 | study_name = "default" 364 | optimizer = Optimizer( 365 | self.pipeline_, db=train_dir / "trials.journal", study_name=study_name 366 | ) 367 | 368 | try: 369 | best_loss = optimizer.best_loss 370 | except ValueError as e: 371 | print("Still waiting for at least one iteration to succeed.") 372 | return 373 | 374 | best_params = optimizer.best_params 375 | 376 | print(f"Loss = {100 * best_loss:g}% with the following hyper-parameters:") 377 | 378 | content = yaml.dump(best_params, default_flow_style=False) 379 | print(content) 380 | 381 | def apply( 382 | self, 383 | protocol_name: str, 384 | output_dir: Path, 385 | subset: Optional[str] = "test", 386 | use_filter: bool = False, 387 | ): 388 | """Apply current best pipeline 389 | 390 | Parameters 391 | ---------- 392 | protocol_name : `str` 393 | Name of pyannote.database protocol to process. 394 | subset : `str`, optional 395 | Subset to process. Defaults to 'test' 396 | """ 397 | 398 | # file generator 399 | protocol = registry.get_protocol( 400 | protocol_name, preprocessors=self.preprocessors_ 401 | ) 402 | 403 | # load pipeline metric (when available) 404 | try: 405 | metric = self.pipeline_.get_metric() 406 | except NotImplementedError as e: 407 | metric = None 408 | 409 | output_dir.mkdir(parents=True, exist_ok=True) 410 | if use_filter: 411 | output_ext = ( 412 | output_dir 413 | / f"{protocol_name}.{subset}_INCOMPLETE.{self.pipeline_.write_format}" 414 | ) 415 | else: 416 | output_ext = ( 417 | output_dir / f"{protocol_name}.{subset}.{self.pipeline_.write_format}" 418 | ) 419 | 420 | with open(output_ext, mode="w") as fp: 421 | files = list(getattr(protocol, subset)()) 422 | if use_filter: 423 | files = filter(self.filters_, files) 424 | 425 | desc = f"Processing {protocol_name} ({subset})" 426 | for current_file in tqdm(iterable=files, desc=desc, unit="file"): 427 | # apply pipeline and dump output to file 428 | output = self.pipeline_(current_file) 429 | self.pipeline_.write(fp, output) 430 | 431 | # compute evaluation metric (when possible) 432 | reference = current_file.get("annotation", None) 433 | if reference is None: 434 | metric = None 435 | 436 | # compute evaluation metric (when available) 437 | if metric is None: 438 | continue 439 | 440 | uem = get_annotated(current_file) 441 | _ = metric(reference, output, uem=uem) 442 | 443 | # "latest" symbolic link 444 | latest = output_dir.parent / "latest" 445 | if latest.exists(): 446 | latest.unlink() 447 | latest.symlink_to(output_dir) 448 | 449 | # print pipeline metric (when available) 450 | if metric is None: 451 | msg = ( 452 | f"For some (possibly good) reason, the output of this " 453 | f"pipeline could not be evaluated on {protocol_name}." 454 | ) 455 | print(msg) 456 | return 457 | 458 | if use_filter: 459 | output_eval = output_dir / f"{protocol_name}.{subset}_INCOMPLETE.eval" 460 | else: 461 | output_eval = output_dir / f"{protocol_name}.{subset}.eval" 462 | 463 | with open(output_eval, "w") as fp: 464 | fp.write(str(metric)) 465 | 466 | 467 | def main(): 468 | arguments = docopt(__doc__, version="Tunable pipelines") 469 | 470 | for database_yml in arguments["--registry"].split(","): 471 | registry.load_database(database_yml) 472 | 473 | protocol_name = arguments[""] 474 | subset = arguments["--subset"] 475 | 476 | if arguments["train"]: 477 | if subset is None: 478 | subset = "development" 479 | 480 | if arguments["--forever"]: 481 | iterations = -1 482 | else: 483 | iterations = int(arguments["--iterations"]) 484 | 485 | sampler = arguments["--sampler"] 486 | pruner = arguments["--pruner"] 487 | 488 | pretrained = arguments["--pretrained"] 489 | if pretrained: 490 | pretrained = Path(pretrained).expanduser().resolve(strict=True) 491 | 492 | average_case = arguments["--average-case"] 493 | 494 | experiment_dir = Path(arguments[""]) 495 | experiment_dir = experiment_dir.expanduser().resolve(strict=True) 496 | 497 | experiment = Experiment(experiment_dir, training=True) 498 | experiment.train( 499 | protocol_name, 500 | subset=subset, 501 | n_iterations=iterations, 502 | pretrained=pretrained, 503 | sampler=sampler, 504 | pruner=pruner, 505 | average_case=average_case, 506 | ) 507 | 508 | if arguments["best"]: 509 | if subset is None: 510 | subset = "development" 511 | 512 | experiment_dir = Path(arguments[""]) 513 | experiment_dir = experiment_dir.expanduser().resolve(strict=True) 514 | 515 | experiment = Experiment(experiment_dir, training=False) 516 | experiment.best(protocol_name, subset=subset) 517 | 518 | if arguments["apply"]: 519 | if subset is None: 520 | subset = "test" 521 | 522 | use_filter = arguments["--use-filter"] 523 | 524 | train_dir = Path(arguments[""]) 525 | train_dir = train_dir.expanduser().resolve(strict=True) 526 | experiment = Experiment.from_train_dir(train_dir, training=False) 527 | 528 | output_dir = Path( 529 | experiment.APPLY_DIR.format( 530 | train_dir=train_dir, date=experiment.mtime_.strftime("%Y%m%d-%H%M%S") 531 | ) 532 | ) 533 | 534 | experiment.apply( 535 | protocol_name, output_dir, subset=subset, use_filter=use_filter 536 | ) 537 | -------------------------------------------------------------------------------- /src/pyannote/pipeline/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2021 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | # Hadrien TITEUX 29 | 30 | import time 31 | import warnings 32 | from pathlib import Path 33 | from typing import Iterable, Optional, Callable, Generator, Mapping, Union, Dict 34 | 35 | import numpy as np 36 | import optuna.logging 37 | import optuna.pruners 38 | import optuna.samplers 39 | from optuna.exceptions import ExperimentalWarning 40 | from optuna.pruners import BasePruner 41 | from optuna.samplers import BaseSampler, TPESampler 42 | from optuna.trial import Trial, FixedTrial 43 | from optuna.storages import RDBStorage, JournalStorage, JournalFileStorage 44 | from tqdm import tqdm 45 | from optuna.storages import RDBStorage, JournalStorage, JournalFileStorage 46 | from scipy.stats import bayes_mvs 47 | 48 | from .pipeline import Pipeline 49 | from .typing import PipelineInput 50 | 51 | optuna.logging.set_verbosity(optuna.logging.WARNING) 52 | 53 | 54 | class Optimizer: 55 | """Pipeline optimizer 56 | 57 | Parameters 58 | ---------- 59 | pipeline : `Pipeline` 60 | Pipeline. 61 | db : `Path`, optional 62 | Path to trial database on disk. Use ".sqlite" extension for SQLite 63 | backend, and ".journal" for Journal backend (prefered for parallel 64 | optimization). 65 | study_name : `str`, optional 66 | Name of study. In case it already exists, study will continue from 67 | there. # TODO -- generate this automatically 68 | sampler : `str` or sampler instance, optional 69 | Algorithm for value suggestion. Must be one of "RandomSampler" or 70 | "TPESampler", or a sampler instance. Defaults to "TPESampler". 71 | pruner : `str` or pruner instance, optional 72 | Algorithm for early pruning of trials. Must be one of "MedianPruner" or 73 | "SuccessiveHalvingPruner", or a pruner instance. 74 | Defaults to no pruning. 75 | seed : `int`, optional 76 | Seed value for the random number generator of the sampler. 77 | Defaults to no seed. 78 | average_case : `bool`, optional 79 | Optimize for average case (default). 80 | Set to False to optimize for worst case. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | pipeline: Pipeline, 86 | db: Optional[Path] = None, 87 | study_name: Optional[str] = None, 88 | sampler: Optional[Union[str, BaseSampler]] = None, 89 | pruner: Optional[Union[str, BasePruner]] = None, 90 | seed: Optional[int] = None, 91 | average_case: bool = True, 92 | ): 93 | self.pipeline = pipeline 94 | 95 | self.db = db 96 | if db is None: 97 | self.storage_ = None 98 | else: 99 | extension = Path(self.db).suffix 100 | if extension == ".db": 101 | warnings.warn( 102 | "Storage with '.db' extension has been deprecated. Use '.sqlite' instead." 103 | ) 104 | self.storage_ = RDBStorage(f"sqlite:///{self.db}") 105 | elif extension == ".sqlite": 106 | self.storage_ = RDBStorage(f"sqlite:///{self.db}") 107 | elif extension == ".journal": 108 | self.storage_ = JournalStorage(JournalFileStorage(f"{self.db}")) 109 | self.study_name = study_name 110 | 111 | if isinstance(sampler, BaseSampler): 112 | self.sampler = sampler 113 | elif isinstance(sampler, str): 114 | try: 115 | self.sampler = getattr(optuna.samplers, sampler)(seed=seed) 116 | except AttributeError as e: 117 | msg = '`sampler` must be one of "RandomSampler" or "TPESampler"' 118 | raise ValueError(msg) 119 | elif sampler is None: 120 | self.sampler = TPESampler(seed=seed) 121 | 122 | if isinstance(pruner, BasePruner): 123 | self.pruner = pruner 124 | elif isinstance(pruner, str): 125 | try: 126 | self.pruner = getattr(optuna.pruners, pruner)() 127 | except AttributeError as e: 128 | msg = '`pruner` must be one of "MedianPruner" or "SuccessiveHalvingPruner"' 129 | raise ValueError(msg) 130 | else: 131 | self.pruner = None 132 | 133 | # generate name of study based on pipeline hash 134 | # Klass = pipeline.__class__ 135 | # study_name = f'{Klass.__module__}.{Klass.__name__}[{hash(pipeline)}]' 136 | 137 | self.study_ = optuna.create_study( 138 | study_name=self.study_name, 139 | load_if_exists=True, 140 | storage=self.storage_, 141 | sampler=self.sampler, 142 | pruner=self.pruner, 143 | direction=self.pipeline.get_direction(), 144 | ) 145 | 146 | self.average_case = average_case 147 | 148 | @property 149 | def best_loss(self) -> float: 150 | """Return best loss so far""" 151 | try: 152 | best_value = self.study_.best_value 153 | except Exception: 154 | direction: int = 1 if self.pipeline.get_direction() == "minimize" else -1 155 | best_value = direction * np.inf 156 | return best_value 157 | 158 | @property 159 | def best_params(self) -> dict: 160 | """Return best parameters so far""" 161 | trial = FixedTrial(self.study_.best_params) 162 | return self.pipeline.parameters(trial=trial) 163 | 164 | @property 165 | def best_pipeline(self) -> Pipeline: 166 | """Return pipeline instantiated with best parameters so far""" 167 | return self.pipeline.instantiate(self.best_params) 168 | 169 | def get_objective( 170 | self, 171 | inputs: Iterable[PipelineInput], 172 | show_progress: Union[bool, Dict] = False, 173 | ) -> Callable[[Trial], float]: 174 | """ 175 | Create objective function used by optuna 176 | 177 | Parameters 178 | ---------- 179 | inputs : `iterable` 180 | List of inputs to process. 181 | show_progress : bool or dict 182 | Show within-trial progress bar using tqdm progress bar. 183 | Can also be a **kwarg dict passed to tqdm. 184 | 185 | Returns 186 | ------- 187 | objective : `callable` 188 | Callable that takes trial as input and returns correspond loss. 189 | """ 190 | 191 | # this is needed for `inputs` that can be only iterated once. 192 | inputs = list(inputs) 193 | n_inputs = len(inputs) 194 | 195 | if show_progress == True: 196 | show_progress = {"desc": "Current trial", "leave": False, "position": 1} 197 | 198 | def objective(trial: Trial) -> float: 199 | """Compute objective value 200 | 201 | Parameter 202 | --------- 203 | trial : `Trial` 204 | Current trial 205 | 206 | Returns 207 | ------- 208 | loss : `float` 209 | Loss 210 | """ 211 | 212 | # use pyannote.metrics metric when available 213 | try: 214 | metric = self.pipeline.get_metric() 215 | except NotImplementedError as e: 216 | metric = None 217 | losses = [] 218 | 219 | processing_time = [] 220 | evaluation_time = [] 221 | 222 | # instantiate pipeline with value suggested in current trial 223 | pipeline = self.pipeline.instantiate(self.pipeline.parameters(trial=trial)) 224 | 225 | if show_progress != False: 226 | progress_bar = tqdm(total=len(inputs), **show_progress) 227 | progress_bar.update(0) 228 | 229 | # accumulate loss for each input 230 | for i, input in enumerate(inputs): 231 | # process input with pipeline 232 | # (and keep track of processing time) 233 | before_processing = time.time() 234 | 235 | # get optional kwargs to be passed to the pipeline 236 | # (e.g. num_speakers for speaker diarization). they 237 | # must be stored in a 'pipeline_kwargs' key in the 238 | # `input` dictionary. 239 | if isinstance(input, Mapping): 240 | pipeline_kwargs = input.get("pipeline_kwargs", {}) 241 | else: 242 | pipeline_kwargs = {} 243 | output = pipeline(input, **pipeline_kwargs) 244 | after_processing = time.time() 245 | processing_time.append(after_processing - before_processing) 246 | 247 | # evaluate output (and keep track of evaluation time) 248 | before_evaluation = time.time() 249 | 250 | # when metric is not available, use loss method instead 251 | if metric is None: 252 | loss = pipeline.loss(input, output) 253 | losses.append(loss) 254 | 255 | # when metric is available,`input` is expected to be provided 256 | # by a `pyannote.database` protocol 257 | else: 258 | from pyannote.database import get_annotated 259 | 260 | _ = metric(input["annotation"], output, uem=get_annotated(input)) 261 | 262 | after_evaluation = time.time() 263 | evaluation_time.append(after_evaluation - before_evaluation) 264 | 265 | if show_progress != False: 266 | progress_bar.update(1) 267 | 268 | if self.pruner is None: 269 | continue 270 | 271 | trial.report(np.mean(losses) if metric is None else abs(metric), i) 272 | if trial.should_prune(): 273 | raise optuna.TrialPruned() 274 | 275 | if show_progress != False: 276 | progress_bar.close() 277 | 278 | trial.set_user_attr("processing_time", sum(processing_time)) 279 | trial.set_user_attr("evaluation_time", sum(evaluation_time)) 280 | 281 | if metric is None: 282 | if len(np.unique(losses)) == 1: 283 | mean = lower_bound = upper_bound = losses[0] 284 | else: 285 | (mean, (lower_bound, upper_bound)), _, _ = bayes_mvs( 286 | losses, alpha=0.9 287 | ) 288 | else: 289 | mean, (lower_bound, upper_bound) = metric.confidence_interval(alpha=0.9) 290 | 291 | if self.average_case: 292 | if metric is None: 293 | return mean 294 | 295 | else: 296 | return abs(metric) 297 | 298 | return ( 299 | upper_bound 300 | if self.pipeline.get_direction() == "minimize" 301 | else lower_bound 302 | ) 303 | 304 | return objective 305 | 306 | def tune( 307 | self, 308 | inputs: Iterable[PipelineInput], 309 | n_iterations: int = 10, 310 | warm_start: dict = None, 311 | show_progress: Union[bool, Dict] = True, 312 | ) -> dict: 313 | """Tune pipeline 314 | 315 | Parameters 316 | ---------- 317 | inputs : iterable 318 | List of inputs processed by the pipeline at each iteration. 319 | n_iterations : int, optional 320 | Number of iterations. Defaults to 10. 321 | warm_start : dict, optional 322 | Nested dictionary of initial parameters used to bootstrap tuning. 323 | 324 | Returns 325 | ------- 326 | result : dict 327 | ['loss'] 328 | ['params'] nested dictionary of optimal parameters 329 | """ 330 | 331 | # pipeline is currently being optimized 332 | self.pipeline.training = True 333 | 334 | objective = self.get_objective(inputs, show_progress=show_progress) 335 | 336 | if warm_start: 337 | flattened_params = self.pipeline._flatten(warm_start) 338 | 339 | with warnings.catch_warnings(): 340 | warnings.filterwarnings("ignore", category=ExperimentalWarning) 341 | self.study_.enqueue_trial(flattened_params) 342 | 343 | self.study_.optimize(objective, n_trials=n_iterations, timeout=None, n_jobs=1) 344 | 345 | # pipeline is no longer being optimized 346 | self.pipeline.training = False 347 | 348 | return {"loss": self.best_loss, "params": self.best_params} 349 | 350 | def tune_iter( 351 | self, 352 | inputs: Iterable[PipelineInput], 353 | warm_start: dict = None, 354 | show_progress: Union[bool, Dict] = True, 355 | ) -> Generator[dict, None, None]: 356 | """ 357 | 358 | Parameters 359 | ---------- 360 | inputs : iterable 361 | List of inputs processed by the pipeline at each iteration. 362 | warm_start : dict, optional 363 | Nested dictionary of initial parameters used to bootstrap tuning. 364 | 365 | Yields 366 | ------ 367 | result : dict 368 | ['loss'] 369 | ['params'] nested dictionary of optimal parameters 370 | """ 371 | 372 | objective = self.get_objective(inputs, show_progress=show_progress) 373 | 374 | try: 375 | best_loss = self.best_loss 376 | except ValueError as e: 377 | best_loss = np.inf 378 | 379 | if warm_start: 380 | flattened_params = self.pipeline._flatten(warm_start) 381 | with warnings.catch_warnings(): 382 | warnings.filterwarnings("ignore", category=ExperimentalWarning) 383 | self.study_.enqueue_trial(flattened_params) 384 | 385 | while True: 386 | # pipeline is currently being optimized 387 | self.pipeline.training = True 388 | 389 | # one trial at a time 390 | self.study_.optimize(objective, n_trials=1, timeout=None, n_jobs=1) 391 | 392 | try: 393 | best_loss = self.best_loss 394 | best_params = self.best_params 395 | except ValueError as e: 396 | continue 397 | 398 | # pipeline is no longer being optimized 399 | self.pipeline.training = False 400 | 401 | yield {"loss": best_loss, "params": best_params} 402 | -------------------------------------------------------------------------------- /src/pyannote/pipeline/parameter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2020 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | # Hadrien TITEUX - https://github.com/hadware 29 | 30 | 31 | from typing import Iterable, Any 32 | from optuna.trial import Trial 33 | 34 | from .pipeline import Pipeline 35 | from collections.abc import Mapping 36 | 37 | 38 | class Parameter: 39 | """Base hyper-parameter""" 40 | 41 | pass 42 | 43 | 44 | class Categorical(Parameter): 45 | """Categorical hyper-parameter 46 | 47 | The value is sampled from `choices`. 48 | 49 | Parameters 50 | ---------- 51 | choices : iterable 52 | Candidates of hyper-parameter value. 53 | """ 54 | 55 | def __init__(self, choices: Iterable): 56 | super().__init__() 57 | self.choices = list(choices) 58 | 59 | def __call__(self, name: str, trial: Trial): 60 | return trial.suggest_categorical(name, self.choices) 61 | 62 | 63 | class DiscreteUniform(Parameter): 64 | """Discrete uniform hyper-parameter 65 | 66 | The value is sampled from the range [low, high], 67 | and the step of discretization is `q`. 68 | 69 | Parameters 70 | ---------- 71 | low : `float` 72 | Lower endpoint of the range of suggested values. 73 | `low` is included in the range. 74 | high : `float` 75 | Upper endpoint of the range of suggested values. 76 | `high` is included in the range. 77 | q : `float` 78 | A step of discretization. 79 | """ 80 | 81 | def __init__(self, low: float, high: float, q: float): 82 | super().__init__() 83 | self.low = float(low) 84 | self.high = float(high) 85 | self.q = float(q) 86 | 87 | def __call__(self, name: str, trial: Trial): 88 | return trial.suggest_discrete_uniform(name, self.low, self.high, self.q) 89 | 90 | 91 | class Integer(Parameter): 92 | """Integer hyper-parameter 93 | 94 | The value is sampled from the integers in [low, high]. 95 | 96 | Parameters 97 | ---------- 98 | low : `int` 99 | Lower endpoint of the range of suggested values. 100 | `low` is included in the range. 101 | high : `int` 102 | Upper endpoint of the range of suggested values. 103 | `high` is included in the range. 104 | """ 105 | 106 | def __init__(self, low: int, high: int): 107 | super().__init__() 108 | self.low = int(low) 109 | self.high = int(high) 110 | 111 | def __call__(self, name: str, trial: Trial): 112 | return trial.suggest_int(name, self.low, self.high) 113 | 114 | 115 | class LogUniform(Parameter): 116 | """Log-uniform hyper-parameter 117 | 118 | The value is sampled from the range [low, high) in the log domain. 119 | 120 | Parameters 121 | ---------- 122 | low : `float` 123 | Lower endpoint of the range of suggested values. 124 | `low` is included in the range. 125 | high : `float` 126 | Upper endpoint of the range of suggested values. 127 | `high` is excluded from the range. 128 | """ 129 | 130 | def __init__(self, low: float, high: float): 131 | super().__init__() 132 | self.low = float(low) 133 | self.high = float(high) 134 | 135 | def __call__(self, name: str, trial: Trial): 136 | return trial.suggest_loguniform(name, self.low, self.high) 137 | 138 | 139 | class Uniform(Parameter): 140 | """Uniform hyper-parameter 141 | 142 | The value is sampled from the range [low, high) in the linear domain. 143 | 144 | Parameters 145 | ---------- 146 | low : `float` 147 | Lower endpoint of the range of suggested values. 148 | `low` is included in the range. 149 | high : `float` 150 | Upper endpoint of the range of suggested values. 151 | `high` is excluded from the range. 152 | """ 153 | 154 | def __init__(self, low: float, high: float): 155 | super().__init__() 156 | self.low = float(low) 157 | self.high = float(high) 158 | 159 | def __call__(self, name: str, trial: Trial): 160 | return trial.suggest_uniform(name, self.low, self.high) 161 | 162 | 163 | class Frozen(Parameter): 164 | """Frozen hyper-parameter 165 | 166 | The value is fixed a priori 167 | 168 | Parameters 169 | ---------- 170 | value : 171 | Fixed value. 172 | """ 173 | 174 | def __init__(self, value: Any): 175 | super().__init__() 176 | self.value = value 177 | 178 | def __call__(self, name: str, trial: Trial): 179 | return self.value 180 | 181 | 182 | class ParamDict(Pipeline, Mapping): 183 | """Dict-like structured hyper-parameter 184 | 185 | Usage 186 | ----- 187 | >>> params = ParamDict(param1=Uniform(0.0, 1.0), param2=Uniform(-1.0, 1.0)) 188 | >>> params = ParamDict(**{"param1": Uniform(0.0, 1.0), "param2": Uniform(-1.0, 1.0)}) 189 | """ 190 | 191 | def __init__(self, **params): 192 | super().__init__() 193 | self.__params = params 194 | for param_name, param_value in params.items(): 195 | setattr(self, param_name, param_value) 196 | 197 | def __len__(self): 198 | return len(self.__params) 199 | 200 | def __iter__(self): 201 | return iter(self.__params) 202 | 203 | def __getitem__(self, param_name): 204 | return getattr(self, param_name) 205 | -------------------------------------------------------------------------------- /src/pyannote/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2022 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | 29 | from typing import Optional, TextIO, Union, Dict, Any 30 | 31 | from pathlib import Path 32 | from collections import OrderedDict 33 | from .typing import PipelineInput 34 | from .typing import PipelineOutput 35 | from .typing import Direction 36 | from filelock import FileLock 37 | import yaml 38 | import warnings 39 | 40 | from pyannote.core import Timeline 41 | from pyannote.core import Annotation 42 | from optuna.trial import Trial 43 | 44 | 45 | class Pipeline: 46 | """Base tunable pipeline""" 47 | 48 | def __init__(self): 49 | 50 | # un-instantiated parameters (= `Parameter` instances) 51 | self._parameters: Dict[str, Parameter] = OrderedDict() 52 | 53 | # instantiated parameters 54 | self._instantiated: Dict[str, Any] = OrderedDict() 55 | 56 | # sub-pipelines 57 | self._pipelines: Dict[str, Pipeline] = OrderedDict() 58 | 59 | # whether pipeline is currently being optimized 60 | self.training = False 61 | 62 | @property 63 | def training(self): 64 | return self._training 65 | 66 | @training.setter 67 | def training(self, training): 68 | self._training = training 69 | # recursively set sub-pipeline training attribute 70 | for _, pipeline in self._pipelines.items(): 71 | pipeline.training = training 72 | 73 | def __hash__(self): 74 | # FIXME -- also keep track of (sub)pipeline attributes 75 | frozen = self.parameters(frozen=True) 76 | return hash(tuple(sorted(self._flatten(frozen).items()))) 77 | 78 | def __getattr__(self, name): 79 | """(Advanced) attribute getter""" 80 | 81 | # in case `name` corresponds to an instantiated parameter value, returns it 82 | if "_instantiated" in self.__dict__: 83 | _instantiated = self.__dict__["_instantiated"] 84 | if name in _instantiated: 85 | return _instantiated[name] 86 | 87 | # in case `name` corresponds to a parameter, returns it 88 | if "_parameters" in self.__dict__: 89 | _parameters = self.__dict__["_parameters"] 90 | if name in _parameters: 91 | return _parameters[name] 92 | 93 | # in case `name` corresponds to a sub-pipeline, returns it 94 | if "_pipelines" in self.__dict__: 95 | _pipelines = self.__dict__["_pipelines"] 96 | if name in _pipelines: 97 | return _pipelines[name] 98 | 99 | msg = "'{}' object has no attribute '{}'".format(type(self).__name__, name) 100 | raise AttributeError(msg) 101 | 102 | def __setattr__(self, name, value): 103 | """(Advanced) attribute setter 104 | 105 | If `value` is an instance of `Parameter`, store it in `_parameters`. 106 | elif `value` is an instance of `Pipeline`, store it in `_pipelines`. 107 | elif `value` isn't an instance of `Parameter` and `name` is in `_parameters`, 108 | store `value` in `_instantiated`. 109 | """ 110 | 111 | # imported here to avoid circular import 112 | from .parameter import Parameter 113 | 114 | def remove_from(*dicts): 115 | for d in dicts: 116 | if name in d: 117 | del d[name] 118 | 119 | _parameters = self.__dict__.get("_parameters") 120 | _instantiated = self.__dict__.get("_instantiated") 121 | _pipelines = self.__dict__.get("_pipelines") 122 | 123 | # if `value` is an instance of `Parameter`, store it in `_parameters` 124 | 125 | if isinstance(value, Parameter): 126 | if _parameters is None: 127 | msg = ( 128 | "cannot assign hyper-parameters " "before Pipeline.__init__() call" 129 | ) 130 | raise AttributeError(msg) 131 | remove_from(self.__dict__, _instantiated, _pipelines) 132 | _parameters[name] = value 133 | return 134 | 135 | # add/update one sub-pipeline 136 | if isinstance(value, Pipeline): 137 | if _pipelines is None: 138 | msg = "cannot assign sub-pipelines " "before Pipeline.__init__() call" 139 | raise AttributeError(msg) 140 | remove_from(self.__dict__, _parameters, _instantiated) 141 | _pipelines[name] = value 142 | return 143 | 144 | # store instantiated parameter value 145 | if _parameters is not None and name in _parameters: 146 | _instantiated[name] = value 147 | return 148 | 149 | object.__setattr__(self, name, value) 150 | 151 | def __delattr__(self, name): 152 | 153 | if name in self._parameters: 154 | del self._parameters[name] 155 | 156 | elif name in self._instantiated: 157 | del self._instantiated[name] 158 | 159 | elif name in self._pipelines: 160 | del self._pipelines[name] 161 | 162 | else: 163 | object.__delattr__(self, name) 164 | 165 | def _flattened_parameters( 166 | self, frozen: Optional[bool] = False, instantiated: Optional[bool] = False 167 | ) -> dict: 168 | """Get flattened dictionary of parameters 169 | 170 | Parameters 171 | ---------- 172 | frozen : `bool`, optional 173 | Only return value of frozen parameters. 174 | instantiated : `bool`, optional 175 | Only return value of instantiated parameters. 176 | 177 | Returns 178 | ------- 179 | params : `dict` 180 | Flattened dictionary of parameters. 181 | """ 182 | 183 | # imported here to avoid circular imports 184 | from .parameter import Frozen 185 | 186 | if frozen and instantiated: 187 | msg = "one must choose between `frozen` and `instantiated`." 188 | raise ValueError(msg) 189 | 190 | # initialize dictionary with root parameters 191 | if instantiated: 192 | params = dict(self._instantiated) 193 | 194 | elif frozen: 195 | params = { 196 | n: p.value for n, p in self._parameters.items() if isinstance(p, Frozen) 197 | } 198 | 199 | else: 200 | params = dict(self._parameters) 201 | 202 | # recursively add sub-pipeline parameters 203 | for pipeline_name, pipeline in self._pipelines.items(): 204 | pipeline_params = pipeline._flattened_parameters( 205 | frozen=frozen, instantiated=instantiated 206 | ) 207 | for name, value in pipeline_params.items(): 208 | params[f"{pipeline_name}>{name}"] = value 209 | 210 | return params 211 | 212 | def _flatten(self, nested_params: dict) -> dict: 213 | """Convert nested dictionary to flattened dictionary 214 | 215 | For instance, a nested dictionary like this one: 216 | 217 | ~~~~~~~~~~~~~~~~~~~~~ 218 | param: value1 219 | pipeline: 220 | param: value2 221 | subpipeline: 222 | param: value3 223 | ~~~~~~~~~~~~~~~~~~~~~ 224 | 225 | becomes the following flattened dictionary: 226 | 227 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 228 | param : value1 229 | pipeline>param : value2 230 | pipeline>subpipeline>param : value3 231 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 232 | 233 | Parameter 234 | --------- 235 | nested_params : `dict` 236 | 237 | Returns 238 | ------- 239 | flattened_params : `dict` 240 | """ 241 | flattened_params = dict() 242 | for name, value in nested_params.items(): 243 | if isinstance(value, dict): 244 | for subname, subvalue in self._flatten(value).items(): 245 | flattened_params[f"{name}>{subname}"] = subvalue 246 | else: 247 | flattened_params[name] = value 248 | return flattened_params 249 | 250 | def _unflatten(self, flattened_params: dict) -> dict: 251 | """Convert flattened dictionary to nested dictionary 252 | 253 | For instance, a flattened dictionary like this one: 254 | 255 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 256 | param : value1 257 | pipeline>param : value2 258 | pipeline>subpipeline>param : value3 259 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 260 | 261 | becomes the following nested dictionary: 262 | 263 | ~~~~~~~~~~~~~~~~~~~~~ 264 | param: value1 265 | pipeline: 266 | param: value2 267 | subpipeline: 268 | param: value3 269 | ~~~~~~~~~~~~~~~~~~~~~ 270 | 271 | Parameter 272 | --------- 273 | flattened_params : `dict` 274 | 275 | Returns 276 | ------- 277 | nested_params : `dict` 278 | """ 279 | 280 | nested_params = {} 281 | 282 | pipeline_params = {name: {} for name in self._pipelines} 283 | for name, value in flattened_params.items(): 284 | # if name contains has multipe ">"-separated tokens 285 | # it means that it is a sub-pipeline parameter 286 | tokens = name.split(">") 287 | if len(tokens) > 1: 288 | # read sub-pipeline name 289 | pipeline_name = tokens[0] 290 | # read parameter name 291 | param_name = ">".join(tokens[1:]) 292 | # update sub-pipeline flattened dictionary 293 | pipeline_params[pipeline_name][param_name] = value 294 | 295 | # otherwise, it is an actual parameter of this pipeline 296 | else: 297 | # store it as such 298 | nested_params[name] = value 299 | 300 | # recursively unflatten sub-pipeline flattened dictionary 301 | for name, pipeline in self._pipelines.items(): 302 | nested_params[name] = pipeline._unflatten(pipeline_params[name]) 303 | 304 | return nested_params 305 | 306 | def parameters( 307 | self, 308 | trial: Optional[Trial] = None, 309 | frozen: Optional[bool] = False, 310 | instantiated: Optional[bool] = False, 311 | ) -> dict: 312 | """Returns nested dictionary of (optionnaly instantiated) parameters. 313 | 314 | For a pipeline with one `param`, one sub-pipeline with its own param 315 | and its own sub-pipeline, it will returns something like: 316 | 317 | ~~~~~~~~~~~~~~~~~~~~~ 318 | param: value1 319 | pipeline: 320 | param: value2 321 | subpipeline: 322 | param: value3 323 | ~~~~~~~~~~~~~~~~~~~~~ 324 | 325 | Parameter 326 | --------- 327 | trial : `Trial`, optional 328 | When provided, use trial to suggest new parameter values 329 | and return them. 330 | frozen : `bool`, optional 331 | Return frozen parameter value 332 | instantiated : `bool`, optional 333 | Return instantiated parameter values. 334 | 335 | Returns 336 | ------- 337 | params : `dict` 338 | Nested dictionary of parameters. See above for the actual format. 339 | """ 340 | 341 | if (instantiated or frozen) and trial is not None: 342 | msg = "One must choose between `trial`, `instantiated`, or `frozen`" 343 | raise ValueError(msg) 344 | 345 | # get flattened dictionary of uninstantiated parameters 346 | params = self._flattened_parameters(frozen=frozen, instantiated=instantiated) 347 | 348 | if trial is not None: 349 | # use provided `trial` to suggest values for parameters 350 | params = {name: param(name, trial) for name, param in params.items()} 351 | 352 | # un-flatten flattened dictionary 353 | return self._unflatten(params) 354 | 355 | def initialize(self): 356 | """Instantiate root pipeline with current set of parameters""" 357 | pass 358 | 359 | def freeze(self, params: dict) -> "Pipeline": 360 | """Recursively freeze pipeline parameters 361 | 362 | Parameters 363 | ---------- 364 | params : `dict` 365 | Nested dictionary of parameters. 366 | 367 | Returns 368 | ------- 369 | self : `Pipeline` 370 | Pipeline. 371 | """ 372 | 373 | # imported here to avoid circular imports 374 | from .parameter import Frozen 375 | 376 | for name, value in params.items(): 377 | 378 | # recursively freeze sub-pipelines parameters 379 | if name in self._pipelines: 380 | if not isinstance(value, dict): 381 | msg = ( 382 | f"only parameters of '{name}' pipeline can " 383 | f"be frozen (not the whole pipeline)" 384 | ) 385 | raise ValueError(msg) 386 | self._pipelines[name].freeze(value) 387 | continue 388 | 389 | # instantiate parameter value 390 | if name in self._parameters: 391 | setattr(self, name, Frozen(value)) 392 | continue 393 | 394 | msg = f"parameter '{name}' does not exist" 395 | raise ValueError(msg) 396 | 397 | return self 398 | 399 | def instantiate(self, params: dict) -> "Pipeline": 400 | """Recursively instantiate all pipelines 401 | 402 | Parameters 403 | ---------- 404 | params : `dict` 405 | Nested dictionary of parameters. 406 | 407 | Returns 408 | ------- 409 | self : `Pipeline` 410 | Instantiated pipeline. 411 | """ 412 | 413 | # imported here to avoid circular imports 414 | from .parameter import Frozen 415 | 416 | for name, value in params.items(): 417 | 418 | # recursively call `instantiate` with sub-pipelines 419 | if name in self._pipelines: 420 | if not isinstance(value, dict): 421 | msg = ( 422 | f"only parameters of '{name}' pipeline can " 423 | f"be instantiated (not the whole pipeline)" 424 | ) 425 | raise ValueError(msg) 426 | self._pipelines[name].instantiate(value) 427 | continue 428 | 429 | # instantiate parameter value 430 | if name in self._parameters: 431 | param = getattr(self, name) 432 | # overwrite provided value of frozen parameters 433 | if isinstance(param, Frozen) and param.value != value: 434 | msg = ( 435 | f"Parameter '{name}' is frozen: using its frozen value " 436 | f"({param.value}) instead of the one provided ({value})." 437 | ) 438 | warnings.warn(msg) 439 | value = param.value 440 | setattr(self, name, value) 441 | continue 442 | 443 | msg = f"parameter '{name}' does not exist" 444 | raise ValueError(msg) 445 | 446 | self.initialize() 447 | 448 | return self 449 | 450 | @property 451 | def instantiated(self): 452 | """Whether pipeline has been instantiated (and therefore can be applied)""" 453 | parameters = set(self._flatten(self.parameters())) 454 | instantiated = set(self._flatten(self.parameters(instantiated=True))) 455 | return parameters == instantiated 456 | 457 | def dump_params( 458 | self, 459 | params_yml: Path, 460 | params: Optional[dict] = None, 461 | loss: Optional[float] = None, 462 | ) -> str: 463 | """Dump parameters to disk 464 | 465 | Parameters 466 | ---------- 467 | params_yml : `Path` 468 | Path to YAML file. 469 | params : `dict`, optional 470 | Nested Parameters. Defaults to pipeline current parameters. 471 | loss : `float`, optional 472 | Loss value. Defaults to not write loss to file. 473 | 474 | Returns 475 | ------- 476 | content : `str` 477 | Content written in `param_yml`. 478 | """ 479 | # use instantiated parameters when `params` is not provided 480 | if params is None: 481 | params = self.parameters(instantiated=True) 482 | 483 | content = {"params": params} 484 | if loss is not None: 485 | content["loss"] = loss 486 | 487 | # format as valid YAML 488 | content_yml = yaml.dump(content, default_flow_style=False) 489 | 490 | # (safely) dump YAML content 491 | with FileLock(params_yml.with_suffix(".lock")): 492 | with open(params_yml, mode="w") as fp: 493 | fp.write(content_yml) 494 | 495 | return content_yml 496 | 497 | def load_params(self, params_yml: Path) -> "Pipeline": 498 | """Instantiate pipeline using parameters from disk 499 | 500 | Parameters 501 | ---------- 502 | param_yml : `Path` 503 | Path to YAML file. 504 | 505 | Returns 506 | ------- 507 | self : `Pipeline` 508 | Instantiated pipeline 509 | 510 | """ 511 | 512 | with open(params_yml, mode="r") as fp: 513 | params = yaml.load(fp, Loader=yaml.SafeLoader) 514 | return self.instantiate(params["params"]) 515 | 516 | def __call__(self, input: PipelineInput) -> PipelineOutput: 517 | """Apply pipeline on input and return its output""" 518 | raise NotImplementedError 519 | 520 | def get_metric(self) -> "pyannote.metrics.base.BaseMetric": 521 | """Return new metric (from pyannote.metrics) 522 | 523 | When this method is implemented, the returned metric is used as a 524 | replacement for the loss method below. 525 | 526 | Returns 527 | ------- 528 | metric : `pyannote.metrics.base.BaseMetric` 529 | """ 530 | raise NotImplementedError() 531 | 532 | def get_direction(self) -> Direction: 533 | return "minimize" 534 | 535 | def loss(self, input: PipelineInput, output: PipelineOutput) -> float: 536 | """Compute loss for given input/output pair 537 | 538 | Parameters 539 | ---------- 540 | input : object 541 | Pipeline input. 542 | output : object 543 | Pipeline output 544 | 545 | Returns 546 | ------- 547 | loss : `float` 548 | Loss value 549 | """ 550 | raise NotImplementedError() 551 | 552 | @property 553 | def write_format(self): 554 | return "rttm" 555 | 556 | def write(self, file: TextIO, output: PipelineOutput): 557 | """Write pipeline output to file 558 | 559 | Parameters 560 | ---------- 561 | file : file object 562 | output : object 563 | Pipeline output 564 | """ 565 | 566 | return getattr(self, f"write_{self.write_format}")(file, output) 567 | 568 | def write_rttm(self, file: TextIO, output: Union[Timeline, Annotation]): 569 | """Write pipeline output to "rttm" file 570 | 571 | Parameters 572 | ---------- 573 | file : file object 574 | output : `pyannote.core.Timeline` or `pyannote.core.Annotation` 575 | Pipeline output 576 | """ 577 | 578 | if isinstance(output, Timeline): 579 | output = output.to_annotation(generator="string") 580 | 581 | if isinstance(output, Annotation): 582 | for s, t, l in output.itertracks(yield_label=True): 583 | line = ( 584 | f"SPEAKER {output.uri} 1 {s.start:.3f} {s.duration:.3f} " 585 | f" {l} \n" 586 | ) 587 | file.write(line) 588 | return 589 | 590 | msg = ( 591 | f'Dumping {output.__class__.__name__} instances to "rttm" files ' 592 | f"is not supported." 593 | ) 594 | raise NotImplementedError(msg) 595 | 596 | def write_txt(self, file: TextIO, output: Union[Timeline, Annotation]): 597 | """Write pipeline output to "txt" file 598 | 599 | Parameters 600 | ---------- 601 | file : file object 602 | output : `pyannote.core.Timeline` or `pyannote.core.Annotation` 603 | Pipeline output 604 | """ 605 | 606 | if isinstance(output, Timeline): 607 | for s in output: 608 | line = f"{output.uri} {s.start:.3f} {s.end:.3f}\n" 609 | file.write(line) 610 | return 611 | 612 | if isinstance(output, Annotation): 613 | for s, t, l in output.itertracks(yield_label=True): 614 | line = f"{output.uri} {s.start:.3f} {s.end:.3f} {t} {l}\n" 615 | file.write(line) 616 | return 617 | 618 | msg = ( 619 | f'Dumping {output.__class__.__name__} instances to "txt" files ' 620 | f"is not supported." 621 | ) 622 | raise NotImplementedError(msg) 623 | -------------------------------------------------------------------------------- /src/pyannote/pipeline/typing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2020 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | 29 | from typing import TypeVar 30 | 31 | PipelineInput = TypeVar("PipelineInput") 32 | PipelineOutput = TypeVar("PipelineOutput") 33 | 34 | try: 35 | from typing import Literal 36 | except ImportError: 37 | from typing_extensions import Literal 38 | Direction = Literal["minimize", "maximize"] 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2021 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | # Hadrien TITEUX -------------------------------------------------------------------------------- /tests/test_optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2022 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | # Hadrien TITEUX 29 | 30 | from typing import List, Dict, Any 31 | 32 | import numpy as np 33 | import pytest 34 | from optuna.samplers import TPESampler 35 | 36 | from pyannote.pipeline import Pipeline, Optimizer 37 | from pyannote.pipeline.parameter import Integer, ParamDict 38 | from pyannote.pipeline.typing import Direction 39 | 40 | 41 | def optimizer_tester(pipeline: Pipeline, target: Any): 42 | dataset = np.ones(10) 43 | sampler = TPESampler(seed=4577) 44 | optimizer = Optimizer(pipeline, sampler=sampler) 45 | optimizer.tune(dataset, n_iterations=100, show_progress=False) 46 | assert optimizer.best_params == target 47 | 48 | 49 | @pytest.mark.parametrize("target, direction", [ 50 | ({'param_a': 10, 'param_b': 10}, "maximize"), 51 | ({'param_a': 0, 'param_b': 0}, "minimize") 52 | ]) 53 | def test_basic_optimization(target, direction: Direction): 54 | class SumPipeline(Pipeline): 55 | 56 | def __init__(self): 57 | super().__init__() 58 | self.param_a: int = Integer(0, 10) 59 | self.param_b: int = Integer(0, 10) 60 | 61 | def __call__(self, data: float) -> float: 62 | return data + self.param_a + self.param_b 63 | 64 | def loss(self, data: float, y_preds: float) -> float: 65 | return y_preds 66 | 67 | def get_direction(self) -> Direction: 68 | return direction 69 | 70 | optimizer_tester(pipeline=SumPipeline(), target=target) 71 | 72 | 73 | @pytest.mark.parametrize("target, direction", [ 74 | ({'param_dict': {'param_a': 10, 'param_b': 10}}, "maximize"), 75 | ({'param_dict': {'param_a': 0, 'param_b': 0}}, "minimize") 76 | ]) 77 | def test_structured_dict_param_optim(target, direction: Direction): 78 | class SumPipeline(Pipeline): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | self.param_dict: Dict[str, int] = ParamDict( 83 | param_a=Integer(0, 10), 84 | param_b=Integer(0, 10) 85 | ) 86 | 87 | def __call__(self, data: float) -> float: 88 | return data + self.param_dict["param_b"] + self.param_dict["param_a"] 89 | 90 | def loss(self, data: float, y_preds: float) -> float: 91 | return y_preds 92 | 93 | def get_direction(self) -> Direction: 94 | return direction 95 | 96 | optimizer_tester(pipeline=SumPipeline(), target=target) 97 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2022 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | # Hadrien TITEUX 29 | 30 | from pyannote.pipeline import Pipeline 31 | from pyannote.pipeline.parameter import Uniform, Integer, ParamDict 32 | from .utils import FakeTrial 33 | 34 | 35 | def pipeline_tester(pl, fake_params): 36 | assert pl.parameters(FakeTrial()) == fake_params 37 | pl.instantiate(pl.parameters(FakeTrial())) 38 | assert (pl._unflatten(pl._flattened_parameters(instantiated=True)) 39 | == 40 | fake_params 41 | ) 42 | 43 | 44 | def test_pipeline_params_simple(): 45 | class TestPipeline(Pipeline): 46 | 47 | def __init__(self): 48 | super().__init__() 49 | self.param_a = Uniform(0, 1) 50 | self.param_b = Integer(3, 10) 51 | 52 | pl = TestPipeline() 53 | # assert pl.parameters(FakeTrial()) == {"param_a": 0, "param_b": 3} 54 | pipeline_tester(pl, {"param_a": 0, "param_b": 3}) 55 | 56 | 57 | def test_pipeline_params_structured(): 58 | class TestPipeline(Pipeline): 59 | 60 | def __init__(self): 61 | super().__init__() 62 | self.params_dict = ParamDict(**{ 63 | "param_a": Uniform(0, 1), 64 | "param_b": Integer(5, 10) 65 | }) 66 | 67 | pl = TestPipeline() 68 | fake_params = {'params_dict': {'param_a': 0.0, 69 | 'param_b': 5}} 70 | pipeline_tester(pl, fake_params) 71 | 72 | 73 | def test_pipeline_with_subpipeline(): 74 | class SubPipeline(Pipeline): 75 | 76 | def __init__(self): 77 | super().__init__() 78 | self.param = Uniform(0, 1) 79 | 80 | class TestPipeline(Pipeline): 81 | 82 | def __init__(self): 83 | super().__init__() 84 | self.params_dict = ParamDict(**{ 85 | "param_a": Uniform(0, 1), 86 | "param_b": Integer(5, 10) 87 | }) 88 | self.subpl = SubPipeline() 89 | 90 | pl = TestPipeline() 91 | fake_params = {'subpl': {'param': 0.0}, 92 | 'params_dict': {'param_a': 0.0, 93 | 'param_b': 5} 94 | } 95 | 96 | pipeline_tester(pl, fake_params) 97 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | # The MIT License (MIT) 5 | 6 | # Copyright (c) 2018-2021 CNRS 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # AUTHORS 27 | # Hervé BREDIN - http://herve.niderb.fr 28 | # Hadrien TITEUX 29 | from typing import Sequence 30 | 31 | from optuna import Trial 32 | from optuna.distributions import CategoricalChoiceType 33 | 34 | 35 | class FakeTrial(Trial): 36 | 37 | def __init__(self): 38 | pass 39 | 40 | def suggest_uniform(self, name: str, low: float, high: float) -> float: 41 | return low 42 | 43 | def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: 44 | return low 45 | 46 | def suggest_categorical(self, name: str, choices: Sequence[CategoricalChoiceType]) -> CategoricalChoiceType: 47 | return choices[0] 48 | --------------------------------------------------------------------------------