├── .coveragerc
├── .gitattributes
├── .github
└── workflows
│ └── python-package.yml
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── docs
├── index.html
├── search.js
├── simpleder.html
└── simpleder
│ └── der.html
├── publish.sh
├── requirements.txt
├── run_pdoc.sh
├── run_tests.sh
├── setup.py
├── simpleder
├── __init__.py
└── der.py
└── tests
└── det_test.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source=simpleder
3 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.8
20 | uses: actions/setup-python@v1
21 | with:
22 | python-version: 3.8
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install flake8 codecov
27 | pip install -r requirements.txt
28 | - name: Lint with flake8
29 | run: |
30 | flake8 .
31 | - name: Run tests
32 | run: |
33 | bash run_tests.sh
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | *.pyc
3 | build/*
4 | dist/*
5 | simpleder.egg-info/*
6 | .coverage
7 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.tabSize": 4,
3 | "editor.insertSpaces": true,
4 | "editor.rulers": [
5 | 80
6 | ],
7 | "files.trimFinalNewlines": true,
8 | "files.trimTrailingWhitespace": true,
9 | "editor.formatOnSave": true,
10 | "terminal.integrated.fontSize": 13,
11 | "python.formatting.provider": "autopep8",
12 | "python.formatting.autopep8Args": [
13 | "--indent-size=4",
14 | "--max-line-length=80"
15 | ],
16 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 Quan Wang
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimpleDER  [](https://pypi.python.org/pypi/simpleder) [](https://pypi.org/project/simpleder) [](https://pepy.tech/project/simpleder) [](https://codecov.io/gh/wq2012/SimpleDER) [](https://wq2012.github.io/SimpleDER)
2 |
3 | ## Overview
4 |
5 | This is a lightweight library to compute Diarization Error Rate (DER).
6 |
7 | Features **NOT** supported:
8 |
9 | * Handling overlapped speech, *i.e.* two speakers speaking at the same time.
10 | * Allowing segment boundary tolerance, *a.k.a.* the `collar` value.
11 |
12 | For more sophisticated metrics with these supports, please use
13 | [pyannote-metrics](https://github.com/pyannote/pyannote-metrics) instead.
14 |
15 | To learn more about speaker diarization, here is a curated list of resources:
16 | [awesome-diarization](https://github.com/wq2012/awesome-diarization).
17 |
18 | ## Diarization Error Rate
19 |
20 | Diarization Error Rate (DER) is the most commonly used metrics for
21 | [speaker diarization](https://en.wikipedia.org/wiki/Speaker_diarisation).
22 |
23 | Its strict form is:
24 |
25 | ```
26 | False Alarm + Miss + Overlap + Confusion
27 | DER = ------------------------------------------
28 | Reference Length
29 | ```
30 |
31 | The definition of each term:
32 |
33 | * `Reference Length:` The total length of the reference (ground truth).
34 | * `False Alarm`: Length of segments which are considered as speech in
35 | hypothesis, but not in reference.
36 | * `Miss`: Length of segments which are considered as speech in
37 | reference, but not in hypothesis.
38 | * `Overlap`: Length of segments which are considered as overlapped speech
39 | in hypothesis, but not in reference.
40 | **This library does NOT support overlap.**
41 | * `Confusion`: Length of segments which are assigned to different speakers
42 | in hypothesis and reference (after applying an optimal assignment).
43 |
44 | The unit of each term is *seconds*.
45 |
46 | Note that DER can theoretically be larger than 1.0.
47 |
48 | References:
49 |
50 | * [pyannote-metrics documentation](https://pyannote.github.io/pyannote-metrics/reference.html)
51 | * [Xavier Anguera's thesis](http://www.xavieranguera.com/phdthesis/node108.html)
52 |
53 | ## Tutorial
54 |
55 | ### Install
56 |
57 | Install the package by:
58 |
59 | ```bash
60 | pip3 install simpleder
61 | ```
62 |
63 | or
64 |
65 | ```bash
66 | python3 -m pip install simpleder
67 | ```
68 |
69 | ### API
70 |
71 | Here is a minimal example:
72 |
73 | ```python
74 | import simpleder
75 |
76 | # reference (ground truth)
77 | ref = [("A", 0.0, 1.0),
78 | ("B", 1.0, 1.5),
79 | ("A", 1.6, 2.1)]
80 |
81 | # hypothesis (diarization result from your algorithm)
82 | hyp = [("1", 0.0, 0.8),
83 | ("2", 0.8, 1.4),
84 | ("3", 1.5, 1.8),
85 | ("1", 1.8, 2.0)]
86 |
87 | error = simpleder.DER(ref, hyp)
88 |
89 | print("DER={:.3f}".format(error))
90 | ```
91 |
92 | This should output:
93 |
94 | ```
95 | DER=0.350
96 | ```
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/docs/search.js:
--------------------------------------------------------------------------------
1 | window.pdocSearch = (function(){
2 | /** elasticlunr - http://weixsong.github.io * Copyright (C) 2017 Oliver Nightingale * Copyright (C) 2017 Wei Song * MIT Licensed */!function(){function e(e){if(null===e||"object"!=typeof e)return e;var t=e.constructor();for(var n in e)e.hasOwnProperty(n)&&(t[n]=e[n]);return t}var t=function(e){var n=new t.Index;return n.pipeline.add(t.trimmer,t.stopWordFilter,t.stemmer),e&&e.call(n,n),n};t.version="0.9.5",lunr=t,t.utils={},t.utils.warn=function(e){return function(t){e.console&&console.warn&&console.warn(t)}}(this),t.utils.toString=function(e){return void 0===e||null===e?"":e.toString()},t.EventEmitter=function(){this.events={}},t.EventEmitter.prototype.addListener=function(){var e=Array.prototype.slice.call(arguments),t=e.pop(),n=e;if("function"!=typeof t)throw new TypeError("last argument must be a function");n.forEach(function(e){this.hasHandler(e)||(this.events[e]=[]),this.events[e].push(t)},this)},t.EventEmitter.prototype.removeListener=function(e,t){if(this.hasHandler(e)){var n=this.events[e].indexOf(t);-1!==n&&(this.events[e].splice(n,1),0==this.events[e].length&&delete this.events[e])}},t.EventEmitter.prototype.emit=function(e){if(this.hasHandler(e)){var t=Array.prototype.slice.call(arguments,1);this.events[e].forEach(function(e){e.apply(void 0,t)},this)}},t.EventEmitter.prototype.hasHandler=function(e){return e in this.events},t.tokenizer=function(e){if(!arguments.length||null===e||void 0===e)return[];if(Array.isArray(e)){var n=e.filter(function(e){return null===e||void 0===e?!1:!0});n=n.map(function(e){return t.utils.toString(e).toLowerCase()});var i=[];return n.forEach(function(e){var n=e.split(t.tokenizer.seperator);i=i.concat(n)},this),i}return e.toString().trim().toLowerCase().split(t.tokenizer.seperator)},t.tokenizer.defaultSeperator=/[\s\-]+/,t.tokenizer.seperator=t.tokenizer.defaultSeperator,t.tokenizer.setSeperator=function(e){null!==e&&void 0!==e&&"object"==typeof e&&(t.tokenizer.seperator=e)},t.tokenizer.resetSeperator=function(){t.tokenizer.seperator=t.tokenizer.defaultSeperator},t.tokenizer.getSeperator=function(){return t.tokenizer.seperator},t.Pipeline=function(){this._queue=[]},t.Pipeline.registeredFunctions={},t.Pipeline.registerFunction=function(e,n){n in t.Pipeline.registeredFunctions&&t.utils.warn("Overwriting existing registered function: "+n),e.label=n,t.Pipeline.registeredFunctions[n]=e},t.Pipeline.getRegisteredFunction=function(e){return e in t.Pipeline.registeredFunctions!=!0?null:t.Pipeline.registeredFunctions[e]},t.Pipeline.warnIfFunctionNotRegistered=function(e){var n=e.label&&e.label in this.registeredFunctions;n||t.utils.warn("Function is not registered with pipeline. This may cause problems when serialising the index.\n",e)},t.Pipeline.load=function(e){var n=new t.Pipeline;return e.forEach(function(e){var i=t.Pipeline.getRegisteredFunction(e);if(!i)throw new Error("Cannot load un-registered function: "+e);n.add(i)}),n},t.Pipeline.prototype.add=function(){var e=Array.prototype.slice.call(arguments);e.forEach(function(e){t.Pipeline.warnIfFunctionNotRegistered(e),this._queue.push(e)},this)},t.Pipeline.prototype.after=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i+1,0,n)},t.Pipeline.prototype.before=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i,0,n)},t.Pipeline.prototype.remove=function(e){var t=this._queue.indexOf(e);-1!==t&&this._queue.splice(t,1)},t.Pipeline.prototype.run=function(e){for(var t=[],n=e.length,i=this._queue.length,o=0;n>o;o++){for(var r=e[o],s=0;i>s&&(r=this._queue[s](r,o,e),void 0!==r&&null!==r);s++);void 0!==r&&null!==r&&t.push(r)}return t},t.Pipeline.prototype.reset=function(){this._queue=[]},t.Pipeline.prototype.get=function(){return this._queue},t.Pipeline.prototype.toJSON=function(){return this._queue.map(function(e){return t.Pipeline.warnIfFunctionNotRegistered(e),e.label})},t.Index=function(){this._fields=[],this._ref="id",this.pipeline=new t.Pipeline,this.documentStore=new t.DocumentStore,this.index={},this.eventEmitter=new t.EventEmitter,this._idfCache={},this.on("add","remove","update",function(){this._idfCache={}}.bind(this))},t.Index.prototype.on=function(){var e=Array.prototype.slice.call(arguments);return this.eventEmitter.addListener.apply(this.eventEmitter,e)},t.Index.prototype.off=function(e,t){return this.eventEmitter.removeListener(e,t)},t.Index.load=function(e){e.version!==t.version&&t.utils.warn("version mismatch: current "+t.version+" importing "+e.version);var n=new this;n._fields=e.fields,n._ref=e.ref,n.documentStore=t.DocumentStore.load(e.documentStore),n.pipeline=t.Pipeline.load(e.pipeline),n.index={};for(var i in e.index)n.index[i]=t.InvertedIndex.load(e.index[i]);return n},t.Index.prototype.addField=function(e){return this._fields.push(e),this.index[e]=new t.InvertedIndex,this},t.Index.prototype.setRef=function(e){return this._ref=e,this},t.Index.prototype.saveDocument=function(e){return this.documentStore=new t.DocumentStore(e),this},t.Index.prototype.addDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.addDoc(i,e),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));this.documentStore.addFieldLength(i,n,o.length);var r={};o.forEach(function(e){e in r?r[e]+=1:r[e]=1},this);for(var s in r){var u=r[s];u=Math.sqrt(u),this.index[n].addToken(s,{ref:i,tf:u})}},this),n&&this.eventEmitter.emit("add",e,this)}},t.Index.prototype.removeDocByRef=function(e){if(e&&this.documentStore.isDocStored()!==!1&&this.documentStore.hasDoc(e)){var t=this.documentStore.getDoc(e);this.removeDoc(t,!1)}},t.Index.prototype.removeDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.hasDoc(i)&&(this.documentStore.removeDoc(i),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));o.forEach(function(e){this.index[n].removeToken(e,i)},this)},this),n&&this.eventEmitter.emit("remove",e,this))}},t.Index.prototype.updateDoc=function(e,t){var t=void 0===t?!0:t;this.removeDocByRef(e[this._ref],!1),this.addDoc(e,!1),t&&this.eventEmitter.emit("update",e,this)},t.Index.prototype.idf=function(e,t){var n="@"+t+"/"+e;if(Object.prototype.hasOwnProperty.call(this._idfCache,n))return this._idfCache[n];var i=this.index[t].getDocFreq(e),o=1+Math.log(this.documentStore.length/(i+1));return this._idfCache[n]=o,o},t.Index.prototype.getFields=function(){return this._fields.slice()},t.Index.prototype.search=function(e,n){if(!e)return[];e="string"==typeof e?{any:e}:JSON.parse(JSON.stringify(e));var i=null;null!=n&&(i=JSON.stringify(n));for(var o=new t.Configuration(i,this.getFields()).get(),r={},s=Object.keys(e),u=0;u0&&t.push(e);for(var i in n)"docs"!==i&&"df"!==i&&this.expandToken(e+i,t,n[i]);return t},t.InvertedIndex.prototype.toJSON=function(){return{root:this.root}},t.Configuration=function(e,n){var e=e||"";if(void 0==n||null==n)throw new Error("fields should not be null");this.config={};var i;try{i=JSON.parse(e),this.buildUserConfig(i,n)}catch(o){t.utils.warn("user configuration parse failed, will use default configuration"),this.buildDefaultConfig(n)}},t.Configuration.prototype.buildDefaultConfig=function(e){this.reset(),e.forEach(function(e){this.config[e]={boost:1,bool:"OR",expand:!1}},this)},t.Configuration.prototype.buildUserConfig=function(e,n){var i="OR",o=!1;if(this.reset(),"bool"in e&&(i=e.bool||i),"expand"in e&&(o=e.expand||o),"fields"in e)for(var r in e.fields)if(n.indexOf(r)>-1){var s=e.fields[r],u=o;void 0!=s.expand&&(u=s.expand),this.config[r]={boost:s.boost||0===s.boost?s.boost:1,bool:s.bool||i,expand:u}}else t.utils.warn("field name in user configuration not found in index instance fields");else this.addAllFields2UserConfig(i,o,n)},t.Configuration.prototype.addAllFields2UserConfig=function(e,t,n){n.forEach(function(n){this.config[n]={boost:1,bool:e,expand:t}},this)},t.Configuration.prototype.get=function(){return this.config},t.Configuration.prototype.reset=function(){this.config={}},lunr.SortedSet=function(){this.length=0,this.elements=[]},lunr.SortedSet.load=function(e){var t=new this;return t.elements=e,t.length=e.length,t},lunr.SortedSet.prototype.add=function(){var e,t;for(e=0;e1;){if(r===e)return o;e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o]}return r===e?o:-1},lunr.SortedSet.prototype.locationFor=function(e){for(var t=0,n=this.elements.length,i=n-t,o=t+Math.floor(i/2),r=this.elements[o];i>1;)e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o];return r>e?o:e>r?o+1:void 0},lunr.SortedSet.prototype.intersect=function(e){for(var t=new lunr.SortedSet,n=0,i=0,o=this.length,r=e.length,s=this.elements,u=e.elements;;){if(n>o-1||i>r-1)break;s[n]!==u[i]?s[n]u[i]&&i++:(t.add(s[n]),n++,i++)}return t},lunr.SortedSet.prototype.clone=function(){var e=new lunr.SortedSet;return e.elements=this.toArray(),e.length=e.elements.length,e},lunr.SortedSet.prototype.union=function(e){var t,n,i;this.length>=e.length?(t=this,n=e):(t=e,n=this),i=t.clone();for(var o=0,r=n.toArray();o\n"}, {"fullname": "simpleder.DER", "modulename": "simpleder", "qualname": "DER", "kind": "function", "doc": "
Compute Diarization Error Rate.
\n\n
Args:\n ref: a list of tuples for the ground truth, where each tuple is\n (speaker, start, end) of type (string, float, float)\n hyp: a list of tuples for the diarization result hypothesis, same type\n as ref
\n\n
Returns:\n a float number for the Diarization Error Rate
Compute the total length of the union of reference and hypothesis.
\n\n
Args:\n ref: a list of tuples for the ground truth, where each tuple is\n (speaker, start, end) of type (string, float, float)\n hyp: a list of tuples for the diarization result hypothesis, same type\n as ref
\n\n
Returns:\n a float number for the union total length
Args:\n ref: a list of tuples for the ground truth, where each tuple is\n (speaker, start, end) of type (string, float, float)\n hyp: a list of tuples for the diarization result hypothesis, same type\n as ref
\n\n
Returns:\n a 2-dim numpy array, whose element (i, j) is the overlap between\n ith reference speaker and jth hypothesis speaker
Args:\n ref: a list of tuples for the ground truth, where each tuple is\n (speaker, start, end) of type (string, float, float)\n hyp: a list of tuples for the diarization result hypothesis, same type\n as ref
\n\n
Returns:\n a float number for the Diarization Error Rate
144defDER(ref,hyp):
75 | 145"""Compute Diarization Error Rate.
76 | 146
77 | 147 Args:
78 | 148 ref: a list of tuples for the ground truth, where each tuple is
79 | 149 (speaker, start, end) of type (string, float, float)
80 | 150 hyp: a list of tuples for the diarization result hypothesis, same type
81 | 151 as `ref`
82 | 152
83 | 153 Returns:
84 | 154 a float number for the Diarization Error Rate
85 | 155 """
86 | 156check_input(ref)
87 | 157check_input(hyp)
88 | 158ref_total_length=compute_total_length(ref)
89 | 159cost_matrix=build_cost_matrix(ref,hyp)
90 | 160row_index,col_index=optimize.linear_sum_assignment(-cost_matrix)
91 | 161optimal_match_overlap=cost_matrix[row_index,col_index].sum()
92 | 162union_total_length=compute_merged_total_length(ref,hyp)
93 | 163der=(union_total_length-optimal_match_overlap)/ref_total_length
94 | 164returnder
95 |
96 |
97 |
98 |
Compute Diarization Error Rate.
99 |
100 |
Args:
101 | ref: a list of tuples for the ground truth, where each tuple is
102 | (speaker, start, end) of type (string, float, float)
103 | hyp: a list of tuples for the diarization result hypothesis, same type
104 | as ref
105 |
106 |
Returns:
107 | a float number for the Diarization Error Rate
1importnumpyasnp
76 | 2fromscipyimportoptimize
77 | 3
78 | 4
79 | 5defcheck_input(hyp):
80 | 6"""Check whether a hypothesis/reference is valid.
81 | 7
82 | 8 Args:
83 | 9 hyp: a list of tuples, where each tuple is (speaker, start, end)
84 | 10 of type (string, float, float)
85 | 11
86 | 12 Raises:
87 | 13 TypeError: if the type of `hyp` is incorrect
88 | 14 ValueError: if some tuple has start > end; or if two tuples intersect
89 | 15 with each other
90 | 16 """
91 | 17ifnotisinstance(hyp,list):
92 | 18raiseTypeError("Input must be a list.")
93 | 19forelementinhyp:
94 | 20ifnotisinstance(element,tuple):
95 | 21raiseTypeError("Input must be a list of tuples.")
96 | 22iflen(element)!=3:
97 | 23raiseTypeError(
98 | 24"Each tuple must have the elements: (speaker, start, end).")
99 | 25ifnotisinstance(element[0],str):
100 | 26raiseTypeError("Speaker must be a string.")
101 | 27ifnotisinstance(element[1],float)ornotisinstance(
102 | 28element[2],float):
103 | 29raiseTypeError("Start and end must be float numbers.")
104 | 30ifelement[1]>element[2]:
105 | 31raiseValueError("Start must not be larger than end.")
106 | 32num_elements=len(hyp)
107 | 33foriinrange(num_elements-1):
108 | 34forjinrange(i+1,num_elements):
109 | 35ifcompute_intersection_length(hyp[i],hyp[j])>0.0:
110 | 36raiseValueError(
111 | 37"Input must not contain overlapped speech.")
112 | 38
113 | 39
114 | 40defcompute_total_length(hyp):
115 | 41"""Compute total length of a hypothesis/reference.
116 | 42
117 | 43 Args:
118 | 44 hyp: a list of tuples, where each tuple is (speaker, start, end)
119 | 45 of type (string, float, float)
120 | 46
121 | 47 Returns:
122 | 48 a float number for the total length
123 | 49 """
124 | 50total_length=0.0
125 | 51forelementinhyp:
126 | 52total_length+=element[2]-element[1]
127 | 53returntotal_length
128 | 54
129 | 55
130 | 56defcompute_intersection_length(A,B):
131 | 57"""Compute the intersection length of two tuples.
132 | 58
133 | 59 Args:
134 | 60 A: a (speaker, start, end) tuple of type (string, float, float)
135 | 61 B: a (speaker, start, end) tuple of type (string, float, float)
136 | 62
137 | 63 Returns:
138 | 64 a float number of the intersection between `A` and `B`
139 | 65 """
140 | 66max_start=max(A[1],B[1])
141 | 67min_end=min(A[2],B[2])
142 | 68returnmax(0.0,min_end-max_start)
143 | 69
144 | 70
145 | 71defcompute_merged_total_length(ref,hyp):
146 | 72"""Compute the total length of the union of reference and hypothesis.
147 | 73
148 | 74 Args:
149 | 75 ref: a list of tuples for the ground truth, where each tuple is
150 | 76 (speaker, start, end) of type (string, float, float)
151 | 77 hyp: a list of tuples for the diarization result hypothesis, same type
152 | 78 as `ref`
153 | 79
154 | 80 Returns:
155 | 81 a float number for the union total length
156 | 82 """
157 | 83# Remove speaker label and merge.
158 | 84merged=[(element[1],element[2])forelementin(ref+hyp)]
159 | 85# Sort by start.
160 | 86merged=sorted(merged,key=lambdaelement:element[0])
161 | 87i=len(merged)-2
162 | 88whilei>=0:
163 | 89ifmerged[i][1]>=merged[i+1][0]:
164 | 90max_end=max(merged[i][1],merged[i+1][1])
165 | 91merged[i]=(merged[i][0],max_end)
166 | 92delmerged[i+1]
167 | 93ifi==len(merged)-1:
168 | 94i-=1
169 | 95else:
170 | 96i-=1
171 | 97total_length=0.0
172 | 98forelementinmerged:
173 | 99total_length+=element[1]-element[0]
174 | 100returntotal_length
175 | 101
176 | 102
177 | 103defbuild_speaker_index(hyp):
178 | 104"""Build the index for the speakers.
179 | 105
180 | 106 Args:
181 | 107 hyp: a list of tuples, where each tuple is (speaker, start, end)
182 | 108 of type (string, float, float)
183 | 109
184 | 110 Returns:
185 | 111 a dict from speaker to integer
186 | 112 """
187 | 113speaker_set=sorted({element[0]forelementinhyp})
188 | 114index={speaker:ifori,speakerinenumerate(speaker_set)}
189 | 115returnindex
190 | 116
191 | 117
192 | 118defbuild_cost_matrix(ref,hyp):
193 | 119"""Build the cost matrix.
194 | 120
195 | 121 Args:
196 | 122 ref: a list of tuples for the ground truth, where each tuple is
197 | 123 (speaker, start, end) of type (string, float, float)
198 | 124 hyp: a list of tuples for the diarization result hypothesis, same type
199 | 125 as `ref`
200 | 126
201 | 127 Returns:
202 | 128 a 2-dim numpy array, whose element (i, j) is the overlap between
203 | 129 `i`th reference speaker and `j`th hypothesis speaker
204 | 130 """
205 | 131ref_index=build_speaker_index(ref)
206 | 132hyp_index=build_speaker_index(hyp)
207 | 133cost_matrix=np.zeros((len(ref_index),len(hyp_index)))
208 | 134forref_elementinref:
209 | 135forhyp_elementinhyp:
210 | 136i=ref_index[ref_element[0]]
211 | 137j=hyp_index[hyp_element[0]]
212 | 138cost_matrix[i,j]+=compute_intersection_length(
213 | 139ref_element,hyp_element)
214 | 140returncost_matrix
215 | 141
216 | 142
217 | 143defDER(ref,hyp):
218 | 144"""Compute Diarization Error Rate.
219 | 145
220 | 146 Args:
221 | 147 ref: a list of tuples for the ground truth, where each tuple is
222 | 148 (speaker, start, end) of type (string, float, float)
223 | 149 hyp: a list of tuples for the diarization result hypothesis, same type
224 | 150 as `ref`
225 | 151
226 | 152 Returns:
227 | 153 a float number for the Diarization Error Rate
228 | 154 """
229 | 155check_input(ref)
230 | 156check_input(hyp)
231 | 157ref_total_length=compute_total_length(ref)
232 | 158cost_matrix=build_cost_matrix(ref,hyp)
233 | 159row_index,col_index=optimize.linear_sum_assignment(-cost_matrix)
234 | 160optimal_match_overlap=cost_matrix[row_index,col_index].sum()
235 | 161union_total_length=compute_merged_total_length(ref,hyp)
236 | 162der=(union_total_length-optimal_match_overlap)/ref_total_length
237 | 163returnder
238 |
6defcheck_input(hyp):
254 | 7"""Check whether a hypothesis/reference is valid.
255 | 8
256 | 9 Args:
257 | 10 hyp: a list of tuples, where each tuple is (speaker, start, end)
258 | 11 of type (string, float, float)
259 | 12
260 | 13 Raises:
261 | 14 TypeError: if the type of `hyp` is incorrect
262 | 15 ValueError: if some tuple has start > end; or if two tuples intersect
263 | 16 with each other
264 | 17 """
265 | 18ifnotisinstance(hyp,list):
266 | 19raiseTypeError("Input must be a list.")
267 | 20forelementinhyp:
268 | 21ifnotisinstance(element,tuple):
269 | 22raiseTypeError("Input must be a list of tuples.")
270 | 23iflen(element)!=3:
271 | 24raiseTypeError(
272 | 25"Each tuple must have the elements: (speaker, start, end).")
273 | 26ifnotisinstance(element[0],str):
274 | 27raiseTypeError("Speaker must be a string.")
275 | 28ifnotisinstance(element[1],float)ornotisinstance(
276 | 29element[2],float):
277 | 30raiseTypeError("Start and end must be float numbers.")
278 | 31ifelement[1]>element[2]:
279 | 32raiseValueError("Start must not be larger than end.")
280 | 33num_elements=len(hyp)
281 | 34foriinrange(num_elements-1):
282 | 35forjinrange(i+1,num_elements):
283 | 36ifcompute_intersection_length(hyp[i],hyp[j])>0.0:
284 | 37raiseValueError(
285 | 38"Input must not contain overlapped speech.")
286 |
287 |
288 |
289 |
Check whether a hypothesis/reference is valid.
290 |
291 |
Args:
292 | hyp: a list of tuples, where each tuple is (speaker, start, end)
293 | of type (string, float, float)
294 |
295 |
Raises:
296 | TypeError: if the type of hyp is incorrect
297 | ValueError: if some tuple has start > end; or if two tuples intersect
298 | with each other
41defcompute_total_length(hyp):
315 | 42"""Compute total length of a hypothesis/reference.
316 | 43
317 | 44 Args:
318 | 45 hyp: a list of tuples, where each tuple is (speaker, start, end)
319 | 46 of type (string, float, float)
320 | 47
321 | 48 Returns:
322 | 49 a float number for the total length
323 | 50 """
324 | 51total_length=0.0
325 | 52forelementinhyp:
326 | 53total_length+=element[2]-element[1]
327 | 54returntotal_length
328 |
329 |
330 |
331 |
Compute total length of a hypothesis/reference.
332 |
333 |
Args:
334 | hyp: a list of tuples, where each tuple is (speaker, start, end)
335 | of type (string, float, float)
336 |
337 |
Returns:
338 | a float number for the total length
72defcompute_merged_total_length(ref,hyp):
394 | 73"""Compute the total length of the union of reference and hypothesis.
395 | 74
396 | 75 Args:
397 | 76 ref: a list of tuples for the ground truth, where each tuple is
398 | 77 (speaker, start, end) of type (string, float, float)
399 | 78 hyp: a list of tuples for the diarization result hypothesis, same type
400 | 79 as `ref`
401 | 80
402 | 81 Returns:
403 | 82 a float number for the union total length
404 | 83 """
405 | 84# Remove speaker label and merge.
406 | 85merged=[(element[1],element[2])forelementin(ref+hyp)]
407 | 86# Sort by start.
408 | 87merged=sorted(merged,key=lambdaelement:element[0])
409 | 88i=len(merged)-2
410 | 89whilei>=0:
411 | 90ifmerged[i][1]>=merged[i+1][0]:
412 | 91max_end=max(merged[i][1],merged[i+1][1])
413 | 92merged[i]=(merged[i][0],max_end)
414 | 93delmerged[i+1]
415 | 94ifi==len(merged)-1:
416 | 95i-=1
417 | 96else:
418 | 97i-=1
419 | 98total_length=0.0
420 | 99forelementinmerged:
421 | 100total_length+=element[1]-element[0]
422 | 101returntotal_length
423 |
424 |
425 |
426 |
Compute the total length of the union of reference and hypothesis.
427 |
428 |
Args:
429 | ref: a list of tuples for the ground truth, where each tuple is
430 | (speaker, start, end) of type (string, float, float)
431 | hyp: a list of tuples for the diarization result hypothesis, same type
432 | as ref
433 |
434 |
Returns:
435 | a float number for the union total length
104defbuild_speaker_index(hyp):
452 | 105"""Build the index for the speakers.
453 | 106
454 | 107 Args:
455 | 108 hyp: a list of tuples, where each tuple is (speaker, start, end)
456 | 109 of type (string, float, float)
457 | 110
458 | 111 Returns:
459 | 112 a dict from speaker to integer
460 | 113 """
461 | 114speaker_set=sorted({element[0]forelementinhyp})
462 | 115index={speaker:ifori,speakerinenumerate(speaker_set)}
463 | 116returnindex
464 |
465 |
466 |
467 |
Build the index for the speakers.
468 |
469 |
Args:
470 | hyp: a list of tuples, where each tuple is (speaker, start, end)
471 | of type (string, float, float)
119defbuild_cost_matrix(ref,hyp):
491 | 120"""Build the cost matrix.
492 | 121
493 | 122 Args:
494 | 123 ref: a list of tuples for the ground truth, where each tuple is
495 | 124 (speaker, start, end) of type (string, float, float)
496 | 125 hyp: a list of tuples for the diarization result hypothesis, same type
497 | 126 as `ref`
498 | 127
499 | 128 Returns:
500 | 129 a 2-dim numpy array, whose element (i, j) is the overlap between
501 | 130 `i`th reference speaker and `j`th hypothesis speaker
502 | 131 """
503 | 132ref_index=build_speaker_index(ref)
504 | 133hyp_index=build_speaker_index(hyp)
505 | 134cost_matrix=np.zeros((len(ref_index),len(hyp_index)))
506 | 135forref_elementinref:
507 | 136forhyp_elementinhyp:
508 | 137i=ref_index[ref_element[0]]
509 | 138j=hyp_index[hyp_element[0]]
510 | 139cost_matrix[i,j]+=compute_intersection_length(
511 | 140ref_element,hyp_element)
512 | 141returncost_matrix
513 |
514 |
515 |
516 |
Build the cost matrix.
517 |
518 |
Args:
519 | ref: a list of tuples for the ground truth, where each tuple is
520 | (speaker, start, end) of type (string, float, float)
521 | hyp: a list of tuples for the diarization result hypothesis, same type
522 | as ref
523 |
524 |
Returns:
525 | a 2-dim numpy array, whose element (i, j) is the overlap between
526 | ith reference speaker and jth hypothesis speaker
144defDER(ref,hyp):
543 | 145"""Compute Diarization Error Rate.
544 | 146
545 | 147 Args:
546 | 148 ref: a list of tuples for the ground truth, where each tuple is
547 | 149 (speaker, start, end) of type (string, float, float)
548 | 150 hyp: a list of tuples for the diarization result hypothesis, same type
549 | 151 as `ref`
550 | 152
551 | 153 Returns:
552 | 154 a float number for the Diarization Error Rate
553 | 155 """
554 | 156check_input(ref)
555 | 157check_input(hyp)
556 | 158ref_total_length=compute_total_length(ref)
557 | 159cost_matrix=build_cost_matrix(ref,hyp)
558 | 160row_index,col_index=optimize.linear_sum_assignment(-cost_matrix)
559 | 161optimal_match_overlap=cost_matrix[row_index,col_index].sum()
560 | 162union_total_length=compute_merged_total_length(ref,hyp)
561 | 163der=(union_total_length-optimal_match_overlap)/ref_total_length
562 | 164returnder
563 |
564 |
565 |
566 |
Compute Diarization Error Rate.
567 |
568 |
Args:
569 | ref: a list of tuples for the ground truth, where each tuple is
570 | (speaker, start, end) of type (string, float, float)
571 | hyp: a list of tuples for the diarization result hypothesis, same type
572 | as ref
573 |
574 |
Returns:
575 | a float number for the Diarization Error Rate
576 |
577 |
578 |
579 |
580 |
581 |
763 |
--------------------------------------------------------------------------------
/publish.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -o errexit
3 |
4 | # This script requires these tools:
5 | # pip3 install --user --upgrade setuptools wheel
6 | # pip3 install --user --upgrade twine
7 |
8 | # Get project path.
9 | PROJECT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
10 |
11 | pushd ${PROJECT_PATH}
12 |
13 | # clean up
14 | rm -rf build
15 | rm -rf dist
16 | rm -rf simpleder.egg-info
17 |
18 | # build and upload
19 | python3 setup.py sdist bdist_wheel
20 | python3 -m twine upload dist/*
21 |
22 | popd
23 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
--------------------------------------------------------------------------------
/run_pdoc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -o errexit
3 |
4 | # Get project path.
5 | PROJECT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
6 |
7 | pushd ${PROJECT_PATH}
8 |
9 | rm -r docs
10 |
11 | # This script requires pdoc:
12 | # pip3 install pdoc
13 | python3 -m pdoc simpleder -o docs
14 |
15 | popd
16 |
--------------------------------------------------------------------------------
/run_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -o errexit
3 |
4 | # Get project path.
5 | PROJECT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
6 |
7 | # Add project modules to PYTHONPATH.
8 | if [[ "${PYTHONPATH}" != *"${PROJECT_PATH}"* ]]; then
9 | export PYTHONPATH="${PYTHONPATH}:${PROJECT_PATH}"
10 | fi
11 |
12 | pushd ${PROJECT_PATH}
13 |
14 | rm -f .coverage
15 |
16 | # Run tests.
17 | for TEST_FILE in $(find tests -name "*_test.py"); do
18 | echo "Running tests in ${TEST_FILE}"
19 | python3 -m coverage run -a ${TEST_FILE}
20 | done
21 | echo "All tests passed!"
22 |
23 | python3 -m codecov
24 |
25 | popd
26 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup script for the package."""
2 |
3 | import setuptools
4 |
5 | VERSION = "0.0.3"
6 |
7 | with open("README.md", "r") as file_object:
8 | LONG_DESCRIPTION = file_object.read()
9 |
10 | SHORT_DESCRIPTION = """
11 | A lightweight library to compute Diarization Error Rate (DER).
12 | """.strip()
13 |
14 | setuptools.setup(
15 | name="simpleder",
16 | version=VERSION,
17 | author="Quan Wang",
18 | author_email="quanw@google.com",
19 | description=SHORT_DESCRIPTION,
20 | long_description=LONG_DESCRIPTION,
21 | long_description_content_type="text/markdown",
22 | url="https://github.com/wq2012/SimpleDER",
23 | packages=setuptools.find_packages(),
24 | classifiers=[
25 | "Programming Language :: Python :: 3",
26 | "License :: OSI Approved :: Apache Software License",
27 | "Operating System :: OS Independent",
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/simpleder/__init__.py:
--------------------------------------------------------------------------------
1 | from . import der
2 |
3 | DER = der.DER
4 |
--------------------------------------------------------------------------------
/simpleder/der.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import optimize
3 |
4 |
5 | def check_input(hyp):
6 | """Check whether a hypothesis/reference is valid.
7 |
8 | Args:
9 | hyp: a list of tuples, where each tuple is (speaker, start, end)
10 | of type (string, float, float)
11 |
12 | Raises:
13 | TypeError: if the type of `hyp` is incorrect
14 | ValueError: if some tuple has start > end; or if two tuples intersect
15 | with each other
16 | """
17 | if not isinstance(hyp, list):
18 | raise TypeError("Input must be a list.")
19 | for element in hyp:
20 | if not isinstance(element, tuple):
21 | raise TypeError("Input must be a list of tuples.")
22 | if len(element) != 3:
23 | raise TypeError(
24 | "Each tuple must have the elements: (speaker, start, end).")
25 | if not isinstance(element[0], str):
26 | raise TypeError("Speaker must be a string.")
27 | if not isinstance(element[1], float) or not isinstance(
28 | element[2], float):
29 | raise TypeError("Start and end must be float numbers.")
30 | if element[1] > element[2]:
31 | raise ValueError("Start must not be larger than end.")
32 | num_elements = len(hyp)
33 | for i in range(num_elements - 1):
34 | for j in range(i + 1, num_elements):
35 | if compute_intersection_length(hyp[i], hyp[j]) > 0.0:
36 | raise ValueError(
37 | "Input must not contain overlapped speech.")
38 |
39 |
40 | def compute_total_length(hyp):
41 | """Compute total length of a hypothesis/reference.
42 |
43 | Args:
44 | hyp: a list of tuples, where each tuple is (speaker, start, end)
45 | of type (string, float, float)
46 |
47 | Returns:
48 | a float number for the total length
49 | """
50 | total_length = 0.0
51 | for element in hyp:
52 | total_length += element[2] - element[1]
53 | return total_length
54 |
55 |
56 | def compute_intersection_length(A, B):
57 | """Compute the intersection length of two tuples.
58 |
59 | Args:
60 | A: a (speaker, start, end) tuple of type (string, float, float)
61 | B: a (speaker, start, end) tuple of type (string, float, float)
62 |
63 | Returns:
64 | a float number of the intersection between `A` and `B`
65 | """
66 | max_start = max(A[1], B[1])
67 | min_end = min(A[2], B[2])
68 | return max(0.0, min_end - max_start)
69 |
70 |
71 | def compute_merged_total_length(ref, hyp):
72 | """Compute the total length of the union of reference and hypothesis.
73 |
74 | Args:
75 | ref: a list of tuples for the ground truth, where each tuple is
76 | (speaker, start, end) of type (string, float, float)
77 | hyp: a list of tuples for the diarization result hypothesis, same type
78 | as `ref`
79 |
80 | Returns:
81 | a float number for the union total length
82 | """
83 | # Remove speaker label and merge.
84 | merged = [(element[1], element[2]) for element in (ref + hyp)]
85 | # Sort by start.
86 | merged = sorted(merged, key=lambda element: element[0])
87 | i = len(merged) - 2
88 | while i >= 0:
89 | if merged[i][1] >= merged[i + 1][0]:
90 | max_end = max(merged[i][1], merged[i + 1][1])
91 | merged[i] = (merged[i][0], max_end)
92 | del merged[i + 1]
93 | if i == len(merged) - 1:
94 | i -= 1
95 | else:
96 | i -= 1
97 | total_length = 0.0
98 | for element in merged:
99 | total_length += element[1] - element[0]
100 | return total_length
101 |
102 |
103 | def build_speaker_index(hyp):
104 | """Build the index for the speakers.
105 |
106 | Args:
107 | hyp: a list of tuples, where each tuple is (speaker, start, end)
108 | of type (string, float, float)
109 |
110 | Returns:
111 | a dict from speaker to integer
112 | """
113 | speaker_set = sorted({element[0] for element in hyp})
114 | index = {speaker: i for i, speaker in enumerate(speaker_set)}
115 | return index
116 |
117 |
118 | def build_cost_matrix(ref, hyp):
119 | """Build the cost matrix.
120 |
121 | Args:
122 | ref: a list of tuples for the ground truth, where each tuple is
123 | (speaker, start, end) of type (string, float, float)
124 | hyp: a list of tuples for the diarization result hypothesis, same type
125 | as `ref`
126 |
127 | Returns:
128 | a 2-dim numpy array, whose element (i, j) is the overlap between
129 | `i`th reference speaker and `j`th hypothesis speaker
130 | """
131 | ref_index = build_speaker_index(ref)
132 | hyp_index = build_speaker_index(hyp)
133 | cost_matrix = np.zeros((len(ref_index), len(hyp_index)))
134 | for ref_element in ref:
135 | for hyp_element in hyp:
136 | i = ref_index[ref_element[0]]
137 | j = hyp_index[hyp_element[0]]
138 | cost_matrix[i, j] += compute_intersection_length(
139 | ref_element, hyp_element)
140 | return cost_matrix
141 |
142 |
143 | def DER(ref, hyp):
144 | """Compute Diarization Error Rate.
145 |
146 | Args:
147 | ref: a list of tuples for the ground truth, where each tuple is
148 | (speaker, start, end) of type (string, float, float)
149 | hyp: a list of tuples for the diarization result hypothesis, same type
150 | as `ref`
151 |
152 | Returns:
153 | a float number for the Diarization Error Rate
154 | """
155 | check_input(ref)
156 | check_input(hyp)
157 | ref_total_length = compute_total_length(ref)
158 | cost_matrix = build_cost_matrix(ref, hyp)
159 | row_index, col_index = optimize.linear_sum_assignment(-cost_matrix)
160 | optimal_match_overlap = cost_matrix[row_index, col_index].sum()
161 | union_total_length = compute_merged_total_length(ref, hyp)
162 | der = (union_total_length - optimal_match_overlap) / ref_total_length
163 | return der
164 |
--------------------------------------------------------------------------------
/tests/det_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import unittest
3 |
4 | from simpleder import der
5 |
6 |
7 | class TestCheckInput(unittest.TestCase):
8 | """Tests for the check_input function."""
9 |
10 | def test_valid(self):
11 | ref = [("A", 1.0, 2.0),
12 | ("B", 4.0, 5.0),
13 | ("A", 6.7, 9.0),
14 | ("C", 10.0, 12.0),
15 | ("D", 12.0, 13.0)]
16 | hyp = [("A", 1.0, 3.0),
17 | ("B", 4.0, 4.8),
18 | ("A", 7.0, 9.0),
19 | ("C", 10.0, 13.0)]
20 | der.check_input(ref)
21 | der.check_input(hyp)
22 |
23 | def test_missing_speaker(self):
24 | hyp = [("A", 1.0, 3.0),
25 | ("B", 4.0, 7.1),
26 | ("A", 7.0, 9.0),
27 | (10.0, 13.0)]
28 | with self.assertRaises(TypeError):
29 | der.check_input(hyp)
30 |
31 | def test_wrong_speaker_type(self):
32 | hyp = [("A", 1.0, 3.0),
33 | ("B", 4.0, 7.1),
34 | ("A", 7.0, 9.0),
35 | (3, 10.0, 13.0)]
36 | with self.assertRaises(TypeError):
37 | der.check_input(hyp)
38 |
39 | def test_overlap(self):
40 | hyp = [("A", 1.0, 3.0),
41 | ("B", 4.0, 7.1),
42 | ("A", 7.0, 9.0),
43 | ("C", 10.0, 13.0)]
44 | with self.assertRaises(ValueError):
45 | der.check_input(hyp)
46 |
47 |
48 | class TestComputeIntersectionLength(unittest.TestCase):
49 | """Tests for the compute_intersection_length function."""
50 |
51 | def test_include(self):
52 | A = ("A", 1.0, 5.0)
53 | B = ("B", 2.0, 3.0)
54 | self.assertEqual(1.0, der.compute_intersection_length(A, B))
55 |
56 | def test_separate(self):
57 | A = ("A", 1.0, 3.0)
58 | B = ("B", 5.0, 7.0)
59 | self.assertEqual(0.0, der.compute_intersection_length(A, B))
60 |
61 | def test_overlap(self):
62 | A = ("A", 1.0, 5.0)
63 | B = ("B", 2.0, 9.0)
64 | self.assertEqual(3.0, der.compute_intersection_length(A, B))
65 |
66 |
67 | class TestComputeTotalLength(unittest.TestCase):
68 | """Tests for the compute_total_length function."""
69 |
70 | def test_example(self):
71 | hyp = [("A", 1.0, 3.0),
72 | ("B", 4.0, 5.0),
73 | ("A", 7.0, 9.0),
74 | ("C", 10.0, 13.0)]
75 | self.assertEqual(8.0, der.compute_total_length(hyp))
76 |
77 |
78 | class TestComputeMergedTotalLength(unittest.TestCase):
79 | """Tests for the compute_merged_total_length function."""
80 |
81 | def test_example1(self):
82 | ref = [("A", 1.0, 2.0),
83 | ("B", 4.0, 5.0),
84 | ("A", 6.7, 9.0),
85 | ("C", 10.0, 12.0),
86 | ("D", 12.0, 13.0)]
87 | hyp = [("A", 1.0, 3.0),
88 | ("B", 4.0, 4.8),
89 | ("A", 7.0, 9.0),
90 | ("C", 10.0, 13.0)]
91 | merged_total_length = der.compute_merged_total_length(ref, hyp)
92 | self.assertEqual(8.3, merged_total_length)
93 |
94 | def test_example2(self):
95 | ref = [("A", 1.0, 2.0)]
96 | hyp = [("A", 1.0, 1.6),
97 | ("A", 1.7, 2.5)]
98 | merged_total_length = der.compute_merged_total_length(ref, hyp)
99 | self.assertEqual(1.5, merged_total_length)
100 |
101 |
102 | class TestBuildSpeakerIndex(unittest.TestCase):
103 | """Tests for the build_speaker_index function."""
104 |
105 | def test_example(self):
106 | hyp = [("A", 1.0, 3.0),
107 | ("B", 4.0, 5.0),
108 | ("A", 7.0, 9.0),
109 | ("C", 10.0, 13.0)]
110 | expected = {
111 | "A": 0,
112 | "B": 1,
113 | "C": 2
114 | }
115 | hyp_index = der.build_speaker_index(hyp)
116 | self.assertDictEqual(expected, hyp_index)
117 |
118 |
119 | class TestBuildCostMatrix(unittest.TestCase):
120 | """Tests for the build_cost_matrix function."""
121 |
122 | def test_example(self):
123 | ref = [("A", 1.0, 2.0),
124 | ("B", 4.0, 4.8),
125 | ("A", 6.7, 9.0),
126 | ("C", 10.0, 12.0),
127 | ("D", 12.0, 13.0)]
128 | hyp = [("A", 1.0, 3.0),
129 | ("B", 4.0, 5.0),
130 | ("A", 7.0, 9.0),
131 | ("C", 10.0, 13.0)]
132 | expected = np.array(
133 | [[3.0, 0.0, 0.0],
134 | [0.0, 0.8, 0.0],
135 | [0.0, 0.0, 2.0],
136 | [0.0, 0.0, 1.0]])
137 | cost_matrix = der.build_cost_matrix(ref, hyp)
138 | self.assertTrue(np.allclose(expected, cost_matrix, atol=0.0001))
139 |
140 |
141 | class TestDER(unittest.TestCase):
142 | """Tests for the DER function."""
143 |
144 | def test_single_tuple_same(self):
145 | ref = [("A", 0.0, 1.0)]
146 | hyp = [("B", 0.0, 1.0)]
147 | self.assertEqual(0.0, der.DER(ref, hyp))
148 |
149 | def test_single_tuple_all_miss(self):
150 | ref = [("A", 0.0, 1.0)]
151 | hyp = []
152 | self.assertEqual(1.0, der.DER(ref, hyp))
153 |
154 | def test_single_tuple_separate(self):
155 | ref = [("A", 0.0, 1.0)]
156 | hyp = [("B", 1.0, 2.0)]
157 | self.assertEqual(2.0, der.DER(ref, hyp))
158 |
159 | def test_single_tuple_half_miss(self):
160 | ref = [("A", 0.0, 1.0)]
161 | hyp = [("B", 0.0, 0.5)]
162 | self.assertEqual(0.5, der.DER(ref, hyp))
163 |
164 | def test_single_tuple_half_different(self):
165 | ref = [("A", 0.0, 1.0)]
166 | hyp = [("B", 0.0, 0.5),
167 | ("C", 0.5, 1.0)]
168 | self.assertEqual(0.5, der.DER(ref, hyp))
169 |
170 | def test_two_tuples_same(self):
171 | ref = [("A", 0.0, 0.5),
172 | ("B", 0.5, 1.0)]
173 | hyp = [("B", 0.0, 0.5),
174 | ("C", 0.5, 1.0)]
175 | self.assertEqual(0.0, der.DER(ref, hyp))
176 |
177 | def test_two_tuples_half_miss(self):
178 | ref = [("A", 0.0, 0.5),
179 | ("B", 0.5, 1.0)]
180 | hyp = [("B", 0.0, 0.25),
181 | ("C", 0.5, 0.75)]
182 | self.assertEqual(0.5, der.DER(ref, hyp))
183 |
184 | def test_two_tuples_one_quarter_correct(self):
185 | ref = [("A", 0.0, 0.5),
186 | ("B", 0.5, 1.0)]
187 | hyp = [("B", 0.0, 0.25),
188 | ("C", 0.25, 0.5)]
189 | self.assertEqual(0.75, der.DER(ref, hyp))
190 |
191 | def test_two_tuples_half_correct(self):
192 | ref = [("A", 0.0, 0.5),
193 | ("B", 0.5, 1.0)]
194 | hyp = [("B", 0.0, 0.25),
195 | ("C", 0.25, 0.5),
196 | ("D", 0.5, 0.75),
197 | ("E", 0.75, 1.0)]
198 | self.assertEqual(0.5, der.DER(ref, hyp))
199 |
200 | def test_three_tuples(self):
201 | ref = [("A", 0.0, 1.0),
202 | ("B", 1.0, 1.5),
203 | ("A", 1.6, 2.1)]
204 | hyp = [("1", 0.0, 0.8),
205 | ("2", 0.8, 1.4),
206 | ("3", 1.5, 1.8),
207 | ("1", 1.8, 2.0)]
208 | self.assertAlmostEqual(0.35, der.DER(ref, hyp), delta=0.0001)
209 |
210 | def test_hyp_has_more_labels(self):
211 | ref = [("0", 0.0, 1.0),
212 | ("1", 1.0, 2.0),
213 | ("1", 2.0, 3.0),
214 | ("1", 3.0, 4.0),
215 | ("0", 4.0, 5.0)]
216 | hyp = [("0", 0.0, 1.0),
217 | ("2", 1.0, 2.0),
218 | ("2", 2.0, 3.0),
219 | ("0", 3.0, 4.0),
220 | ("2", 4.0, 5.0)]
221 | self.assertAlmostEqual(0.4, der.DER(ref, hyp), delta=0.0001)
222 |
223 | def test_ref_has_more_labels(self):
224 | ref = [("0", 0.0, 1.0),
225 | ("2", 1.0, 2.0),
226 | ("2", 2.0, 3.0),
227 | ("0", 3.0, 4.0),
228 | ("1", 4.0, 5.0)]
229 | hyp = [("0", 0.0, 1.0),
230 | ("1", 1.0, 2.0),
231 | ("1", 2.0, 3.0),
232 | ("1", 3.0, 4.0),
233 | ("0", 4.0, 5.0)]
234 | self.assertAlmostEqual(0.4, der.DER(ref, hyp), delta=0.0001)
235 |
236 |
237 | if __name__ == "__main__":
238 | unittest.main()
239 |
--------------------------------------------------------------------------------