├── .gitignore ├── LICENSE ├── README.md ├── data ├── toy_testing_data.npz └── toy_training_data.npz ├── demo.py ├── resources ├── SML_diag.png ├── SUPERVISED_ONLINE_DIARIZATION_WITH_SAMPLE_MEAN_LOSS_FOR_MULTI_DOMAIN_DATA.pdf └── uisrnn.gif └── uisrnn ├── __init__.py ├── arguments.py ├── evals.py ├── loss_func.py ├── uisrnn.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.pyc 3 | .DS_Store 4 | *result.txt 5 | *.uisrnn 6 | build/* 7 | dist/* 8 | uisrnn.egg-info/* 9 | .coverage 10 | runs 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Better, Faster, Stronger UIS-RNN 2 | 3 | This repository implements some useful features on top of the original [UIS-RNN repository](https://github.com/google/uis-rnn). Some of them are described in the following paper: [Supervised Online Diarization with Sample Mean Loss for Multi-Domain Data](https://arxiv.org/abs/1911.01266). 4 | Here is a list: 5 | * **Sample Mean Loss (SML)**, a loss function that improves performance and training efficiency. To learn more about it you can read our paper. 6 | * **Estimation of** `crp_alpha`, a parameter of the distance dependent Chinese Restaurant Process (ddCRP) that determines the probability of switching to a new speaker. Again, more info in our paper. 7 | * **Parallel prediction** using `torch.multiprocessing`, that mitigates the issue with slow decoding and enables higher GPU usage. 8 | * **Tensorboard** logging, for visualizing training. 9 | 10 | Here is a diagram of the Sample Mean Loss: 11 | 12 |

13 | 14 |

15 | 16 | The UIS-RNN was originally proposed in [Fully Supervised Speaker Diarization](https://arxiv.org/abs/1810.04719). 17 | 18 |

19 | 20 |

21 | 22 | ## Run the demo 23 | 24 | To get started, simply run this command: 25 | 26 | ```bash 27 | python3 demo.py --train_iteration=1000 -l=0.001 28 | ``` 29 | 30 | This will train a UIS-RNN model using `data/toy_training_data.npz`, 31 | then store the model on disk, perform inference on `data/toy_testing_data.npz`, 32 | print the inference results, and save the averaged accuracy in a text file. 33 | 34 | P.S.: The files under `data/` are manually generated *toy data*, 35 | for demonstration purpose only. 36 | These data are very simple, so we are supposed to get 100% accuracy on the 37 | testing data. 38 | 39 | ## Arguments 40 | 41 | * `--loss_samples` the number of samples for the Sample Mean Loss. If `loss_samples <= 0` it will be ignored and the loss will be computed as per the original UIS-RNN 42 | * `--fc_depth`: the numebr of fully connected layers after the GRU. 43 | * `--non_lin`: whether to use non linearity (relu) in the fully connected layers. 44 | * `NUM_WORKERS`: the number of workers (processes) for multiprocessing. The argument can be found in `demo.py`. 45 | 46 | All the other arguments are the same as per the [original repository](https://github.com/google/uis-rnn) 47 | 48 | ## Citations 49 | 50 | Our paper is cited as: 51 | 52 | ``` 53 | @article{fini2019supervised, 54 | title={Supervised online diarization with sample mean loss for multi-domain data}, 55 | author={Fini, Enrico and Brutti, Alessio}, 56 | journal={arXiv preprint arXiv:1911.01266}, 57 | year={2019} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /data/toy_testing_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/uis-rnn-sml/820d8698b60c88683c5be15b8bf529afb3c886c9/data/toy_testing_data.npz -------------------------------------------------------------------------------- /data/toy_training_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/uis-rnn-sml/820d8698b60c88683c5be15b8bf529afb3c886c9/data/toy_training_data.npz -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A demo script showing how to use the uisrnn package on toy data.""" 15 | 16 | import numpy as np 17 | from functools import partial 18 | from torch.utils.tensorboard import SummaryWriter 19 | import torch.multiprocessing as mp 20 | mp = mp.get_context('forkserver') 21 | 22 | import uisrnn 23 | 24 | 25 | SAVED_MODEL_NAME = 'saved_model.uisrnn' 26 | NUM_WORKERS = 2 27 | 28 | 29 | def diarization_experiment(model_args, training_args, inference_args): 30 | """Experiment pipeline. 31 | 32 | Load data --> train model --> test model --> output result 33 | 34 | Args: 35 | model_args: model configurations 36 | training_args: training configurations 37 | inference_args: inference configurations 38 | """ 39 | # data loading 40 | train_data = np.load('./data/toy_training_data.npz', allow_pickle=True) 41 | test_data = np.load('./data/toy_testing_data.npz', allow_pickle=True) 42 | train_sequence = train_data['train_sequence'] 43 | train_cluster_id = train_data['train_cluster_id'] 44 | test_sequences = test_data['test_sequences'].tolist() 45 | test_cluster_ids = test_data['test_cluster_ids'].tolist() 46 | 47 | # model init 48 | model = uisrnn.UISRNN(model_args) 49 | # model.load(SAVED_MODEL_NAME) # to load a checkpoint 50 | # tensorboard writer init 51 | writer = SummaryWriter() 52 | 53 | # training 54 | for epoch in range(training_args.epochs): 55 | stats = model.fit(train_sequence, train_cluster_id, training_args) 56 | # add to tensorboard 57 | for loss, cur_iter in stats: 58 | for loss_name, loss_value in loss.items(): 59 | writer.add_scalar('loss/' + loss_name, loss_value, cur_iter) 60 | # save the mdoel 61 | model.save(SAVED_MODEL_NAME) 62 | 63 | # testing 64 | predicted_cluster_ids = [] 65 | test_record = [] 66 | # predict sequences in parallel 67 | model.rnn_model.share_memory() 68 | pool = mp.Pool(NUM_WORKERS, maxtasksperchild=None) 69 | pred_gen = pool.imap( 70 | func=partial(model.predict, args=inference_args), 71 | iterable=test_sequences) 72 | # collect and score predicitons 73 | for idx, predicted_cluster_id in enumerate(pred_gen): 74 | accuracy = uisrnn.compute_sequence_match_accuracy( 75 | test_cluster_ids[idx], predicted_cluster_id) 76 | predicted_cluster_ids.append(predicted_cluster_id) 77 | test_record.append((accuracy, len(test_cluster_ids[idx]))) 78 | print('Ground truth labels:') 79 | print(test_cluster_ids[idx]) 80 | print('Predicted labels:') 81 | print(predicted_cluster_id) 82 | print('-' * 80) 83 | 84 | # close multiprocessing pool 85 | pool.close() 86 | # close tensorboard writer 87 | writer.close() 88 | 89 | print('Finished diarization experiment') 90 | print(uisrnn.output_result(model_args, training_args, test_record)) 91 | 92 | 93 | def main(): 94 | """The main function.""" 95 | model_args, training_args, inference_args = uisrnn.parse_arguments() 96 | diarization_experiment(model_args, training_args, inference_args) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /resources/SML_diag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/uis-rnn-sml/820d8698b60c88683c5be15b8bf529afb3c886c9/resources/SML_diag.png -------------------------------------------------------------------------------- /resources/SUPERVISED_ONLINE_DIARIZATION_WITH_SAMPLE_MEAN_LOSS_FOR_MULTI_DOMAIN_DATA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/uis-rnn-sml/820d8698b60c88683c5be15b8bf529afb3c886c9/resources/SUPERVISED_ONLINE_DIARIZATION_WITH_SAMPLE_MEAN_LOSS_FOR_MULTI_DOMAIN_DATA.pdf -------------------------------------------------------------------------------- /resources/uisrnn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/uis-rnn-sml/820d8698b60c88683c5be15b8bf529afb3c886c9/resources/uisrnn.gif -------------------------------------------------------------------------------- /uisrnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The module for Unbounded Interleaved-State Recurrent Neural Network. 15 | 16 | An introduction is available at [README.md]. 17 | 18 | [README.md]: https://github.com/google/uis-rnn/blob/master/README.md 19 | """ 20 | 21 | from . import arguments 22 | from . import evals 23 | from . import loss_func 24 | from . import uisrnn 25 | from . import utils 26 | 27 | #pylint: disable=C0103 28 | parse_arguments = arguments.parse_arguments 29 | compute_sequence_match_accuracy = evals.compute_sequence_match_accuracy 30 | output_result = utils.output_result 31 | UISRNN = uisrnn.UISRNN 32 | -------------------------------------------------------------------------------- /uisrnn/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Arguments for UISRNN.""" 15 | 16 | import argparse 17 | 18 | _DEFAULT_OBSERVATION_DIM = 256 19 | 20 | 21 | def str2bool(value): 22 | """A function to convert string to bool value.""" 23 | if value.lower() in {'yes', 'true', 't', 'y', '1'}: 24 | return True 25 | if value.lower() in {'no', 'false', 'f', 'n', '0'}: 26 | return False 27 | raise argparse.ArgumentTypeError('Boolean value expected.') 28 | 29 | 30 | def parse_arguments(): 31 | """Parse arguments. 32 | 33 | Returns: 34 | A tuple of: 35 | 36 | - `model_args`: model arguments 37 | - `training_args`: training arguments 38 | - `inference_args`: inference arguments 39 | """ 40 | # model configurations 41 | model_parser = argparse.ArgumentParser( 42 | description='Model configurations.', add_help=False) 43 | 44 | model_parser.add_argument( 45 | '--observation_dim', 46 | default=_DEFAULT_OBSERVATION_DIM, 47 | type=int, 48 | help='The dimension of the embeddings (e.g. d-vectors).') 49 | 50 | model_parser.add_argument( 51 | '--rnn_hidden_size', 52 | default=512, 53 | type=int, 54 | help='The number of nodes for each RNN layer.') 55 | model_parser.add_argument( 56 | '--rnn_depth', 57 | default=1, 58 | type=int, 59 | help='The number of RNN layers.') 60 | model_parser.add_argument( 61 | '--fc_depth', 62 | default=1, 63 | type=int, 64 | help='The number of fully connected layers.') 65 | model_parser.add_argument( 66 | '--rnn_dropout', 67 | default=0.2, 68 | type=float, 69 | help='The dropout rate for all RNN layers.') 70 | model_parser.add_argument( 71 | '--non_lin', 72 | default=True, 73 | type=str2bool, 74 | help='Whether to use non linearity after linear layers or not.') 75 | model_parser.add_argument( 76 | '--transition_bias', 77 | default=None, 78 | type=float, 79 | help='The value of p0, corresponding to Eq. (6) in the ' 80 | 'paper. If the value is given, we will fix to this value. If the ' 81 | 'value is None, we will estimate it from training data ' 82 | 'using Eq. (13) in the paper.') 83 | model_parser.add_argument( 84 | '--crp_alpha', 85 | default=None, 86 | type=float, 87 | help='The value of alpha for the Chinese restaurant process (CRP), ' 88 | 'corresponding to Eq. (7) in the paper. If the value is given,' 89 | 'we will fix to this value. If the value is None, we will estimate ' 90 | 'it from training data.') 91 | model_parser.add_argument( 92 | '--sigma2', 93 | default=None, 94 | type=float, 95 | help='The value of sigma squared, corresponding to Eq. (11) in the ' 96 | 'paper. If the value is given, we will fix to this value. If the ' 97 | 'value is None, we will estimate it from training data.') 98 | model_parser.add_argument( 99 | '--verbosity', 100 | default=2, 101 | type=int, 102 | help='How verbose will the logging information be. Higher value ' 103 | 'represents more verbose information. A general guideline: ' 104 | '0 for errors; 1 for finishing important steps; ' 105 | '2 for finishing less important steps; 3 or above for debugging ' 106 | 'information.') 107 | model_parser.add_argument( 108 | '--enable_cuda', 109 | default=True, 110 | type=str2bool, 111 | help='Whether we should use CUDA if it is avaiable. If False, we will ' 112 | 'always use CPU.') 113 | 114 | # training configurations 115 | training_parser = argparse.ArgumentParser( 116 | description='Training configurations.', add_help=False) 117 | 118 | training_parser.add_argument( 119 | '--optimizer', 120 | '-o', 121 | default='adam', 122 | choices=['adam'], 123 | help='The optimizer for training.') 124 | training_parser.add_argument( 125 | '--learning_rate', 126 | '-l', 127 | default=1e-3, 128 | type=float, 129 | help='The leaning rate for training.') 130 | training_parser.add_argument( 131 | '--train_iteration', 132 | '-t', 133 | default=20000, 134 | type=int, 135 | help='The total number of training iterations.') 136 | training_parser.add_argument( 137 | '--epochs', 138 | '-e', 139 | default=10, 140 | type=int, 141 | help='The total number of training epochs.') 142 | training_parser.add_argument( 143 | '--batch_size', 144 | '-b', 145 | default=10, 146 | type=int, 147 | help='The batch size for training.') 148 | training_parser.add_argument( 149 | '--num_permutations', 150 | default=10, 151 | type=int, 152 | help='The number of permutations per utterance sampled in the training ' 153 | 'data.') 154 | training_parser.add_argument( 155 | '--sigma_alpha', 156 | default=1.0, 157 | type=float, 158 | help='The inverse gamma shape for estimating sigma2. This value is only ' 159 | 'meaningful when sigma2 is not given, and estimated from data.') 160 | training_parser.add_argument( 161 | '--sigma_beta', 162 | default=1.0, 163 | type=float, 164 | help='The inverse gamma scale for estimating sigma2. This value is only ' 165 | 'meaningful when sigma2 is not given, and estimated from data.') 166 | training_parser.add_argument( 167 | '--regularization_weight', 168 | '-r', 169 | default=1e-5, 170 | type=float, 171 | help='The network regularization multiplicative.') 172 | training_parser.add_argument( 173 | '--grad_max_norm', 174 | default=5.0, 175 | type=float, 176 | help='Max norm of the gradient.') 177 | training_parser.add_argument( 178 | '--enforce_cluster_id_uniqueness', 179 | default=True, 180 | type=str2bool, 181 | help='Whether to enforce cluster ID uniqueness across different ' 182 | 'training sequences. Only effective when the first input to fit() ' 183 | 'is a list of sequences. In general, assume the cluster IDs for two ' 184 | 'sequences are [a, b] and [a, c]. If the `a` from the two sequences ' 185 | 'are not the same label, then this arg should be True.') 186 | training_parser.add_argument( 187 | '--loss_samples', 188 | default=-1, 189 | type=int, 190 | help='if loss_samples > 0 then it represents the number of embeddings ' 191 | 'to be sampled in the Sample Mean Loss otherwise it will be ignored ' 192 | 'and the loss will be computed as per the original UIS-RNN') 193 | 194 | # inference configurations 195 | inference_parser = argparse.ArgumentParser( 196 | description='Inference configurations.', add_help=False) 197 | 198 | inference_parser.add_argument( 199 | '--beam_size', 200 | '-s', 201 | default=10, 202 | type=int, 203 | help='The beam search size for inference.') 204 | inference_parser.add_argument( 205 | '--look_ahead', 206 | default=1, 207 | type=int, 208 | help='The number of look ahead steps during inference.') 209 | inference_parser.add_argument( 210 | '--test_iteration', 211 | default=2, 212 | type=int, 213 | help='During inference, we concatenate M duplicates of the test ' 214 | 'sequence, and run inference on this concatenated sequence. ' 215 | 'Then we return the inference results on the last duplicate as the ' 216 | 'final prediction for the test sequence.') 217 | 218 | # a super parser for sanity checks 219 | super_parser = argparse.ArgumentParser( 220 | parents=[model_parser, training_parser, inference_parser]) 221 | 222 | # get arguments 223 | super_parser.parse_args() 224 | model_args, _ = model_parser.parse_known_args() 225 | training_args, _ = training_parser.parse_known_args() 226 | inference_args, _ = inference_parser.parse_known_args() 227 | 228 | return (model_args, training_args, inference_args) 229 | -------------------------------------------------------------------------------- /uisrnn/evals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for model evaluation.""" 15 | 16 | from scipy import optimize 17 | import numpy as np 18 | 19 | 20 | def get_list_inverse_index(unique_ids): 21 | """Get value to position index from a list of unique ids. 22 | 23 | Args: 24 | unique_ids: A list of unique integers of strings. 25 | 26 | Returns: 27 | result: a dict from value to position 28 | 29 | Raises: 30 | TypeError: If unique_ids is not a list. 31 | """ 32 | if not isinstance(unique_ids, list): 33 | raise TypeError('unique_ids must be a list') 34 | result = dict() 35 | for i, unique_id in enumerate(unique_ids): 36 | result[unique_id] = i 37 | return result 38 | 39 | 40 | def compute_sequence_match_accuracy(sequence1, sequence2): 41 | """Compute the accuracy between two sequences by finding optimal matching. 42 | 43 | Args: 44 | sequence1: A list of integers or strings. 45 | sequence2: A list of integers or strings. 46 | 47 | Returns: 48 | accuracy: sequence matching accuracy as a number in [0.0, 1.0] 49 | 50 | Raises: 51 | TypeError: If sequence1 or sequence2 is not list. 52 | ValueError: If sequence1 and sequence2 are not same size. 53 | """ 54 | if not isinstance(sequence1, list) or not isinstance(sequence2, list): 55 | raise TypeError('sequence1 and sequence2 must be lists') 56 | if not sequence1 or len(sequence1) != len(sequence2): 57 | raise ValueError( 58 | 'sequence1 and sequence2 must have the same non-zero length') 59 | # get unique ids from sequences 60 | unique_ids1 = sorted(set(sequence1)) 61 | unique_ids2 = sorted(set(sequence2)) 62 | inverse_index1 = get_list_inverse_index(unique_ids1) 63 | inverse_index2 = get_list_inverse_index(unique_ids2) 64 | # get the count matrix 65 | count_matrix = np.zeros((len(unique_ids1), len(unique_ids2))) 66 | for item1, item2 in zip(sequence1, sequence2): 67 | index1 = inverse_index1[item1] 68 | index2 = inverse_index2[item2] 69 | count_matrix[index1, index2] += 1.0 70 | row_index, col_index = optimize.linear_sum_assignment(-count_matrix) 71 | optimal_match_count = count_matrix[row_index, col_index].sum() 72 | accuracy = optimal_match_count / len(sequence1) 73 | return accuracy 74 | -------------------------------------------------------------------------------- /uisrnn/loss_func.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Loss functions for training.""" 15 | 16 | import torch 17 | 18 | 19 | def weighted_mse_loss(input_tensor, target_tensor, weight=1): 20 | """Compute weighted MSE loss. 21 | 22 | Note that we are doing weighted loss that only sum up over non-zero entries. 23 | 24 | Args: 25 | input_tensor: input tensor 26 | target_tensor: target tensor 27 | weight: weight tensor, in this case 1/sigma^2 28 | 29 | Returns: 30 | the weighted MSE loss 31 | """ 32 | observation_dim = input_tensor.size()[-1] 33 | streched_tensor = ((input_tensor - target_tensor) ** 2).view( 34 | -1, observation_dim) 35 | entry_num = float(streched_tensor.size()[0]) 36 | non_zero_entry_num = torch.sum(streched_tensor[:, 0] != 0).float() 37 | weighted_tensor = torch.mm( 38 | ((input_tensor - target_tensor)**2).view(-1, observation_dim), 39 | (torch.diag(weight.float().view(-1)))) 40 | return torch.mean( 41 | weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num 42 | 43 | 44 | def sigma2_prior_loss(num_non_zero, sigma_alpha, sigma_beta, sigma2): 45 | """Compute sigma2 prior loss. 46 | 47 | Args: 48 | num_non_zero: since rnn_truth is a collection of different length sequences 49 | padded with zeros to fit them into a tensor, we count the sum of 50 | 'real lengths' of all sequences 51 | sigma_alpha: inverse gamma shape 52 | sigma_beta: inverse gamma scale 53 | sigma2: sigma squared 54 | 55 | Returns: 56 | the sigma2 prior loss 57 | """ 58 | return ((2 * sigma_alpha + num_non_zero + 2) / 59 | (2 * num_non_zero) * torch.log(sigma2)).sum() + ( 60 | sigma_beta / (sigma2 * num_non_zero)).sum() 61 | 62 | 63 | def regularization_loss(params, weight): 64 | """Compute regularization loss. 65 | 66 | Args: 67 | params: iterable of all parameters 68 | weight: weight for the regularization term 69 | 70 | Returns: 71 | the regularization loss 72 | """ 73 | l2_reg = 0 74 | for param in params: 75 | l2_reg += torch.norm(param) 76 | return weight * l2_reg 77 | -------------------------------------------------------------------------------- /uisrnn/uisrnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The UIS-RNN model.""" 15 | 16 | import numpy as np 17 | import torch 18 | from torch import autograd 19 | from torch import nn 20 | from torch import optim 21 | import torch.nn.functional as F 22 | 23 | from uisrnn import loss_func 24 | from uisrnn import utils 25 | 26 | _INITIAL_SIGMA2_VALUE = 0.1 27 | 28 | 29 | class CoreRNN(nn.Module): 30 | """The core Recurent Neural Network used by UIS-RNN.""" 31 | 32 | def __init__(self, input_dim, hidden_size, rnn_depth, fc_depth, 33 | observation_dim, dropout=0, non_lin=True): 34 | super(CoreRNN, self).__init__() 35 | self.hidden_size = hidden_size 36 | self.non_lin = non_lin 37 | if rnn_depth >= 2: 38 | self.gru = nn.GRU(input_dim, hidden_size, rnn_depth, dropout=dropout) 39 | else: 40 | self.gru = nn.GRU(input_dim, hidden_size, rnn_depth) 41 | self.linears = nn.ModuleList( 42 | [nn.Linear(hidden_size, hidden_size) for _ in range(fc_depth)]) 43 | self.last_linear = nn.Linear(hidden_size, observation_dim) 44 | 45 | def forward(self, input_seq, hidden=None): 46 | out, hidden = self.gru(input_seq, hidden) 47 | if isinstance(out, torch.nn.utils.rnn.PackedSequence): 48 | out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=False) 49 | for layer in self.linears: 50 | out = layer(out) 51 | if self.non_lin: 52 | out = F.relu(out) 53 | mean = self.last_linear(out) 54 | return mean, hidden 55 | 56 | 57 | class BeamState: 58 | """Structure that contains necessary states for beam search.""" 59 | 60 | def __init__(self, source=None): 61 | if not source: 62 | self.mean_set = [] 63 | self.hidden_set = [] 64 | self.neg_likelihood = 0 65 | self.trace = [] 66 | self.block_counts = [] 67 | else: 68 | self.mean_set = source.mean_set.copy() 69 | self.hidden_set = source.hidden_set.copy() 70 | self.trace = source.trace.copy() 71 | self.block_counts = source.block_counts.copy() 72 | self.neg_likelihood = source.neg_likelihood 73 | 74 | def append(self, mean, hidden, cluster): 75 | """Append new item to the BeamState.""" 76 | self.mean_set.append(mean.clone()) 77 | self.hidden_set.append(hidden.clone()) 78 | self.block_counts.append(1) 79 | self.trace.append(cluster) 80 | 81 | 82 | class UISRNN: 83 | """Unbounded Interleaved-State Recurrent Neural Networks.""" 84 | 85 | def __init__(self, args): 86 | """Construct the UISRNN object. 87 | 88 | Args: 89 | args: Model configurations. See `arguments.py` for details. 90 | """ 91 | self.observation_dim = args.observation_dim 92 | self.device = torch.device( 93 | 'cuda:0' if (torch.cuda.is_available() and args.enable_cuda) else 'cpu') 94 | self.rnn_model = CoreRNN(self.observation_dim, args.rnn_hidden_size, 95 | args.rnn_depth, args.fc_depth, 96 | self.observation_dim, args.rnn_dropout, 97 | args.non_lin).to(self.device) 98 | self.rnn_init_hidden = nn.Parameter( 99 | torch.zeros(args.rnn_depth, 1, args.rnn_hidden_size).to(self.device)) 100 | # booleans indicating which variables are trainable 101 | self.estimate_sigma2 = (args.sigma2 is None) 102 | self.estimate_transition_bias = (args.transition_bias is None) 103 | self.estimate_crp_alpha = (args.crp_alpha is None) 104 | # initial values of variables 105 | sigma2 = _INITIAL_SIGMA2_VALUE if self.estimate_sigma2 else args.sigma2 106 | self.sigma2 = nn.Parameter( 107 | sigma2 * torch.ones(self.observation_dim).to(self.device)) 108 | self.transition_bias = args.transition_bias 109 | self.transition_bias_denominator = 0.0 110 | self.crp_alpha = args.crp_alpha 111 | self.crp_alpha_denominator = 0.0 112 | self.logger = utils.Logger(args.verbosity) 113 | self.current_iter = 0 114 | 115 | 116 | def _get_optimizer(self, optimizer, learning_rate): 117 | """Get optimizer for UISRNN. 118 | 119 | Args: 120 | optimizer: string - name of the optimizer. 121 | learning_rate: - learning rate for the entire model. 122 | We do not customize learning rate for separate parts. 123 | 124 | Returns: 125 | a pytorch "optim" object 126 | """ 127 | params = [ 128 | { 129 | 'params': self.rnn_model.parameters() 130 | }, # rnn parameters 131 | { 132 | 'params': self.rnn_init_hidden 133 | } # rnn initial hidden state 134 | ] 135 | if self.estimate_sigma2: # train sigma2 136 | params.append({ 137 | 'params': self.sigma2 138 | }) # variance parameters 139 | assert optimizer == 'adam', 'Only adam optimizer is supported.' 140 | return optim.Adam(params, lr=learning_rate) 141 | 142 | def save(self, filepath): 143 | """Save the model to a file. 144 | 145 | Args: 146 | filepath: the path of the file. 147 | """ 148 | torch.save({ 149 | 'rnn_state_dict': self.rnn_model.state_dict(), 150 | 'rnn_init_hidden': self.rnn_init_hidden.detach().cpu().numpy(), 151 | 'transition_bias': self.transition_bias, 152 | 'transition_bias_denominator': self.transition_bias_denominator, 153 | 'crp_alpha': self.crp_alpha, 154 | 'crp_alpha_denominator': self.crp_alpha_denominator, 155 | 'sigma2': self.sigma2.detach().cpu().numpy(), 156 | 'current_iter': self.current_iter}, filepath) 157 | 158 | def load(self, filepath): 159 | """Load the model from a file. 160 | 161 | Args: 162 | filepath: the path of the file. 163 | """ 164 | var_dict = torch.load(filepath) 165 | self.rnn_model.load_state_dict(var_dict['rnn_state_dict']) 166 | self.rnn_init_hidden = nn.Parameter( 167 | torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device)) 168 | self.transition_bias = float(var_dict['transition_bias']) 169 | self.transition_bias_denominator = float( 170 | var_dict['transition_bias_denominator']) 171 | self.crp_alpha = float(var_dict['crp_alpha']) 172 | self.crp_alpha_denominator = float(var_dict['crp_alpha_denominator']) 173 | self.sigma2 = nn.Parameter( 174 | torch.from_numpy(var_dict['sigma2']).to(self.device)) 175 | self.current_iter = int(var_dict['current_iter']) 176 | 177 | self.logger.print( 178 | 3, 'Loaded model with transition_bias={}, crp_alpha={}, sigma2={}, ' 179 | 'rnn_init_hidden={}'.format( 180 | self.transition_bias, self.crp_alpha, var_dict['sigma2'], 181 | var_dict['rnn_init_hidden'])) 182 | 183 | def fit_concatenated(self, train_sequence, train_cluster_id, args): 184 | """Fit UISRNN model to concatenated sequence and cluster_id. 185 | 186 | Args: 187 | train_sequence: the training observation sequence, which is a 188 | 2-dim numpy array of real numbers, of size `N * D`. 189 | 190 | - `N`: summation of lengths of all utterances. 191 | - `D`: observation dimension. 192 | 193 | For example, 194 | ``` 195 | train_sequence = 196 | [[1.2 3.0 -4.1 6.0] --> an entry of speaker #0 from utterance 'iaaa' 197 | [0.8 -1.1 0.4 0.5] --> an entry of speaker #1 from utterance 'iaaa' 198 | [-0.2 1.0 3.8 5.7] --> an entry of speaker #0 from utterance 'iaaa' 199 | [3.8 -0.1 1.5 2.3] --> an entry of speaker #0 from utterance 'ibbb' 200 | [1.2 1.4 3.6 -2.7]] --> an entry of speaker #0 from utterance 'ibbb' 201 | ``` 202 | Here `N=5`, `D=4`. 203 | 204 | We concatenate all training utterances into this single sequence. 205 | train_cluster_id: the speaker id sequence, which is 1-dim list or 206 | numpy array of strings, of size `N`. 207 | For example, 208 | ``` 209 | train_cluster_id = 210 | ['iaaa_0', 'iaaa_1', 'iaaa_0', 'ibbb_0', 'ibbb_0'] 211 | ``` 212 | 'iaaa_0' means the entry belongs to speaker #0 in utterance 'iaaa'. 213 | 214 | Note that the order of entries within an utterance are preserved, 215 | and all utterances are simply concatenated together. 216 | args: Training configurations. See `arguments.py` for details. 217 | 218 | Raises: 219 | TypeError: If train_sequence or train_cluster_id is of wrong type. 220 | ValueError: If train_sequence or train_cluster_id has wrong dimension. 221 | """ 222 | # check type 223 | if (not isinstance(train_sequence, np.ndarray) or 224 | train_sequence.dtype != float): 225 | raise TypeError('train_sequence should be a numpy array of float type.') 226 | if isinstance(train_cluster_id, list): 227 | train_cluster_id = np.array(train_cluster_id) 228 | if (not isinstance(train_cluster_id, np.ndarray) or 229 | not train_cluster_id.dtype.name.startswith(('str', 'unicode'))): 230 | raise TypeError('train_cluster_id type be a numpy array of strings.') 231 | # check dimension 232 | if train_sequence.ndim != 2: 233 | raise ValueError('train_sequence must be 2-dim array.') 234 | if train_cluster_id.ndim != 1: 235 | raise ValueError('train_cluster_id must be 1-dim array.') 236 | # check length and size 237 | train_total_length, observation_dim = train_sequence.shape 238 | if observation_dim != self.observation_dim: 239 | raise ValueError('train_sequence does not match the dimension specified ' 240 | 'by args.observation_dim.') 241 | if train_total_length != len(train_cluster_id): 242 | raise ValueError('train_sequence length is not equal to ' 243 | 'train_cluster_id length.') 244 | 245 | self.rnn_model.train() 246 | optimizer = self._get_optimizer(optimizer=args.optimizer, 247 | learning_rate=args.learning_rate) 248 | 249 | sub_sequences, seq_lengths = utils.resize_sequence( 250 | sequence=train_sequence, 251 | cluster_id=train_cluster_id, 252 | num_permutations=args.num_permutations) 253 | 254 | # For batch learning, pack the entire dataset. 255 | if args.batch_size is None: 256 | packed_train_sequence, rnn_truth = utils.pack_sequence( 257 | sub_sequences, 258 | seq_lengths, 259 | args.batch_size, 260 | self.observation_dim, 261 | self.device, 262 | args.loss_samples) 263 | train_loss = [] 264 | for _ in range(args.train_iteration): 265 | self.current_iter += 1 266 | optimizer.zero_grad() 267 | # For online learning, pack a subset in each iteration. 268 | if args.batch_size is not None: 269 | packed_train_sequence, rnn_truth = utils.pack_sequence( 270 | sub_sequences, 271 | seq_lengths, 272 | args.batch_size, 273 | self.observation_dim, 274 | self.device, 275 | args.loss_samples) 276 | hidden = self.rnn_init_hidden.repeat(1, args.batch_size, 1) 277 | mean, _ = self.rnn_model(packed_train_sequence, hidden) 278 | # use mean to predict 279 | mean = torch.cumsum(mean, dim=0) 280 | mean_size = mean.size() 281 | mean = torch.mm( 282 | torch.diag( 283 | 1.0 / torch.arange(1, mean_size[0] + 1).float().to(self.device)), 284 | mean.view(mean_size[0], -1)) 285 | mean = mean.view(mean_size) 286 | 287 | # Likelihood part. 288 | loss1 = loss_func.weighted_mse_loss( 289 | input_tensor=(rnn_truth != 0).float() * mean[:-1, :, :], 290 | target_tensor=rnn_truth, 291 | weight=1 / (2 * self.sigma2)) 292 | 293 | # Sigma2 prior part. 294 | weight = (((rnn_truth != 0).float() * mean[:-1, :, :] - rnn_truth) 295 | ** 2).view(-1, observation_dim) 296 | num_non_zero = torch.sum((weight != 0).float(), dim=0).squeeze() 297 | loss2 = loss_func.sigma2_prior_loss( 298 | num_non_zero, args.sigma_alpha, args.sigma_beta, self.sigma2) 299 | 300 | # Regularization part. 301 | loss3 = loss_func.regularization_loss( 302 | self.rnn_model.parameters(), args.regularization_weight) 303 | 304 | loss = loss1 + loss2 + loss3 305 | loss.backward() 306 | nn.utils.clip_grad_norm_(self.rnn_model.parameters(), args.grad_max_norm) 307 | optimizer.step() 308 | # avoid numerical issues 309 | self.sigma2.data.clamp_(min=1e-6) 310 | 311 | if (np.remainder(self.current_iter, 10) == 0 or 312 | self.current_iter == args.train_iteration - 1): 313 | self.logger.print( 314 | 2, 315 | 'Iter: {:d} \t' 316 | 'Training Loss: {:.4f} \n' 317 | ' Negative Log Likelihood: {:.4f}\t' 318 | 'Sigma2 Prior: {:.4f}\t' 319 | 'Regularization: {:.4f}'.format( 320 | self.current_iter, 321 | float(loss.data), 322 | float(loss1.data), 323 | float(loss2.data), 324 | float(loss3.data))) 325 | yield {'total': loss.data, 326 | 'nll': loss1.data, 327 | 's2p': loss2.data, 328 | 'regular': loss3.data}, self.current_iter 329 | train_loss.append(float(loss1.data)) # only save the likelihood part 330 | self.logger.print( 331 | 1, 'Done training with {} iterations'.format(args.train_iteration)) 332 | 333 | 334 | def fit(self, train_sequences, train_cluster_ids, args): 335 | """Fit UISRNN model. 336 | 337 | Args: 338 | train_sequences: Either a list of training sequences, or a single 339 | concatenated training sequence: 340 | 341 | 1. train_sequences is list, and each element is a 2-dim numpy array 342 | of real numbers, of size: `length * D`. 343 | The length varies among different sequences, but the D is the same. 344 | In speaker diarization, each sequence is the sequence of speaker 345 | embeddings of one utterance. 346 | 2. train_sequences is a single concatenated sequence, which is a 347 | 2-dim numpy array of real numbers. See `fit_concatenated()` 348 | for more details. 349 | train_cluster_ids: Ground truth labels for train_sequences: 350 | 351 | 1. if train_sequences is a list, this must also be a list of the same 352 | size, each element being a 1-dim list or numpy array of strings. 353 | 2. if train_sequences is a single concatenated sequence, this 354 | must also be the concatenated 1-dim list or numpy array of strings 355 | args: Training configurations. See `arguments.py` for details. 356 | 357 | Raises: 358 | TypeError: If train_sequences or train_cluster_ids is of wrong type. 359 | """ 360 | if isinstance(train_sequences, np.ndarray): 361 | # train_sequences is already the concatenated sequence 362 | if self.estimate_transition_bias: 363 | # see issue #55: https://github.com/google/uis-rnn/issues/55 364 | self.logger.print( 365 | 2, 366 | 'Warning: transition_bias cannot be correctly estimated from a ' 367 | 'concatenated sequence; train_sequences will be treated as a ' 368 | 'single sequence. This can lead to inaccurate estimation of ' 369 | 'transition_bias. Please, consider estimating transition_bias ' 370 | 'before concatenating the sequences and passing it as argument.') 371 | train_sequences = [train_sequences] 372 | train_cluster_ids = [train_cluster_ids] 373 | elif isinstance(train_sequences, list): 374 | # train_sequences is a list of un-concatenated sequences 375 | # we will concatenate it later, after estimating transition_bias 376 | pass 377 | else: 378 | raise TypeError('train_sequences must be a list or numpy.ndarray') 379 | 380 | # estimate transition_bias 381 | if self.estimate_transition_bias: 382 | (transition_bias, 383 | transition_bias_denominator) = utils.estimate_transition_bias( 384 | train_cluster_ids) 385 | # set or update transition_bias 386 | if self.transition_bias is None: 387 | self.transition_bias = transition_bias 388 | self.transition_bias_denominator = transition_bias_denominator 389 | else: 390 | self.transition_bias = ( 391 | self.transition_bias * self.transition_bias_denominator + 392 | transition_bias * transition_bias_denominator) / ( 393 | self.transition_bias_denominator + transition_bias_denominator) 394 | self.transition_bias_denominator += transition_bias_denominator 395 | 396 | # estimate crp_alpha 397 | if self.estimate_crp_alpha: 398 | (crp_alpha, 399 | crp_alpha_denominator) = utils.estimate_crp_alpha(train_cluster_ids) 400 | # set or update crp_alpha 401 | if self.crp_alpha is None: 402 | self.crp_alpha = crp_alpha 403 | self.crp_alpha_denominator = crp_alpha_denominator 404 | else: 405 | self.crp_alpha = ( 406 | self.crp_alpha * self.crp_alpha_denominator + 407 | crp_alpha * crp_alpha_denominator) / ( 408 | self.crp_alpha_denominator + crp_alpha) 409 | self.crp_alpha_denominator += crp_alpha_denominator 410 | 411 | # concatenate train_sequences 412 | (concatenated_train_sequence, 413 | concatenated_train_cluster_id) = utils.concatenate_training_data( 414 | train_sequences, 415 | train_cluster_ids, 416 | args.enforce_cluster_id_uniqueness, 417 | True) 418 | 419 | return self.fit_concatenated( 420 | concatenated_train_sequence, concatenated_train_cluster_id, args) 421 | 422 | def _update_beam_state(self, beam_state, look_ahead_seq, cluster_seq): 423 | """Update a beam state given a look ahead sequence and known cluster 424 | assignments. 425 | 426 | Args: 427 | beam_state: A BeamState object. 428 | look_ahead_seq: Look ahead sequence, size: look_ahead*D. 429 | look_ahead: number of step to look ahead in the beam search. 430 | D: observation dimension 431 | cluster_seq: Cluster assignment sequence for look_ahead_seq. 432 | 433 | Returns: 434 | new_beam_state: An updated BeamState object. 435 | """ 436 | 437 | loss = 0 438 | new_beam_state = BeamState(beam_state) 439 | for sub_idx, cluster in enumerate(cluster_seq): 440 | if cluster > len(new_beam_state.mean_set): # invalid trace 441 | new_beam_state.neg_likelihood = float('inf') 442 | break 443 | elif cluster < len(new_beam_state.mean_set): # existing cluster 444 | last_cluster = new_beam_state.trace[-1] 445 | loss = loss_func.weighted_mse_loss( 446 | input_tensor=torch.squeeze(new_beam_state.mean_set[cluster]), 447 | target_tensor=look_ahead_seq[sub_idx, :], 448 | weight=1 / (2 * self.sigma2)).cpu().detach().numpy() 449 | if cluster == last_cluster: 450 | loss -= np.log(1 - self.transition_bias) 451 | else: 452 | loss -= np.log(self.transition_bias) + np.log( 453 | new_beam_state.block_counts[cluster]) - np.log( 454 | sum(new_beam_state.block_counts) + self.crp_alpha) 455 | # update new mean and new hidden 456 | mean, hidden = self.rnn_model( 457 | look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0), 458 | new_beam_state.hidden_set[cluster]) 459 | new_beam_state.mean_set[cluster] = (new_beam_state.mean_set[cluster]*( 460 | (np.array(new_beam_state.trace) == cluster).sum() - 461 | 1).astype(float) + mean.clone()) / ( 462 | np.array(new_beam_state.trace) == cluster).sum().astype( 463 | float) # use mean to predict 464 | new_beam_state.hidden_set[cluster] = hidden.clone() 465 | if cluster != last_cluster: 466 | new_beam_state.block_counts[cluster] += 1 467 | new_beam_state.trace.append(cluster) 468 | else: # new cluster 469 | init_input = autograd.Variable( 470 | torch.zeros(self.observation_dim) 471 | ).unsqueeze(0).unsqueeze(0).to(self.device) 472 | mean, hidden = self.rnn_model(init_input, 473 | self.rnn_init_hidden) 474 | loss = loss_func.weighted_mse_loss( 475 | input_tensor=torch.squeeze(mean), 476 | target_tensor=look_ahead_seq[sub_idx, :], 477 | weight=1 / (2 * self.sigma2)).cpu().detach().numpy() 478 | loss -= np.log(self.transition_bias) + np.log( 479 | self.crp_alpha) - np.log( 480 | sum(new_beam_state.block_counts) + self.crp_alpha) 481 | # update new min and new hidden 482 | mean, hidden = self.rnn_model( 483 | look_ahead_seq[sub_idx, :].unsqueeze(0).unsqueeze(0), 484 | hidden) 485 | new_beam_state.append(mean, hidden, cluster) 486 | new_beam_state.neg_likelihood += loss 487 | return new_beam_state 488 | 489 | def _calculate_score(self, beam_state, look_ahead_seq): 490 | """Calculate negative log likelihoods for all possible state allocations 491 | of a look ahead sequence, according to the current beam state. 492 | 493 | Args: 494 | beam_state: A BeamState object. 495 | look_ahead_seq: Look ahead sequence, size: look_ahead*D. 496 | look_ahead: number of step to look ahead in the beam search. 497 | D: observation dimension 498 | 499 | Returns: 500 | beam_score_set: a set of scores for each possible state allocation. 501 | """ 502 | 503 | look_ahead, _ = look_ahead_seq.shape 504 | beam_num_clusters = len(beam_state.mean_set) 505 | beam_score_set = float('inf') * np.ones( 506 | beam_num_clusters + 1 + np.arange(look_ahead)) 507 | for cluster_seq, _ in np.ndenumerate(beam_score_set): 508 | updated_beam_state = self._update_beam_state(beam_state, 509 | look_ahead_seq, cluster_seq) 510 | beam_score_set[cluster_seq] = updated_beam_state.neg_likelihood 511 | return beam_score_set 512 | 513 | def predict_single(self, test_sequence, args): 514 | """Predict labels for a single test sequence using UISRNN model. 515 | 516 | Args: 517 | test_sequence: the test observation sequence, which is 2-dim numpy array 518 | of real numbers, of size `N * D`. 519 | 520 | - `N`: length of one test utterance. 521 | - `D` : observation dimension. 522 | 523 | For example: 524 | ``` 525 | test_sequence = 526 | [[2.2 -1.0 3.0 5.6] --> 1st entry of utterance 'iccc' 527 | [0.5 1.8 -3.2 0.4] --> 2nd entry of utterance 'iccc' 528 | [-2.2 5.0 1.8 3.7] --> 3rd entry of utterance 'iccc' 529 | [-3.8 0.1 1.4 3.3] --> 4th entry of utterance 'iccc' 530 | [0.1 2.7 3.5 -1.7]] --> 5th entry of utterance 'iccc' 531 | ``` 532 | Here `N=5`, `D=4`. 533 | args: Inference configurations. See `arguments.py` for details. 534 | 535 | Returns: 536 | predicted_cluster_id: predicted speaker id sequence, which is 537 | an array of integers, of size `N`. 538 | For example, `predicted_cluster_id = [0, 1, 0, 0, 1]` 539 | 540 | Raises: 541 | TypeError: If test_sequence is of wrong type. 542 | ValueError: If test_sequence has wrong dimension. 543 | """ 544 | # check type 545 | if (not isinstance(test_sequence, np.ndarray) or 546 | test_sequence.dtype != float): 547 | raise TypeError('test_sequence should be a numpy array of float type.') 548 | # check dimension 549 | if test_sequence.ndim != 2: 550 | raise ValueError('test_sequence must be 2-dim array.') 551 | # check size 552 | test_sequence_length, observation_dim = test_sequence.shape 553 | if observation_dim != self.observation_dim: 554 | raise ValueError('test_sequence does not match the dimension specified ' 555 | 'by args.observation_dim.') 556 | 557 | self.rnn_model.eval() 558 | test_sequence = np.tile(test_sequence, (args.test_iteration, 1)) 559 | test_sequence = autograd.Variable( 560 | torch.from_numpy(test_sequence).float()).to(self.device) 561 | # bookkeeping for beam search 562 | beam_set = [BeamState()] 563 | for num_iter in np.arange(0, args.test_iteration * test_sequence_length, 564 | args.look_ahead): 565 | max_clusters = max([len(beam_state.mean_set) for beam_state in beam_set]) 566 | look_ahead_seq = test_sequence[num_iter: num_iter + args.look_ahead, :] 567 | look_ahead_seq_length = look_ahead_seq.shape[0] 568 | score_set = float('inf') * np.ones( 569 | np.append( 570 | args.beam_size, max_clusters + 1 + np.arange( 571 | look_ahead_seq_length))) 572 | for beam_rank, beam_state in enumerate(beam_set): 573 | beam_score_set = self._calculate_score(beam_state, look_ahead_seq) 574 | score_set[beam_rank, :] = np.pad( 575 | beam_score_set, 576 | np.tile([[0, max_clusters - len(beam_state.mean_set)]], 577 | (look_ahead_seq_length, 1)), 'constant', 578 | constant_values=float('inf')) 579 | # find top scores 580 | score_ranked = np.sort(score_set, axis=None) 581 | score_ranked[score_ranked == float('inf')] = 0 582 | score_ranked = np.trim_zeros(score_ranked) 583 | idx_ranked = np.argsort(score_set, axis=None) 584 | updated_beam_set = [] 585 | for new_beam_rank in range( 586 | np.min((len(score_ranked), args.beam_size))): 587 | total_idx = np.unravel_index(idx_ranked[new_beam_rank], 588 | score_set.shape) 589 | prev_beam_rank = total_idx[0] 590 | cluster_seq = total_idx[1:] 591 | updated_beam_state = self._update_beam_state( 592 | beam_set[prev_beam_rank], look_ahead_seq, cluster_seq) 593 | updated_beam_set.append(updated_beam_state) 594 | beam_set = updated_beam_set 595 | predicted_cluster_id = beam_set[0].trace[-test_sequence_length:] 596 | return predicted_cluster_id 597 | 598 | def predict(self, test_sequences, args): 599 | """Predict labels for a single or many test sequences using UISRNN model. 600 | 601 | Args: 602 | test_sequences: Either a list of test sequences, or a single test 603 | sequence. Each test sequence is a 2-dim numpy array 604 | of real numbers. See `predict_single()` for details. 605 | args: Inference configurations. See `arguments.py` for details. 606 | 607 | Returns: 608 | predicted_cluster_ids: Predicted labels for test_sequences. 609 | 610 | 1. if test_sequences is a list, predicted_cluster_ids will be a list 611 | of the same size, where each element being a 1-dim list of strings. 612 | 2. if test_sequences is a single sequence, predicted_cluster_ids will 613 | be a 1-dim list of strings 614 | 615 | Raises: 616 | TypeError: If test_sequences is of wrong type. 617 | """ 618 | # check type 619 | if isinstance(test_sequences, np.ndarray): 620 | return self.predict_single(test_sequences, args) 621 | if isinstance(test_sequences, list): 622 | return [self.predict_single(test_sequence, args) 623 | for test_sequence in test_sequences] 624 | raise TypeError('test_sequences should be either a list or numpy array.') 625 | -------------------------------------------------------------------------------- /uisrnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for UIS-RNN.""" 15 | 16 | import random 17 | import string 18 | 19 | import numpy as np 20 | import torch 21 | from torch import autograd 22 | 23 | 24 | class Logger: 25 | """A class for printing logging information to screen.""" 26 | 27 | def __init__(self, verbosity): 28 | self._verbosity = verbosity 29 | 30 | def print(self, level, message): 31 | """Print a message if level is not higher than verbosity. 32 | 33 | Args: 34 | level: the level of this message, smaller value means more important 35 | message: the message to be printed 36 | """ 37 | if level <= self._verbosity: 38 | print(message) 39 | 40 | 41 | def generate_random_string(length=6): 42 | """Generate a random string of upper case letters and digits. 43 | 44 | Args: 45 | length: length of the generated string 46 | 47 | Returns: 48 | the generated string 49 | """ 50 | return ''.join([ 51 | random.choice(string.ascii_uppercase + string.digits) 52 | for _ in range(length)]) 53 | 54 | 55 | def enforce_cluster_id_uniqueness(cluster_ids): 56 | """Enforce uniqueness of cluster id across sequences. 57 | 58 | Args: 59 | cluster_ids: a list of 1-dim list/numpy.ndarray of strings 60 | 61 | Returns: 62 | a new list with same length of cluster_ids 63 | 64 | Raises: 65 | TypeError: if cluster_ids or its element has wrong type 66 | """ 67 | if not isinstance(cluster_ids, list): 68 | raise TypeError('cluster_ids must be a list') 69 | new_cluster_ids = [] 70 | for cluster_id in cluster_ids: 71 | sequence_id = generate_random_string() 72 | if isinstance(cluster_id, np.ndarray): 73 | cluster_id = cluster_id.tolist() 74 | if not isinstance(cluster_id, list): 75 | raise TypeError('Elements of cluster_ids must be list or numpy.ndarray') 76 | new_cluster_id = ['_'.join([sequence_id, s]) for s in cluster_id] 77 | new_cluster_ids.append(new_cluster_id) 78 | return new_cluster_ids 79 | 80 | 81 | def concatenate_training_data(train_sequences, train_cluster_ids, 82 | enforce_uniqueness=True, shuffle=True): 83 | """Concatenate training data. 84 | 85 | Args: 86 | train_sequences: a list of 2-dim numpy arrays to be concatenated 87 | train_cluster_ids: a list of 1-dim list/numpy.ndarray of strings 88 | enforce_uniqueness: a boolean indicated whether we should enfore uniqueness 89 | to train_cluster_ids 90 | shuffle: whether to randomly shuffle input order 91 | 92 | Returns: 93 | concatenated_train_sequence: a 2-dim numpy array 94 | concatenated_train_cluster_id: a list of strings 95 | 96 | Raises: 97 | TypeError: if input has wrong type 98 | ValueError: if sizes/dimensions of input or their elements are incorrect 99 | """ 100 | # check input 101 | if not isinstance(train_sequences, list) or not isinstance( 102 | train_cluster_ids, list): 103 | raise TypeError('train_sequences and train_cluster_ids must be lists') 104 | if len(train_sequences) != len(train_cluster_ids): 105 | raise ValueError( 106 | 'train_sequences and train_cluster_ids must have same size') 107 | train_cluster_ids = [ 108 | x.tolist() if isinstance(x, np.ndarray) else x 109 | for x in train_cluster_ids] 110 | global_observation_dim = None 111 | for i, (train_sequence, train_cluster_id) in enumerate( 112 | zip(train_sequences, train_cluster_ids)): 113 | train_length, observation_dim = train_sequence.shape 114 | if i == 0: 115 | global_observation_dim = observation_dim 116 | elif global_observation_dim != observation_dim: 117 | raise ValueError( 118 | 'train_sequences must have consistent observation dimension') 119 | if not isinstance(train_cluster_id, list): 120 | raise TypeError( 121 | 'Elements of train_cluster_ids must be list or numpy.ndarray') 122 | if len(train_cluster_id) != train_length: 123 | raise ValueError( 124 | 'Each train_sequence and its train_cluster_id must have same length') 125 | 126 | # enforce uniqueness 127 | if enforce_uniqueness: 128 | train_cluster_ids = enforce_cluster_id_uniqueness(train_cluster_ids) 129 | 130 | # random shuffle 131 | if shuffle: 132 | zipped_input = list(zip(train_sequences, train_cluster_ids)) 133 | random.shuffle(zipped_input) 134 | train_sequences, train_cluster_ids = zip(*zipped_input) 135 | 136 | # concatenate 137 | concatenated_train_sequence = np.concatenate(train_sequences, axis=0) 138 | concatenated_train_cluster_id = [x for train_cluster_id in train_cluster_ids 139 | for x in train_cluster_id] 140 | return concatenated_train_sequence, concatenated_train_cluster_id 141 | 142 | 143 | def sample_permuted_segments(index_sequence, number_samples): 144 | """Sample sequences with permuted blocks. 145 | 146 | Args: 147 | index_sequence: (integer array, size: L) 148 | - subsequence index 149 | For example, index_sequence = [1,2,6,10,11,12]. 150 | number_samples: (integer) 151 | - number of subsampled block-preserving permuted sequences. 152 | For example, number_samples = 5 153 | 154 | Returns: 155 | sampled_index_sequences: (a list of numpy arrays) - a list of subsampled 156 | block-preserving permuted sequences. For example, 157 | ``` 158 | sampled_index_sequences = 159 | [[10,11,12,1,2,6], 160 | [6,1,2,10,11,12], 161 | [1,2,10,11,12,6], 162 | [6,1,2,10,11,12], 163 | [1,2,6,10,11,12]] 164 | ``` 165 | The length of "sampled_index_sequences" is "number_samples". 166 | """ 167 | segments = [] 168 | if len(index_sequence) == 1: 169 | segments.append(index_sequence) 170 | else: 171 | prev = 0 172 | for i in range(len(index_sequence) - 1): 173 | if index_sequence[i + 1] != index_sequence[i] + 1: 174 | segments.append(index_sequence[prev:(i + 1)]) 175 | prev = i + 1 176 | if i + 1 == len(index_sequence) - 1: 177 | segments.append(index_sequence[prev:]) 178 | # sample permutations 179 | sampled_index_sequences = [] 180 | for _ in range(number_samples): 181 | segments_array = [] 182 | permutation = np.random.permutation(len(segments)) 183 | for permutation_item in permutation: 184 | segments_array.append(segments[permutation_item]) 185 | sampled_index_sequences.append(np.concatenate(segments_array)) 186 | return sampled_index_sequences 187 | 188 | 189 | def resize_sequence(sequence, cluster_id, num_permutations=None): 190 | """Resize sequences for packing and batching. 191 | 192 | Args: 193 | sequence: (real numpy matrix, size: seq_len*obs_size) - observed sequence 194 | cluster_id: (numpy vector, size: seq_len) - cluster indicator sequence 195 | num_permutations: int - Number of permutations per utterance sampled. 196 | 197 | Returns: 198 | sub_sequences: A list of numpy array, with obsevation vector from the same 199 | cluster in the same list. 200 | seq_lengths: The length of each cluster (+1). 201 | """ 202 | # merge sub-sequences that belong to a single cluster to a single sequence 203 | unique_id = np.unique(cluster_id) 204 | sub_sequences = [] 205 | seq_lengths = [] 206 | if num_permutations and num_permutations > 1: 207 | for i in unique_id: 208 | idx_set = np.where(cluster_id == i)[0] 209 | sampled_idx_sets = sample_permuted_segments(idx_set, num_permutations) 210 | for j in range(num_permutations): 211 | sub_sequences.append(sequence[sampled_idx_sets[j], :]) 212 | seq_lengths.append(len(idx_set) + 1) 213 | else: 214 | for i in unique_id: 215 | idx_set = np.where(cluster_id == i) 216 | sub_sequences.append(sequence[idx_set, :][0]) 217 | seq_lengths.append(len(idx_set[0]) + 1) 218 | return sub_sequences, seq_lengths 219 | 220 | 221 | def pack_sequence(sub_sequences, seq_lengths, batch_size, observation_dim, 222 | device, loss_samples): 223 | """Pack sequences for training. 224 | 225 | Args: 226 | sub_sequences: A list of numpy array, with obsevation vector from the same 227 | cluster in the same list. 228 | seq_lengths: The length of each cluster (+1). 229 | batch_size: int or None - Run batch learning if batch_size is None. Else, 230 | run online learning with specified batch size. 231 | observation_dim: int - dimension for observation vectors 232 | device: str - Your device. E.g., `cuda:0` or `cpu`. 233 | 234 | Returns: 235 | packed_rnn_input: (PackedSequence object) packed rnn input 236 | rnn_truth: ground truth 237 | """ 238 | num_clusters = len(seq_lengths) 239 | sorted_seq_lengths = np.sort(seq_lengths)[::-1] 240 | permute_index = np.argsort(seq_lengths)[::-1] 241 | 242 | if batch_size is None: 243 | rnn_input = np.zeros((sorted_seq_lengths[0], 244 | num_clusters, 245 | observation_dim)) 246 | for i in range(num_clusters): 247 | rnn_input[1:sorted_seq_lengths[i], i, 248 | :] = sub_sequences[permute_index[i]] 249 | rnn_input_tensor = autograd.Variable( 250 | torch.from_numpy(rnn_input).float()).to(device) 251 | packed_rnn_input = torch.nn.utils.rnn.pack_padded_sequence( 252 | rnn_input_tensor, sorted_seq_lengths, batch_first=False) 253 | else: 254 | mini_batch = np.sort(np.random.choice(num_clusters, batch_size)) 255 | rnn_input = np.zeros((sorted_seq_lengths[mini_batch[0]], 256 | batch_size, 257 | observation_dim)) 258 | for i in range(batch_size): 259 | rnn_input[1:sorted_seq_lengths[mini_batch[i]], 260 | i, :] = sub_sequences[permute_index[mini_batch[i]]] 261 | rnn_input_tensor = autograd.Variable( 262 | torch.from_numpy(rnn_input).float()).to(device) 263 | packed_rnn_input = torch.nn.utils.rnn.pack_padded_sequence( 264 | rnn_input_tensor, sorted_seq_lengths[mini_batch], batch_first=False) 265 | # build ground truth 266 | if loss_samples > 0: # 267 | rnn_truth = np.zeros_like(rnn_input[1:,:,:]) 268 | for i in range(batch_size): 269 | for j in range(sorted_seq_lengths[mini_batch[i]] - 1): 270 | samples_idx = np.random.randint( 271 | low=j + 1, 272 | high=sorted_seq_lengths[mini_batch[i]], 273 | size=min(loss_samples, sorted_seq_lengths[mini_batch[i]]-j-1)) 274 | rnn_truth[j] = np.mean(rnn_input[samples_idx, i, :], axis=0) 275 | else: 276 | rnn_truth = rnn_input[1:, :, :] 277 | rnn_truth = torch.from_numpy(rnn_truth).float().to(device) 278 | return packed_rnn_input, rnn_truth 279 | 280 | 281 | def output_result(model_args, training_args, test_record): 282 | """Produce a string to summarize the experiment.""" 283 | accuracy_array, _ = zip(*test_record) 284 | total_accuracy = np.mean(accuracy_array) 285 | output_string = """ 286 | Config: 287 | sigma_alpha: {} 288 | sigma_beta: {} 289 | crp_alpha: {} 290 | learning rate: {} 291 | regularization: {} 292 | batch size: {} 293 | 294 | Performance: 295 | averaged accuracy: {:.6f} 296 | accuracy numbers for all testing sequences: 297 | """.strip().format( 298 | training_args.sigma_alpha, 299 | training_args.sigma_beta, 300 | model_args.crp_alpha, 301 | training_args.learning_rate, 302 | training_args.regularization_weight, 303 | training_args.batch_size, 304 | total_accuracy) 305 | for accuracy in accuracy_array: 306 | output_string += '\n {:.6f}'.format(accuracy) 307 | output_string += '\n' + '=' * 80 + '\n' 308 | filename = 'layer_{}_{}_{:.1f}_result.txt'.format( 309 | model_args.rnn_hidden_size, 310 | model_args.rnn_depth, model_args.rnn_dropout) 311 | with open(filename, 'a') as file_object: 312 | file_object.write(output_string) 313 | return output_string 314 | 315 | 316 | def estimate_transition_bias(cluster_ids, smooth=1): 317 | """Estimate the transition bias. 318 | 319 | Args: 320 | cluster_id: Either a list of cluster indicator sequences, or a single 321 | concatenated sequence. The former is strongly preferred, since the 322 | transition_bias estimated from the latter will be inaccurate. 323 | smooth: int or float - Smoothing coefficient, avoids -inf value in np.log 324 | in the case of a sequence with a single speaker and division by 0 in the 325 | case of empty sequences. Using a small value for smooth decreases the 326 | bias in the calculation of transition_bias but can also lead to underflow 327 | in some remote cases, larger values are safer but less accurate. 328 | 329 | Returns: 330 | bias: Flipping coin head probability. 331 | bias_denominator: The denominator of the bias, used for multiple calls to 332 | fit(). 333 | """ 334 | transit_num = smooth 335 | bias_denominator = 2 * smooth 336 | for cluster_id_seq in cluster_ids: 337 | for entry in range(len(cluster_id_seq) - 1): 338 | transit_num += (cluster_id_seq[entry] != cluster_id_seq[entry + 1]) 339 | bias_denominator += 1 340 | bias = transit_num / bias_denominator 341 | return bias, bias_denominator 342 | 343 | 344 | def estimate_crp_alpha(cluster_ids, smooth=1): 345 | """Estimate the transition bias. 346 | 347 | Args: 348 | cluster_id: Either a list of cluster indicator sequences, or a single 349 | concatenated sequence. The former is strongly preferred, since the 350 | transition_bias estimated from the latter will be inaccurate. 351 | smooth: int or float - Smoothing coefficient, avoids -inf value in np.log 352 | in the case of a sequence with a single speaker and division by 0 in the 353 | case of empty sequences. Using a small value for smooth decreases the 354 | bias in the calculation of transition_bias but can also lead to underflow 355 | in some remote cases, larger values are safer but less accurate. 356 | 357 | Returns: 358 | crp_alpha: alpha parameter of the ddCRP that quentifies the probability 359 | of a new speaker joining the conversation. 360 | crp_alpha_denominator: The denominator of for crp_alpha, used for 361 | multiple calls to fit(). 362 | """ 363 | speaker_joins = sum(len(set(seq)) - 1 for seq in cluster_ids) + smooth 364 | speaker_changes = 2 * smooth 365 | for cluster_id_seq in cluster_ids: 366 | for entry in range(len(cluster_id_seq) - 1): 367 | speaker_changes += (cluster_id_seq[entry] != cluster_id_seq[entry + 1]) 368 | return speaker_joins / speaker_changes, speaker_changes 369 | --------------------------------------------------------------------------------