├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── dev │ ├── full │ │ └── split1.txt │ ├── head │ │ └── split1.txt │ └── tail │ │ └── split1.txt └── train │ └── full │ └── split1.txt ├── experiments └── demo │ ├── test_config.ini │ └── train_config.ini ├── requirements.txt ├── run_inference.py ├── run_model.py └── src ├── __init__.py ├── data.py ├── loss.py ├── metadata.py ├── model.py ├── tokenizer.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_STORE 2 | 3 | # Environments 4 | .env 5 | .venv 6 | env/ 7 | venv/ 8 | ENV/ 9 | env.bak/ 10 | venv.bak/E 11 | 12 | 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | *.pyc 31 | MANIFEST 32 | 33 | *.log 34 | *.pt 35 | tb_log/ 36 | cache/ 37 | checkpoints/ 38 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention-Based Contextual Language Modeling Adaptation 2 | 3 | This project provides the source to reproduce the main methods of the paper 4 | "Attention-Based Contextual Language Model Adaptation for Speech Recognition", 5 | submitted to ACL 2021. This codebase also implements additional functionality 6 | that was not explicitly described in the paper, such as experimental methods 7 | for combining multiple types of non-linguistic context together (e.g. geo-location, 8 | and datetime). 9 | 10 | ## Onboarding 11 | 12 | Basic environment setup 13 | 14 | ``` 15 | virtualenv -p python3 env 16 | source env/bin/activate 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Data 21 | 22 | We are unable to provide the data we used for the results we report in our 23 | paper. However, to illustrate the expected input by the model we provide data 24 | samples in the data folder that illustrate the data format. In general, data 25 | should be structured in a tsv format where the first column corresponds to 26 | the transcribed utterance and the subsequent columns correspond to associated 27 | non-linguistic context. 28 | 29 | ## Training a Model 30 | 31 | In all of our experiments, we adapt a base 1-layer LSTM model with additional 32 | context, using an attention mechanism. To run an experiment, you need to 33 | define a config file with the desired configurations for the model architecture, 34 | data processing, model training and model evaluation parameters. To illustrate 35 | how to setup a config file for an experiment, we provide a sample config file 36 | under experiments/demo. The sample config provides the configurations for 37 | conditioning an NLM on datetime context, using a bahdanau attention mechanism. 38 | 39 | To train a model using the sample config, run the following command from the 40 | root directory. 41 | 42 | ``` 43 | python3 run_model.py experiments/demo/train_config.ini 44 | ``` 45 | 46 | Running this script will generate a log containing the training results. Using 47 | the provided train_config.ini config, you should expect the see the following 48 | final evaluation (numbers might vary a bit): 49 | 50 | ``` 51 | Finished Evaluation Model. 52 | Full Dev Data -- Loss: 4.678730704567649, PPL: 107.63334655761719 53 | Head Dev Data -- Loss: 4.679276899857954, PPL: 107.69217681884766 54 | Tail Dev Data -- Loss: 4.678184509277344, PPL: 107.57459259033203 55 | ``` 56 | 57 | ## Running Inference 58 | 59 | Similarly to how we train a model by defining a config file, we also provide a 60 | config file for evaluating a model on a given dataset. In experiment/demo you 61 | will find a sample config for evaluating the same model defined in train_config.ini, 62 | using the sample dev dataset we provide in data/dev. 63 | 64 | To run inference, run the following command from the root directory. 65 | 66 | ``` 67 | python3 run_inference.py experiments/demo/test_config.ini 68 | ``` 69 | 70 | Note that the configuration we provide in test_config assumes that you have 71 | already trained a model using the config in train_config.ini. 72 | 73 | 74 | For additional information or questions, reach out to mrtimri@amazon.com 75 | -------------------------------------------------------------------------------- /data/dev/full/split1.txt: -------------------------------------------------------------------------------- 1 | This is a sample dev utterance from head of dataset 2019-11-08-23 2 | This is another sample dev utterance from head of dataset 2019-11-08-23 3 | This is a sample train utterance from tail of dataset 2019-11-08-23 4 | This is another sample train utterance from tail of dataset 2019-11-08-23 5 | -------------------------------------------------------------------------------- /data/dev/head/split1.txt: -------------------------------------------------------------------------------- 1 | This is a sample dev utterance from head of dataset 2019-11-08-23 2 | This is another sample dev utterance from head of dataset 2019-11-08-23 3 | -------------------------------------------------------------------------------- /data/dev/tail/split1.txt: -------------------------------------------------------------------------------- 1 | This is a sample train utterance from tail of dataset 2019-11-08-23 2 | This is another sample train utterance from tail of dataset 2019-11-08-23 3 | -------------------------------------------------------------------------------- /data/train/full/split1.txt: -------------------------------------------------------------------------------- 1 | This is a sample train utterance 2019-11-08-23 2 | This is another sample train utterance 2019-11-08-23 3 | This is a third sample train utterance 2019-11-08-23 4 | -------------------------------------------------------------------------------- /experiments/demo/test_config.ini: -------------------------------------------------------------------------------- 1 | [EXPERIMENT] 2 | experiment_directory = experiments/demo 3 | 4 | [DATA] 5 | train_data_directory_full = data/train/full 6 | 7 | [TOKENIZER] 8 | tokenizer_type = basic_tokenizer 9 | vocab_limit = 20 10 | 11 | [METADATA] 12 | text_index = 0 13 | md_indices = 1 14 | md_transformations = all_tokens 15 | 16 | [MODEL] 17 | model_type = concat_lstm 18 | batch_size = 2 19 | 20 | [TEST] 21 | model_path = experiments/demo/checkpoints/checkpoint_step_4/model.pt 22 | 23 | ppl_data_directory_full = data/dev/full 24 | ppl_data_directory_head = data/dev/head 25 | ppl_data_directory_tail = data/dev/tail 26 | 27 | text_index = 0 28 | md_indices = 1 29 | md_transformations = all_tokens 30 | -------------------------------------------------------------------------------- /experiments/demo/train_config.ini: -------------------------------------------------------------------------------- 1 | [EXPERIMENT] 2 | experiment_directory = experiments/demo 3 | 4 | [DATA] 5 | train_data_directory_full = data/train/full 6 | dev_data_directory_full = data/dev/full 7 | dev_data_directory_head = data/dev/head 8 | dev_data_directory_tail = data/dev/tail 9 | 10 | [TOKENIZER] 11 | tokenizer_type = basic_tokenizer 12 | vocab_limit = 20 13 | 14 | [METADATA] 15 | text_index = 0 16 | md_indices = 1 17 | md_transformations = all_tokens 18 | 19 | [MODEL] 20 | model_type = concat_lstm 21 | batch_size = 2 22 | emb_dim = 512 23 | context_dim = 512 24 | hidden_dim = 512 25 | 26 | attention_mechanism = bahdanau 27 | query_type = word 28 | 29 | use_weight_tying = False 30 | use_null_token = True 31 | use_layernorm = True 32 | 33 | md_projection_dim = 512 34 | md_dims = 512 35 | md_group_sizes = 4 36 | 37 | hierarchical_attention = False 38 | 39 | [TRAIN] 40 | print_every = 1 41 | eval_every = 1 42 | max_train_steps = 4 43 | scheduler = plateau 44 | learning_rate = 1e-3 45 | eps_tolerance = 0 46 | patience = 1 47 | decay_factor = 0.75 48 | 49 | [EVAL] 50 | text_index = 0 51 | md_indices = 1 52 | md_transformations = all_tokens 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf==3.18.3 2 | pygeohash==1.2.0 3 | scipy==1.4.0 4 | torch==1.7.0 5 | sentencepiece==0.1.91 6 | tensorboard==2.1 7 | timezonefinder==4.4.0 8 | click==7.1.2 9 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import torch 5 | import logging 6 | import click 7 | import os 8 | import math 9 | 10 | from src.tokenizer import get_tokenizer 11 | from src.metadata import MetaDataTransformer 12 | from src.model import get_model 13 | from src.util import move_to_device, device, eval_model, get_dataloader 14 | from src.loss import get_loss_fn, get_no_reduction_loss_fn 15 | 16 | from configparser import ConfigParser 17 | 18 | """ 19 | Basic utils for setting up main training and evaluation loops. 20 | """ 21 | 22 | def setup_config(config_file_path): 23 | config = ConfigParser() 24 | config.read(config_file_path) 25 | return config 26 | 27 | def setup_logger(config): 28 | # Removing handlers that might be associated with environment; and logs 29 | # out to both stderr and a log file 30 | for handler in logging.root.handlers[:]: 31 | logging.root.removeHandler(handler) 32 | log_file_name = os.path.join(config.get("EXPERIMENT", "experiment_directory"), "inference_result.log") 33 | logging.basicConfig( 34 | format='%(asctime)s [%(levelname)s] %(message)s', 35 | datefmt='%m/%d/%Y %I:%M:%S %p', 36 | level=logging.DEBUG, 37 | handlers=[ 38 | logging.FileHandler(log_file_name), 39 | logging.StreamHandler() 40 | ] 41 | ) 42 | logging.info(f"Initializing experiment: {config.get('EXPERIMENT', 'experiment_directory')}") 43 | logging.info(f"Running model on device: {device}") 44 | 45 | def setup(config_file_path): 46 | config = setup_config(config_file_path) 47 | setup_logger(config) 48 | return config 49 | 50 | """ 51 | Main script for running inference on a trained model. 52 | """ 53 | 54 | @click.command() 55 | @click.argument('config_file_path') 56 | def main(config_file_path): 57 | config = setup(config_file_path) 58 | md_transformer = MetaDataTransformer(text_index=config.get("TEST", "text_index"), 59 | md_indices=config.get("TEST", "md_indices", 60 | fallback=""), 61 | md_transformations=config.get("TEST", 62 | "md_transformations", 63 | fallback="")) 64 | 65 | tokenizer = get_tokenizer(tokenizer_type=config.get("TOKENIZER", "tokenizer_type"), 66 | data_path=config.get("DATA", "train_data_directory_full"), 67 | md_transformer=md_transformer, 68 | vocab_limit=int(config.get("TOKENIZER", "vocab_limit")), 69 | force_new_creation=False) 70 | tokenizer.add_special_tokens(md_transformer.get_md_tokens()) 71 | 72 | # Getting dataloaders 73 | ppl_dataloader_full = get_dataloader(config, tokenizer, md_transformer, "ppl", "full", config_section="TEST") 74 | ppl_dataloader_head = get_dataloader(config, tokenizer, md_transformer, "ppl", "head", config_section="TEST") 75 | ppl_dataloader_tail = get_dataloader(config, tokenizer, md_transformer, "ppl", "tail", config_section="TEST") 76 | 77 | # Setting up model 78 | model = torch.load(config.get("TEST", "model_path"), map_location=device) 79 | _, dev_loss_fn = get_loss_fn(config.get("MODEL", "model_type")) 80 | no_reduction_loss_fn = get_no_reduction_loss_fn(config.get("MODEL", "model_type")) 81 | 82 | #### Evaluation Cycle ### 83 | model.eval() 84 | model.to(device) 85 | 86 | loss_full, ppl_full = eval_model(model, ppl_dataloader_full, 87 | dev_loss_fn) 88 | loss_head, ppl_head = eval_model(model, ppl_dataloader_head, 89 | dev_loss_fn) 90 | loss_tail, ppl_tail = eval_model(model, ppl_dataloader_tail, 91 | dev_loss_fn) 92 | 93 | logging.info("Full evaluation: ") 94 | logging.info(f"\t loss: {loss_full} ppl: {ppl_full}") 95 | logging.info("Head evaluation: ") 96 | logging.info(f"\t loss: {loss_head} ppl: {ppl_head}") 97 | logging.info("Tail evaluation: ") 98 | logging.info(f"\t loss: {loss_tail} ppl: {ppl_tail}") 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import torch 5 | import logging 6 | import click 7 | import os 8 | import math 9 | 10 | from src.tokenizer import get_tokenizer 11 | from src.metadata import MetaDataTransformer 12 | from src.model import get_model 13 | from src.util import move_to_device, device, save_model_checkpoint, eval_model,\ 14 | get_dataloader, get_lr_scheduler 15 | from src.loss import get_loss_fn, get_no_reduction_loss_fn 16 | 17 | from configparser import ConfigParser 18 | from torch.optim import Adam 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | """ 22 | Basic utils for setting up main training and evaluation loops. 23 | """ 24 | 25 | def setup_config(config_file_path): 26 | config = ConfigParser() 27 | config.read(config_file_path) 28 | return config 29 | 30 | def setup_logger(config): 31 | # Removing handlers that might be associated with environment; and logs 32 | # out to both stderr and a log file 33 | for handler in logging.root.handlers[:]: 34 | logging.root.removeHandler(handler) 35 | log_file_name = os.path.join(config.get("EXPERIMENT", "experiment_directory"), "experiment.log") 36 | logging.basicConfig( 37 | format='%(asctime)s [%(levelname)s] %(message)s', 38 | datefmt='%m/%d/%Y %I:%M:%S %p', 39 | level=logging.DEBUG, 40 | handlers=[ 41 | logging.FileHandler(log_file_name), 42 | logging.StreamHandler() 43 | ] 44 | ) 45 | logging.info(f"Initializing experiment: {config.get('EXPERIMENT', 'experiment_directory')}") 46 | logging.info(f"Running model on device: {device}") 47 | 48 | def setup(config_file_path): 49 | config = setup_config(config_file_path) 50 | setup_logger(config) 51 | return config 52 | 53 | """ 54 | Main script for training a language model with meta data input. 55 | """ 56 | 57 | @click.command() 58 | @click.argument('config_file_path') 59 | def main(config_file_path): 60 | 61 | ####### Initial Model Setup ####### 62 | config = setup(config_file_path) 63 | writer = SummaryWriter(os.path.join(config.get("EXPERIMENT", "experiment_directory"), "tb_log")) 64 | 65 | md_transformer = MetaDataTransformer(text_index=config.get("METADATA", "text_index"), 66 | md_indices=config.get("METADATA", "md_indices", 67 | fallback=""), 68 | md_transformations=config.get("METADATA", 69 | "md_transformations", 70 | fallback="")) 71 | 72 | tokenizer = get_tokenizer(tokenizer_type=config.get("TOKENIZER", "tokenizer_type"), 73 | data_path=config.get("DATA", "train_data_directory_full"), 74 | md_transformer=md_transformer, 75 | vocab_limit=int(config.get("TOKENIZER", "vocab_limit")), 76 | force_new_creation=False) 77 | tokenizer.add_special_tokens(md_transformer.get_md_tokens()) 78 | 79 | # Constructing datasets 80 | train_dataloader = get_dataloader(config, tokenizer, md_transformer, "train", "full") 81 | dev_dataloader_full = get_dataloader(config, tokenizer, md_transformer, "dev", "full") 82 | dev_dataloader_head = get_dataloader(config, tokenizer, md_transformer, "dev", "head") 83 | dev_dataloader_tail = get_dataloader(config, tokenizer, md_transformer, "dev", "tail") 84 | 85 | # Loading model 86 | model = get_model(config, tokenizer.get_vocab_size()) 87 | train_loss_fn, dev_loss_fn = get_loss_fn(config.get("MODEL", "model_type")) 88 | 89 | ####### Training Configurations ####### 90 | model_params = filter(lambda p: p.requires_grad, model.parameters()) 91 | lr = float(config.get("TRAIN", "learning_rate")) 92 | optimizer = Adam(model_params, lr=lr) 93 | # Scheduler is either 1) lr_scheduler.ReduceLROnPlateau or 2) lr_scheduler.OneCycleLR 94 | update_lr_per_step, scheduler = get_lr_scheduler(config, optimizer) 95 | 96 | print_every = int(config.get("TRAIN", "print_every")) 97 | eval_every = int(config.get("TRAIN", "eval_every")) 98 | 99 | max_train_steps = int(config.get("TRAIN", "max_train_steps")) 100 | 101 | logging.info(f"Training Configuration:") 102 | logging.info(f"\t Evaluating model every: {eval_every} steps") 103 | logging.info(f"\t Training with Adam using lr: {lr}") 104 | 105 | ####### Training Loop ####### 106 | model.to(device) 107 | model.train() 108 | logging.info("Beginning training loop.") 109 | best_dev_loss_full = 1e3 110 | best_dev_loss_head = 1e3 111 | best_dev_loss_tail = 1e3 112 | best_dev_loss_full_step = -1 113 | best_dev_loss_head_step = -1 114 | best_dev_loss_tail_step = -1 115 | 116 | for train_step, train_batch in enumerate(train_dataloader): 117 | if train_step > max_train_steps: 118 | break 119 | 120 | # MODEL EVALUATION 121 | if (train_step and train_step % eval_every == 0): 122 | model.eval() 123 | 124 | dev_loss_full, dev_ppl_full = eval_model(model, dev_dataloader_full, 125 | dev_loss_fn) 126 | dev_loss_head, dev_ppl_head = eval_model(model, dev_dataloader_head, 127 | dev_loss_fn) 128 | dev_loss_tail, dev_ppl_tail = eval_model(model, dev_dataloader_tail, 129 | dev_loss_fn) 130 | 131 | logging.info(f"\t Finished Evaluation Model.") 132 | logging.info(f"\t \t Full Dev Data -- Loss: {dev_loss_full}, PPL: {dev_ppl_full}") 133 | logging.info(f"\t \t Head Dev Data -- Loss: {dev_loss_head}, PPL: {dev_ppl_head}") 134 | logging.info(f"\t \t Tail Dev Data -- Loss: {dev_loss_tail}, PPL: {dev_ppl_tail}") 135 | writer.add_scalar('Loss/dev_full', dev_loss_full, train_step) 136 | writer.add_scalar('PPL/dev_full', dev_ppl_full, train_step) 137 | writer.add_scalar('Loss/dev_head', dev_loss_head, train_step) 138 | writer.add_scalar('PPL/dev_head', dev_ppl_head, train_step) 139 | writer.add_scalar('Loss/dev_tail', dev_loss_tail, train_step) 140 | writer.add_scalar('PPL/dev_tail', dev_ppl_tail, train_step) 141 | 142 | if dev_loss_full < best_dev_loss_full: 143 | best_dev_loss_full = dev_loss_full 144 | best_dev_loss_full_step = train_step 145 | if dev_loss_head < best_dev_loss_head: 146 | best_dev_loss_head = dev_loss_head 147 | best_dev_loss_head_step = train_step 148 | if dev_loss_tail < best_dev_loss_tail: 149 | best_dev_loss_tail = dev_loss_tail 150 | best_dev_loss_tail_step = train_step 151 | 152 | if not update_lr_per_step: 153 | scheduler.step(dev_loss_full) 154 | 155 | save_model_checkpoint(model, train_step, config, 156 | dev_loss_full, dev_ppl_full, 157 | dev_loss_head, dev_ppl_head, 158 | dev_loss_tail, dev_ppl_tail) 159 | 160 | 161 | model.train() 162 | 163 | model.zero_grad() 164 | train_batch = move_to_device(train_batch) 165 | pred_logits = model(train_batch) 166 | train_loss = train_loss_fn(pred_logits, train_batch) 167 | 168 | if (train_step and train_step % print_every == 0): 169 | logging.info(f"\t Training Step {train_step} Loss: {train_loss.item()}") 170 | writer.add_scalar('Loss/train', train_loss.item(), train_step) 171 | 172 | train_loss.backward() 173 | optimizer.step() 174 | if update_lr_per_step and train_step < max_train_steps: 175 | # don't step scheduler at last update step 176 | scheduler.step() 177 | 178 | logging.info(f"Completed {max_train_steps} training steps.") 179 | logging.info(f"Evaluation Overview: ") 180 | logging.info(f"Best full dev loss/PPL: {best_dev_loss_full}/{math.exp(best_dev_loss_full)} \t step: {best_dev_loss_full_step}") 181 | logging.info(f"Best head dev loss/PPL: {best_dev_loss_head}/{math.exp(best_dev_loss_head)} \t step: {best_dev_loss_head_step}") 182 | logging.info(f"Best tail dev loss/PPL: {best_dev_loss_tail}/{math.exp(best_dev_loss_tail)} \t step: {best_dev_loss_tail_step}") 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import os 5 | import torch 6 | from torch.utils.data import IterableDataset 7 | from itertools import cycle 8 | from collections import defaultdict 9 | """ 10 | Custom iterable dataset for streaming in data and data processing utils. 11 | """ 12 | 13 | def list_files(directory, ignore_str="json"): 14 | # ignore_str set to json to skip nbest files 15 | files = [os.path.join(directory, file) for file in os.listdir(directory) if ignore_str not in file\ 16 | and os.path.isfile(os.path.join(directory,file))] 17 | return files 18 | 19 | def return_split(file_name): 20 | split = file_name.split('/')[-1] 21 | return split 22 | 23 | def get_next_utterance(directory, sort_by_function=return_split): 24 | ''' A generator that yields the next utterance ''' 25 | data_files = list_files(directory) 26 | data_files.sort(key=sort_by_function, reverse=True) 27 | 28 | for idx, file_path in enumerate(data_files): 29 | with open(file_path, "r") as transcription_fp: 30 | for line in transcription_fp: 31 | yield line 32 | 33 | def custom_collate(batch): 34 | """Collate function to deal with variable length input """ 35 | batch_size = len(batch) 36 | max_len = max([sample["text_len"] for sample in batch]) 37 | 38 | # IMPORTANT: Enforce padding token to be 0 39 | padded_input = torch.zeros((batch_size, max_len)) 40 | padded_output = torch.zeros((batch_size, max_len)) 41 | text_len = [] 42 | 43 | md, md_len = defaultdict(list), defaultdict(list) 44 | 45 | for idx, sample in enumerate(batch): 46 | curr_len = sample["text_len"] 47 | text_len.append(curr_len) 48 | padded_input[idx, :curr_len] = sample["input"] 49 | padded_output[idx, :curr_len] = sample["output"] 50 | 51 | sample_md = sample["md"] 52 | sample_md_len = sample["md_len"] 53 | if sample_md is None: 54 | md = None 55 | md_len = None 56 | continue 57 | 58 | for curr_md_transform, curr_md in sample_md.items(): 59 | md[curr_md_transform].append(curr_md) 60 | 61 | for curr_md_transform, curr_md_len in sample_md_len.items(): 62 | md_len[curr_md_transform].append(curr_md_len) 63 | 64 | 65 | text_len = torch.stack(text_len) 66 | 67 | if md: 68 | for curr_md_transform in md.keys(): 69 | md[curr_md_transform] = torch.stack(md[curr_md_transform]) 70 | for curr_md_transform in md.keys(): 71 | md_len[curr_md_transform] = torch.stack(md_len[curr_md_transform]) 72 | 73 | processed_batch = {"input": padded_input, 74 | "output": padded_output, 75 | "md": md, 76 | "text_len": text_len, 77 | "md_len": md_len} 78 | 79 | return processed_batch 80 | 81 | 82 | class MetaDataset(IterableDataset): 83 | """Dataset that can include meta data information. """ 84 | 85 | def __init__(self, data_directory, tokenizer, md_transformer): 86 | self.data_directory = data_directory 87 | self.tokenizer = tokenizer 88 | self.md_transformer = md_transformer 89 | self.cycle_data = "train" in data_directory 90 | 91 | def generate_processed_stream(self): 92 | for utterance in get_next_utterance(self.data_directory): 93 | md_dict, text = self.md_transformer.parse_raw_input(utterance) 94 | input = torch.tensor(self.tokenizer.encode_text(text, add_sos=True)) 95 | output = torch.tensor(self.tokenizer.encode_text(text, add_eos=True)) 96 | text_len = torch.tensor(len(input)) 97 | 98 | if md_dict: 99 | md = {} 100 | md_len = {} 101 | for curr_md_transform, curr_md in md_dict.items(): 102 | if not isinstance(curr_md, torch.Tensor): 103 | curr_md = torch.tensor(self.tokenizer.encode_text(curr_md)) 104 | curr_md_len = torch.tensor(len(curr_md)) 105 | else: 106 | curr_md_len = torch.tensor(1) 107 | 108 | md[curr_md_transform] = curr_md 109 | md_len[curr_md_transform] = curr_md_len 110 | 111 | else: 112 | md = None 113 | md_len = None 114 | 115 | sample = {"input": input, 116 | "output": output, 117 | "md": md, 118 | "text_len": text_len, 119 | "md_len": md_len} 120 | yield sample 121 | 122 | def __iter__(self): 123 | if self.cycle_data: 124 | return cycle(self.generate_processed_stream()) 125 | else: 126 | return self.generate_processed_stream() 127 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import torch 5 | from torch.nn import CrossEntropyLoss 6 | 7 | """ 8 | Loss functions for training different models defined in src/model.py 9 | """ 10 | 11 | ce_loss_sum = CrossEntropyLoss(reduction="sum", ignore_index=0) 12 | ce_loss_mean = CrossEntropyLoss(reduction="mean", ignore_index=0) 13 | ce_loss_no_reduction = CrossEntropyLoss(reduction="none", ignore_index=0) 14 | 15 | def base_lstm_train_loss(pred_logits, batch): 16 | """Calculates CE loss ignoring metadata tokens """ 17 | 18 | labels = batch["output"].long() 19 | 20 | ignore_len = 0 if "md" not in batch else 1 21 | 22 | # reshaping to ignore meta data tokens 23 | pred_logits = pred_logits[:, ignore_len:, :] 24 | pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)]) 25 | labels = torch.reshape(labels, [-1]) #same as flatten 26 | mean_loss = ce_loss_mean(pred_logits, labels) 27 | return mean_loss 28 | 29 | def base_lstm_dev_loss(pred_logits, batch): 30 | """Calculates CE loss ignoring metadata tokens """ 31 | 32 | labels = batch["output"].long() 33 | 34 | ignore_len = 0 if "md" not in batch else 1 35 | n_tokens = torch.sum(labels > 0) 36 | 37 | # reshaping to ignore meta data tokens 38 | pred_logits = pred_logits[:, ignore_len:, :] 39 | pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)]) 40 | labels = torch.reshape(labels, [-1]) #same as flatten 41 | sum_loss = ce_loss_sum(pred_logits, labels) 42 | return sum_loss, n_tokens 43 | 44 | def base_lstm_no_reduction_loss(pred_logits, batch): 45 | """Calculates CE loss ignoring metadata tokens """ 46 | 47 | labels = batch["output"].long() 48 | 49 | ignore_len = 0 if "md" not in batch else 1 50 | n_tokens = torch.sum(labels > 0, dim=-1) 51 | 52 | # reshaping to ignore meta data tokens 53 | pred_logits = pred_logits[:, ignore_len:, :] 54 | initial_pred_shape = pred_logits.shape[:-1] 55 | pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)]) 56 | labels = torch.reshape(labels, [-1]) #same as flatten 57 | all_loss = ce_loss_no_reduction(pred_logits, labels) 58 | all_loss = all_loss.view(initial_pred_shape) 59 | 60 | return all_loss, n_tokens 61 | 62 | def advanced_lstm_train_loss(pred_logits, batch): 63 | """ LSTM loss used by advanced baselines - no need to remove md logits """ 64 | 65 | labels = batch["output"].long() 66 | pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)]) 67 | labels = torch.reshape(labels, [-1]) #same as flatten 68 | mean_loss = ce_loss_mean(pred_logits, labels) 69 | return mean_loss 70 | 71 | def advanced_lstm_dev_loss(pred_logits, batch): 72 | """ LSTM loss used by advanced baselines - no need to remove md logits """ 73 | 74 | labels = batch["output"].long() 75 | n_tokens = torch.sum(labels > 0) 76 | 77 | pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)]) 78 | labels = torch.reshape(labels, [-1]) #same as flatten 79 | sum_loss = ce_loss_sum(pred_logits, labels) 80 | return sum_loss, n_tokens 81 | 82 | def advanced_lstm_no_reduction_loss(pred_logits, batch): 83 | """ LSTM loss used by advanced baselines - no need to remove md logits """ 84 | 85 | labels = batch["output"].long() 86 | n_tokens = torch.sum(labels > 0, dim=-1) 87 | 88 | initial_pred_shape = pred_logits.shape[:-1] 89 | pred_logits = torch.reshape(pred_logits, [-1, pred_logits.size(-1)]) 90 | labels = torch.reshape(labels, [-1]) #same as flatten 91 | all_loss = ce_loss_no_reduction(pred_logits, labels) 92 | all_loss = all_loss.view(initial_pred_shape) 93 | return all_loss, n_tokens 94 | 95 | ##### Utility Wrapper ##### 96 | 97 | 98 | LOSS_MAP = {"base_lstm": (base_lstm_train_loss, base_lstm_dev_loss), 99 | "concat_lstm": (advanced_lstm_train_loss, advanced_lstm_dev_loss),} 100 | 101 | LOSS_MAP_NO_REDUCTION = {"base_lstm": base_lstm_no_reduction_loss, 102 | "concat_lstm": advanced_lstm_no_reduction_loss,} 103 | 104 | def get_loss_fn(model_type): 105 | """Given a model type maps that to the model train and dev loss functions.""" 106 | assert(model_type in LOSS_MAP), f"Invalid model: {model_type}" 107 | return LOSS_MAP[model_type] 108 | 109 | def get_no_reduction_loss_fn(model_type): 110 | """ 111 | Given a model type maps that to the model loss function with no 112 | no reduction. Used primarily for per utterance ppl and wer inference. 113 | """ 114 | assert(model_type in LOSS_MAP_NO_REDUCTION), f"Invalid model: {model_type}" 115 | return LOSS_MAP_NO_REDUCTION[model_type] 116 | -------------------------------------------------------------------------------- /src/metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import logging 5 | import math 6 | import torch 7 | from datetime import datetime 8 | 9 | """ 10 | Functions for converting metadata strings into metadata tokens. 11 | """ 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | ''' 16 | Datetime metadata processing functions. 17 | 18 | These functions take in a string with the associated context and process it. 19 | The functions described below are the ones we use in the ACL paper, and we 20 | keep them here to showcase examples of processing functions. You will likely 21 | need to implement your own for your particular context. 22 | ''' 23 | 24 | def split_datetime_md(md_string): 25 | """Convert md_string into date object""" 26 | date = datetime.strptime(md_string,"%Y-%m-%d-%H") 27 | return date 28 | 29 | 30 | def extract_hour_token(md_string): 31 | """Returns one token per md_string corresponding into one of 24 hours tokens""" 32 | date = split_datetime_md(md_string) 33 | md_token = "" 34 | return md_token 35 | 36 | 37 | def extract_part_of_day_token(md_string): 38 | """Splits into one of 3 tokens: morning, midday and evening""" 39 | date = split_datetime_md(md_string) 40 | hour = date.hour 41 | md_token = "= 5 and hour < 10: 43 | md_token = md_token + "MORNING" + ">" 44 | elif hour >= 10 and hour < 14: 45 | md_token = md_token + "MIDDAY" + ">" 46 | else: 47 | md_token = md_token + "EVENING" + ">" 48 | return md_token 49 | 50 | 51 | def extract_day_of_week_token(md_string): 52 | """ Splits into one of 7 days """ 53 | date = split_datetime_md(md_string) 54 | md_token = "" 55 | return md_token 56 | 57 | 58 | def extract_weekend_weekday_token(md_string): 59 | """ Splits into one of 2: weekday, and weekend """ 60 | date = split_datetime_md(md_string) 61 | md_token = "" 65 | else: 66 | md_token = md_token + "WEEKEND" + ">" 67 | return md_token 68 | 69 | 70 | def extract_week_of_year_token(md_string): 71 | """ Splits into 1 of 53 weeks numbers in range(1,54) """ 72 | date = split_datetime_md(md_string) 73 | md_token = "" 74 | return md_token 75 | 76 | 77 | def extract_month_of_year_token(md_string): 78 | """ Splits into 1 of 12 months numbers in range(1,13) """ 79 | date = split_datetime_md(md_string) 80 | md_token = "" 81 | return md_token 82 | 83 | 84 | def extract_year_token(md_string): 85 | """ Returns one of 7 year numbers (2014-2020)""" 86 | date = split_datetime_md(md_string) 87 | md_token = "" 88 | return md_token 89 | 90 | 91 | def extract_all_tokens(md_string): 92 | """ Combines month, week, day and hour information """ 93 | date = split_datetime_md(md_string) 94 | month_token = "" 95 | week_token = "" 96 | day_token = "" 97 | hour_token = "" 98 | 99 | md_data = [month_token, week_token, day_token, hour_token] 100 | md_tokens = " ".join(md_data) 101 | return md_tokens 102 | 103 | 104 | def extract_radians(md_string): 105 | """ 106 | Creates a radian representation of date information that uses: hour, day, week, 107 | month information encapsulated in one 8-dimensional vector. 108 | """ 109 | date = split_datetime_md(md_string) 110 | 111 | hour = date.hour # Between 0 and 23 112 | day = date.weekday() # Between 0 and 6 113 | week = date.isocalendar()[1] - 1 # Between 0 and 52 114 | month = date.month - 1 # between 0 and 11 115 | 116 | radians_hour = [math.sin(2*math.pi*hour/24), math.cos(2*math.pi*hour/24)] 117 | radians_day = [math.sin(2*math.pi*day/7), math.cos(2*math.pi*day/7)] 118 | radians_week = [math.sin(2*math.pi*week/53), math.cos(2*math.pi*week/53)] 119 | radians_month = [math.sin(2*math.pi*month/12), math.cos(2*math.pi*month/12)] 120 | 121 | all_radians = radians_hour + radians_day + radians_week + radians_month 122 | 123 | return torch.FloatTensor(all_radians) 124 | 125 | # Geo hash metadata processing functions 126 | from string import ascii_lowercase as alphabet 127 | geo_hash_chars = list(alphabet) + list(range(0,10)) 128 | geo_hash_set = {f"{x}{y}" for x in geo_hash_chars for y in geo_hash_chars} 129 | 130 | def extract_geo_hash(md_string): 131 | if md_string not in geo_hash_set: 132 | md_string = "None" 133 | return f"" 134 | 135 | ### Metadata transformation class 136 | 137 | TRANSFORM_MAP = { 138 | "hour_token": (extract_hour_token, set(["" for i in range(24)])), 139 | "part_of_day_token": (extract_part_of_day_token, set(["", "", ""])), 140 | "day_of_week_token": (extract_day_of_week_token, set(["" for i in range(7)])), 141 | "weekend_weekday_token": (extract_weekend_weekday_token, set(["", ""])), 142 | "week_of_year_token": (extract_week_of_year_token, set(["" for i in range(1,54)])), 143 | "month_of_year_token": (extract_month_of_year_token, set(["" for i in range(1,13)])), 144 | "year_token": (extract_year_token, set(["" for i in range(2014,2021)])), 145 | "all_tokens": (extract_all_tokens, set(["" for i in range(1,13)] +\ 146 | ["" for i in range(1,54)] +\ 147 | ["" for i in range(7)] +\ 148 | ["" for i in range(24)])), 149 | "radians": (extract_radians, set()), 150 | "geo_hash": (extract_geo_hash, {f"" for md_string in geo_hash_set}), 151 | } 152 | 153 | class MetaDataTransformer(): 154 | def __init__(self, text_index, md_indices, md_transformations): 155 | self.text_index = int(text_index) 156 | self.md_transformations = md_transformations.split(',') if md_transformations else [] 157 | self.md_indices = [int(x) for x in md_indices.split(',') if x] 158 | 159 | assert(len(self.md_indices) == len(self.md_transformations)), \ 160 | "Length of metadata indices and metadata transformations are mismatched" 161 | 162 | def get_md_tokens(self): 163 | all_tokens = set() 164 | for md_transform in self.md_transformations: 165 | _, tokens = TRANSFORM_MAP[md_transform] 166 | all_tokens.update(tokens) 167 | return all_tokens 168 | 169 | def parse_raw_input(self, utterance): 170 | md = {} 171 | split_utterance = utterance.rstrip().split('\t') 172 | for md_index, md_transform in zip(self.md_indices, self.md_transformations): 173 | raw_md = split_utterance[md_index].rstrip() 174 | md_transform_func, _ = TRANSFORM_MAP[md_transform] 175 | md[md_transform] = md_transform_func(raw_md) 176 | text = split_utterance[self.text_index] 177 | return (md, text) 178 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | from src.util import device 9 | 10 | """ 11 | Attention Model Classes. Note that all classes share the same set of arguments. 12 | """ 13 | 14 | class BahdanauAttention(nn.Module): 15 | ''' 16 | Standard bahdanau attention mechanism: https://arxiv.org/abs/1409.0473 17 | ''' 18 | def __init__(self, md_dim, query_dim, md_group_size, use_null_token): 19 | super().__init__() 20 | 21 | hidden_dim = md_dim # TODO (low priority): allow user to customize this 22 | self.md_dim = md_dim #size of keys 23 | self.query_dim = query_dim #size of query vector 24 | self.md_group_size = md_group_size 25 | 26 | self.use_null_token = use_null_token 27 | if self.use_null_token: 28 | self.zeros = torch.zeros([1,1,self.md_dim], requires_grad=False) 29 | 30 | self.key_projection = nn.Linear(md_dim, hidden_dim, bias=False) 31 | self.query_projection = nn.Linear(query_dim, hidden_dim, bias=False) 32 | self.energy_projection = nn.Linear(hidden_dim, 1, bias=False) 33 | 34 | def forward(self, input, query): 35 | if query.dim() == 3: 36 | # 3-dimensional query means we are precomputing attention 37 | seq_len = query.size(0) 38 | query = query.view(seq_len, -1, 1, self.query_dim) 39 | else: 40 | query = query.view(-1, 1, self.query_dim) 41 | 42 | input_dim = input.dim() 43 | if input_dim == 3: 44 | input = input.view(-1, self.md_group_size, self.md_dim) 45 | elif input_dim == 4: 46 | input = input.view(seq_len, -1, self.md_group_size, self.md_dim) 47 | else: 48 | raise Exception(f"Invalid number of input dimension: {input_dim}") 49 | 50 | if self.use_null_token: 51 | if input_dim == 3: 52 | zeros = self.zeros.repeat(input.size(0), 1, 1).to(device) 53 | input = torch.cat((input, zeros), dim=1) 54 | else: 55 | zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device) 56 | test = self.zeros.repeat(input.size(0), 1, 1).to(device) 57 | input = torch.cat((input, zeros), dim=2) 58 | 59 | hidden_keys = self.key_projection(input) 60 | hidden_query = self.query_projection(query) 61 | 62 | scores = self.energy_projection(torch.tanh(hidden_query + hidden_keys)) 63 | alphas = nn.Softmax(dim=-1)(scores).transpose(-1,-2) 64 | context = torch.matmul(alphas, input).squeeze() 65 | return context 66 | 67 | class GeneralAttention(nn.Module): 68 | ''' 69 | Implements a general purpose general attention mechanism where key and query 70 | vectors are multiplied together by a learned weight matrix W. 71 | ''' 72 | def __init__(self, md_dim, query_dim, md_group_size, use_null_token): 73 | 74 | super().__init__() 75 | 76 | self.md_dim = md_dim 77 | self.query_dim = query_dim 78 | self.md_group_size = md_group_size 79 | self.use_null_token = use_null_token 80 | 81 | if self.use_null_token: 82 | self.zeros = torch.zeros([1,1,self.md_dim], requires_grad=False) 83 | 84 | self.W = nn.Parameter(torch.Tensor(self.query_dim, self.md_dim)) 85 | nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) 86 | 87 | def forward(self, input, query): 88 | if query.dim() == 3: 89 | # 3-dimensional query means we are precomputing attention 90 | seq_len = query.size(0) 91 | query = query.view(seq_len, -1, 1, self.query_dim) 92 | else: 93 | query = query.view(-1, 1, self.query_dim) 94 | 95 | input_dim = input.dim() 96 | if input_dim == 3: 97 | input = input.view(-1, self.md_group_size, self.md_dim) 98 | elif input_dim == 4: 99 | input = input.view(seq_len, -1, self.md_group_size, self.md_dim) 100 | else: 101 | raise Exception(f"Invalid number of input dimension: {input_dim}") 102 | 103 | if self.use_null_token: 104 | if input_dim == 3: 105 | zeros = self.zeros.repeat(input.size(0), 1, 1).to(device) 106 | input = torch.cat((input, zeros), dim=1) 107 | else: 108 | zeros = self.zeros.repeat(input.size(0), input.size(1), 1, 1).to(device) 109 | test = self.zeros.repeat(input.size(0), 1, 1).to(device) 110 | input = torch.cat((input, zeros), dim=2) 111 | 112 | scores = torch.matmul(query, self.W) 113 | scores = torch.matmul(scores, input.transpose(-1,-2)) 114 | 115 | alphas = nn.Softmax(dim=-1)(scores) 116 | context = torch.matmul(alphas, input).squeeze() 117 | return context 118 | 119 | ATTENTION_MAP = {"general": GeneralAttention, "bahdanau": BahdanauAttention} 120 | 121 | 122 | class MetadataConstructor(nn.Module): 123 | """ 124 | General module that processes different metadata and construct a metadata 125 | representation by using an attention based approach, or more simply concatenating 126 | different metadata embeddings. 127 | """ 128 | 129 | def __init__(self, metadata_constructor_params, dimension_params): 130 | 131 | super().__init__() 132 | 133 | self.md_projection_dim = metadata_constructor_params["md_projection_dim"] 134 | self.md_dims = metadata_constructor_params["md_dims"] 135 | self.md_group_sizes = metadata_constructor_params["md_group_sizes"] 136 | self.context_dim = dimension_params["context_dim"] 137 | 138 | # when using radians 139 | self.attention_mechanism = metadata_constructor_params["attention_mechanism"] 140 | 141 | if self.attention_mechanism: 142 | self.query_type = metadata_constructor_params["query_type"] 143 | 144 | assert(self.attention_mechanism in ATTENTION_MAP),\ 145 | f"Invalid attention type: {self.attention_mechanism}" 146 | assert(self.query_type in ("word", "hidden")),\ 147 | f"Invalid query type: {self.query_type}" 148 | 149 | query_dim = self.get_query_dim(dimension_params) 150 | 151 | self.use_null_token = metadata_constructor_params["use_null_token"] 152 | 153 | self.attention_modules = [] 154 | for md_dim, md_group_size in zip(self.md_dims, self.md_group_sizes): 155 | attention_module = ATTENTION_MAP[self.attention_mechanism](md_dim, 156 | query_dim, 157 | md_group_size, 158 | self.use_null_token).to(device) 159 | self.attention_modules.append(attention_module) 160 | 161 | # After attention module, the resulting metadata embeddings are projected 162 | # to size md_projection_dim 163 | self.projection_layers = [] 164 | for md_dim in self.md_dims: 165 | projection = nn.Linear(md_dim, self.md_projection_dim).to(device) 166 | self.projection_layers.append(projection) 167 | 168 | # The resulting metadata embeddings can now be combined via another 169 | # attention mechanism (specified by "hierarchical_attention" bool parameter), 170 | # or via a simpler concatenation of the metadata together 171 | self.use_hierarchical_attention = metadata_constructor_params["hierarchical_attention"] 172 | if self.use_hierarchical_attention: 173 | num_attention_groups = len(self.md_dims) 174 | query_dim = self.get_query_dim(dimension_params) 175 | # NOTE: we use the same query embedding in the attention module 176 | # as in the previous attention modules 177 | self.hierarchical_attention_module = ATTENTION_MAP[self.attention_mechanism](self.md_projection_dim, 178 | query_dim, 179 | num_attention_groups, 180 | self.use_null_token).to(device) 181 | context_projection_input_dim = self.md_projection_dim 182 | else: 183 | # If metadata is not combined hierarchically, all the metadata is 184 | # instead concated together. Computing the resulting size of the 185 | # concatenated embedding 186 | context_projection_input_dim = 0 187 | for md_group_size in self.md_group_sizes: 188 | if self.attention_mechanism: 189 | context_projection_input_dim += self.md_projection_dim 190 | else: 191 | context_projection_input_dim += self.md_projection_dim * md_group_size 192 | 193 | # Finally the metadata embedding (either concatenated or combined via attetntion) 194 | # are projected to size of context_dim 195 | self.context_projection = nn.Linear(context_projection_input_dim, 196 | self.context_dim).to(device) 197 | self.context_normalization = nn.LayerNorm(self.context_dim) 198 | 199 | def get_query_dim(self, dimension_params): 200 | ''' Simple helper function to return size of the query ''' 201 | if self.query_type == "word": 202 | query_dim = dimension_params["emb_dim"] 203 | elif self.query_type == "hidden": 204 | query_dim = dimension_params["hidden_dim"] 205 | return query_dim 206 | 207 | def is_precomputable(self): 208 | ''' Logic for determining is attention can be precomputed''' 209 | if not self.attention_mechanism or self.query_type == "word": 210 | return True 211 | else: 212 | return False 213 | 214 | @staticmethod 215 | def preprocess_md_util(md_embs, embedding_projection, md_dims, md_group_sizes): 216 | ''' Static helper to be used by base lstm model ''' 217 | processed_md = {} 218 | for idx, (md_transform, md) in enumerate(md_embs.items()): 219 | md_dim = md_dims[idx] 220 | md_group_size = md_group_sizes[idx] 221 | # Notice this breaks if md_dim == md_group_size 222 | if md.shape[-1] != md_dim*md_group_size: 223 | md = embedding_projection(md) 224 | 225 | processed_md[md_transform] = md 226 | return processed_md 227 | 228 | def preprocess_md(self, md_embs, embedding_projection): 229 | ''' Preprocess metadata input to have correct dimensionality ''' 230 | return MetadataConstructor.preprocess_md_util(md_embs, embedding_projection, 231 | self.md_dims, self.md_group_sizes) 232 | 233 | def concat_md(self, input_md): 234 | """ Flattens and concatenates input embeddings""" 235 | mds = [] 236 | for md in input_md: 237 | if not self.attention_mechanism: 238 | md = md.flatten(start_dim=1) 239 | mds.append(md) 240 | concat_md = torch.cat(mds, -1) 241 | return concat_md 242 | 243 | def forward(self, md_embs, query=None): 244 | ''' 245 | Query can be None if attention mechanism is not used 246 | ''' 247 | processed_mds = [] 248 | for idx, md in enumerate(md_embs.values()): 249 | if self.attention_mechanism: 250 | attention_module = self.attention_modules[idx] 251 | md = attention_module(md, query) 252 | 253 | # Only need to project data if more than one metadata group used 254 | projection_layer = self.projection_layers[idx] 255 | processed_md = projection_layer(md) 256 | processed_mds.append(processed_md) 257 | 258 | if self.use_hierarchical_attention: 259 | combined_md = self.hierarchical_attention_module(torch.stack(processed_mds), query) 260 | else: 261 | combined_md = self.concat_md(processed_mds) 262 | 263 | context_emb = self.context_projection(combined_md) 264 | context_emb = self.context_normalization(context_emb) 265 | context_emb = torch.tanh(context_emb) 266 | return context_emb 267 | 268 | 269 | ''' 270 | Base LSTM models that we are conditioned with non-linguistic context. In the paper 271 | we provide two methods for adding the context to the LSTM model, a concatenation-based 272 | method and a factor-based method. This open-source codebase provides the concatenation-method 273 | since it generally yielded the best results. 274 | ''' 275 | 276 | class ConcatLSTM(nn.Module): 277 | """ 278 | Implements the hidden state ConcatCell Approach proposed by Mikolov et al. 2012 279 | Paper Details: https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/rnn_ctxt.pdf 280 | """ 281 | 282 | def __init__(self, dimension_params, metadata_constructor_params, layer_params): 283 | 284 | super().__init__() 285 | self.emb_dim = dimension_params["emb_dim"] 286 | self.context_dim = dimension_params["context_dim"] 287 | self.hidden_dim = dimension_params["hidden_dim"] 288 | self.vocab_size = dimension_params["vocab_size"] 289 | 290 | self.n_layers = layer_params["n_layers"] 291 | self.use_softmax_adaptation = layer_params["use_softmax_adaptation"] 292 | self.use_layernorm = layer_params["use_layernorm"] 293 | self.use_weight_tying = layer_params["use_weight_tying"] 294 | 295 | self.metadata_constructor = MetadataConstructor(metadata_constructor_params, 296 | dimension_params) 297 | 298 | self.embeddings = nn.Embedding(self.vocab_size, self.emb_dim) 299 | 300 | self.gate_size = 4 * self.hidden_dim # input, forget, gate, output 301 | self._all_weights = nn.ParameterList() 302 | self._params_per_layer = 5 303 | 304 | for layer in range(self.n_layers): 305 | self.layer_input_size = self.emb_dim if layer == 0 else self.hidden_dim 306 | 307 | # weight matrix for meta data 308 | w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.layer_input_size)) 309 | w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.hidden_dim)) 310 | 311 | w_mh = nn.Parameter(torch.Tensor(self.gate_size, self.context_dim)) 312 | 313 | b_ih = nn.Parameter(torch.Tensor(self.gate_size)) 314 | b_hh = nn.Parameter(torch.Tensor(self.gate_size)) 315 | 316 | for param in (w_mh, w_ih, w_hh, b_ih, b_hh): 317 | self._all_weights.append(param) 318 | 319 | if self.use_softmax_adaptation: 320 | self.md_vocab_projection = nn.Linear(self.context_dim, self.vocab_size) 321 | 322 | if self.use_layernorm: 323 | self.layernorm = nn.LayerNorm(self.hidden_dim) 324 | 325 | if self.use_weight_tying: 326 | self.vocab_projection = nn.Linear(self.emb_dim, self.vocab_size) 327 | self.embedding_projection = nn.Linear(self.hidden_dim, self.emb_dim) 328 | self.vocab_projection.weight = self.embeddings.weight 329 | else: 330 | self.vocab_projection = nn.Linear(self.hidden_dim, self.vocab_size) 331 | 332 | self._reset_parameters() 333 | 334 | def _run_cell(self, input, md_layer, hidden, w_ih, w_hh, b_ih, b_hh): 335 | """ 336 | LSTM cell structure adapted from: 337 | github.com/pytorch/benchmark/blob/09eaadc1d05ad442b1f0beb82babf875bbafb24b/rnns/fastrnns/cells.py#L25-L40 338 | """ 339 | 340 | hx, cx = hidden 341 | gates = torch.matmul(input, w_ih.t()) + torch.matmul(hx, w_hh.t()) +\ 342 | md_layer + b_ih + b_hh 343 | 344 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 345 | 346 | if self.use_layernorm: 347 | ingate = self.layernorm(ingate) 348 | forgetgate = self.layernorm(forgetgate) 349 | cellgate = self.layernorm(cellgate) 350 | outgate = self.layernorm(outgate) 351 | 352 | ingate = torch.sigmoid(ingate) 353 | forgetgate = torch.sigmoid(forgetgate) 354 | cellgate = torch.tanh(cellgate) 355 | outgate = torch.sigmoid(outgate) 356 | 357 | cy = (forgetgate * cx) + (ingate * cellgate) 358 | hy = outgate * torch.tanh(cy) 359 | 360 | return hy, cy 361 | 362 | def _reset_parameters(self): 363 | """ Basic randomization of parameters""" 364 | for weight in self.parameters(): 365 | if weight.dim() > 1: 366 | nn.init.xavier_normal_(weight) 367 | else: 368 | torch.nn.init.zeros_(weight) # bias vector 369 | 370 | def forward(self, input): 371 | input_ids = input["input"].long() 372 | input_lens = input["text_len"] 373 | 374 | input_embs = self.embeddings(input_ids) 375 | max_batch_size = input_embs.size(0) 376 | seq_len = input_embs.size(1) 377 | 378 | # Assuming input is batch_first - permutting for sequence first 379 | input_embs = input_embs.permute(1, 0, 2) 380 | 381 | zeros = torch.zeros(self.n_layers, max_batch_size, self.hidden_dim).to(device) 382 | 383 | h_init = zeros 384 | c_init = zeros 385 | inputs = input_embs 386 | outputs = [] 387 | 388 | md_input = self.metadata_constructor.preprocess_md(input["md"], self.embeddings) 389 | 390 | if self.metadata_constructor.is_precomputable(): 391 | md = self.metadata_constructor(md_input, input_embs) 392 | 393 | for layer in range(self.n_layers): 394 | h = h_init[layer] 395 | c = c_init[layer] 396 | 397 | weight_start_index = layer * self._params_per_layer 398 | weight_end_index = (layer+1) * self._params_per_layer 399 | w_mh, w_ih, w_hh, b_ih, b_hh = self._all_weights[weight_start_index: weight_end_index] 400 | 401 | # Meta data can be computed in advance when not using attention 402 | 403 | if self.metadata_constructor.is_precomputable(): 404 | precomputed_md = torch.matmul(md, w_mh.t()) 405 | 406 | for t in range(seq_len): 407 | if not self.metadata_constructor.is_precomputable(): 408 | md = self.metadata_constructor(md_input, h) 409 | md_layer = torch.matmul(md, w_mh.t()) 410 | else: 411 | md_layer = precomputed_md[t] if precomputed_md.dim() == 3 else precomputed_md 412 | 413 | h, c = self._run_cell(inputs[t], md_layer, (h, c), w_ih, w_hh, b_ih, b_hh) 414 | outputs += [h] 415 | 416 | inputs = outputs 417 | outputs = [] 418 | 419 | # At the end the input variable will be set to outputs 420 | # Permutting to have batch - seq len - hidden dim 421 | lstm_out = torch.stack(inputs).permute(1, 0, 2) 422 | 423 | if self.use_weight_tying: 424 | vocab_predictions = self.vocab_projection(self.embedding_projection(lstm_out)) 425 | else: 426 | vocab_predictions = self.vocab_projection(lstm_out) 427 | 428 | if self.use_softmax_adaptation: 429 | md_embs = md_embs.view(max_batch_size, -1) 430 | md_context = self.md_vocab_projection(md_embs).unsqueeze(1) 431 | vocab_predictions += md_context 432 | 433 | return vocab_predictions 434 | 435 | 436 | class BaseLSTM(nn.Module): 437 | """Basic LSTM model - concatenating metadata to LSTM """ 438 | def __init__(self, dimension_params, metadata_constructor_params, layer_params): 439 | super().__init__() 440 | self.emb_dim = dimension_params["emb_dim"] 441 | self.hidden_dim = dimension_params["hidden_dim"] 442 | self.vocab_size = dimension_params["vocab_size"] 443 | 444 | self.md_dims = metadata_constructor_params["md_dims"] 445 | self.md_group_sizes = metadata_constructor_params["md_group_sizes"] 446 | self.use_md = True if self.md_dims and self.md_group_sizes else False 447 | 448 | self.n_layers = layer_params["n_layers"] 449 | self.use_weight_tying = layer_params["use_weight_tying"] 450 | 451 | self.embeddings = nn.Embedding(self.vocab_size, self.emb_dim) 452 | 453 | self.lstm = nn.LSTM(input_size=self.emb_dim, 454 | hidden_size=self.hidden_dim, 455 | num_layers=self.n_layers, 456 | batch_first=True, 457 | ) 458 | 459 | if self.use_weight_tying: 460 | self.vocab_projection = nn.Linear(self.emb_dim, self.vocab_size) 461 | self.embedding_projection = nn.Linear(self.hidden_dim, self.emb_dim) 462 | self.vocab_projection.weight = self.embeddings.weight 463 | else: 464 | self.vocab_projection = nn.Linear(self.hidden_dim, self.vocab_size) 465 | 466 | if self.use_md: 467 | self.metadata_constructor = MetadataConstructor(metadata_constructor_params, 468 | dimension_params) 469 | 470 | def forward(self, input): 471 | """ Appends meta data embeddings to input and passes through LSTM.""" 472 | input_ids = input["input"].long() 473 | input_lens = input["text_len"] 474 | input_embs = self.embeddings(input_ids) 475 | 476 | if self.use_md: 477 | md_input = self.metadata_constructor.preprocess_md(input["md"], self.embeddings) 478 | md_emb = self.metadata_constructor(md_input) 479 | 480 | # Prepending meta data information to text 481 | sos_embs = input_embs[:, :1, :] 482 | text_embs = input_embs[:, 1:, :] 483 | 484 | joined_embs = torch.cat((sos_embs, md_emb, text_embs), dim=1) 485 | joined_lens = 1 + input_lens # fixed metadata embedding 486 | 487 | packed_input = pack_padded_sequence(joined_embs, joined_lens, 488 | batch_first=True, enforce_sorted=False) 489 | else: 490 | packed_input = pack_padded_sequence(input_embs, input_lens, 491 | batch_first=True, enforce_sorted=False) 492 | 493 | # output: batch, seq_len, hidden_size 494 | lstm_out, _ = self.lstm(packed_input) 495 | lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) 496 | 497 | if self.use_weight_tying: 498 | vocab_predictions = self.vocab_projection(self.embedding_projection(lstm_out)) 499 | else: 500 | vocab_predictions = self.vocab_projection(lstm_out) 501 | 502 | return vocab_predictions 503 | 504 | ##### Utility Wrapper ##### 505 | 506 | def init_base_lstm(model_config, vocab_size): 507 | 508 | dimension_params = { 509 | "emb_dim":int(model_config.get("emb_dim")), 510 | "context_dim":int(model_config.get("emb_dim")), 511 | "hidden_dim":int(model_config.get("hidden_dim")), 512 | "vocab_size":vocab_size, 513 | } 514 | 515 | metadata_constructor_params = { 516 | "md_projection_dim":int(model_config.get("md_projection_dim", fallback="50")), 517 | "md_dims":[int(x) for x in model_config.get("md_dims", fallback="").split(',') if x], 518 | "md_group_sizes":[int(x) for x in model_config.get("md_group_sizes", fallback="").split(',') if x], 519 | "attention_mechanism": "", 520 | "hierarchical_attention": "", 521 | } 522 | 523 | layer_params = { 524 | "n_layers":1, # Fixed in ACL paper 525 | "use_weight_tying":eval(model_config.get("use_weight_tying", fallback='False')), 526 | } 527 | 528 | 529 | model = BaseLSTM(dimension_params, metadata_constructor_params, layer_params) 530 | return model 531 | 532 | 533 | def init_concat_lstm(model_config, vocab_size): 534 | dimension_params = { 535 | "emb_dim":int(model_config.get("emb_dim")), 536 | "context_dim":int(model_config.get("context_dim", 537 | fallback=int(model_config.get("emb_dim")))), 538 | "hidden_dim":int(model_config.get("hidden_dim")), 539 | "vocab_size":vocab_size, 540 | } 541 | 542 | metadata_constructor_params = { 543 | "md_projection_dim":int(model_config.get("md_projection_dim", fallback="50")), 544 | "md_dims":[int(x) for x in model_config.get("md_dims").split(',')], 545 | "md_group_sizes":[int(x) for x in model_config.get("md_group_sizes").split(',')], 546 | "attention_mechanism":model_config.get("attention_mechanism", fallback=""), 547 | "query_type":model_config.get("query_type", fallback=""), 548 | "use_null_token":eval(model_config.get("use_null_token", fallback="False")), 549 | "hierarchical_attention":eval(model_config.get("hierarchical_attention", fallback="False")), 550 | } 551 | 552 | layer_params = { 553 | "n_layers":1, # Fixed in ACL paper 554 | "use_softmax_adaptation":eval(model_config.get("use_softmax_adaptation", fallback='False')), 555 | "use_layernorm":eval(model_config.get("use_layernorm", fallback='False')), 556 | "use_weight_tying":eval(model_config.get("use_weight_tying", fallback='False')), 557 | } 558 | 559 | model = ConcatLSTM(dimension_params, metadata_constructor_params, layer_params) 560 | 561 | return model 562 | 563 | 564 | MODEL_MAP = {"base_lstm": init_base_lstm, 565 | "concat_lstm": init_concat_lstm} 566 | 567 | 568 | def get_model(config, vocab_size): 569 | """Given a model type maps that to a model initiation function.""" 570 | model_config = config["MODEL"] 571 | model_type = model_config.get("model_type") 572 | assert(model_type in MODEL_MAP), f"Invalid model type: {model_type}" 573 | return MODEL_MAP[model_type](model_config, vocab_size) 574 | -------------------------------------------------------------------------------- /src/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import pickle 5 | import os 6 | import logging 7 | from collections import Counter 8 | from src.data import get_next_utterance, list_files 9 | from abc import ABCMeta, abstractmethod 10 | 11 | import sentencepiece as spm 12 | 13 | """ 14 | Standard tokenization classes 15 | """ 16 | 17 | # NOTE: Keep this as is - padding token id must be 0 18 | BASE_TOKENS = {"":1, "":2, "": 3, "": 0} 19 | logger = logging.getLogger(__name__) 20 | 21 | class BaseTokenizer(metaclass=ABCMeta): 22 | @abstractmethod 23 | def get_vocab_size(self): 24 | pass 25 | 26 | @abstractmethod 27 | def encode_text(self, input_sentence): 28 | """Encodes an utterance """ 29 | pass 30 | 31 | @abstractmethod 32 | def encode(self, input_toks): 33 | """ Encodes a list of tokens to a list of ids """ 34 | pass 35 | 36 | @abstractmethod 37 | def decode(self, input_ids): 38 | """ Decodes a list of ids to a list of tokens """ 39 | pass 40 | 41 | @abstractmethod 42 | def add_special_tokens(self, special_tokens): 43 | """Adds a set of special tokens to the base """ 44 | pass 45 | 46 | @classmethod 47 | @abstractmethod 48 | def load_tokenizer(cls, tokenizer_path): 49 | pass 50 | 51 | class SPTokenizer(BaseTokenizer): 52 | def __init__(self, data_path, md_transformer, vocab_limit): 53 | """ Sentence Piece Tokenizer """ 54 | self.data_path = data_path 55 | self.md_transformer = md_transformer 56 | self.vocab_limit = vocab_limit 57 | self.tokenizer_model = self._generate_model() 58 | self.special_tokens = BASE_TOKENS 59 | 60 | def _process_data_files(self, dir_path): 61 | """ 62 | Reads in data in self.data_path and writes out to utterance text to 63 | files. 64 | """ 65 | line_count = 0 66 | file_count = 0 67 | 68 | curr_out_file_name = os.path.join(dir_path, f"processed_{file_count}.txt") 69 | out_file = open(curr_out_file_name, "w") 70 | for utterance in get_next_utterance(self.data_path): 71 | _, text = self.md_transformer.parse_raw_input(utterance) 72 | line_count += 1 73 | 74 | if (line_count % 20_000 == 0): 75 | line_count = 0 76 | file_count += 1 77 | curr_out_file_name = os.path.join(dir_path,\ 78 | f"processed_{file_count}.txt") 79 | out_file.close() 80 | out_file = open(curr_out_file_name, "w") 81 | 82 | out_file.write(text + '\n') 83 | 84 | def _generate_model(self): 85 | """ 86 | Creates a dataset of processed text files, and trains a sentence piece 87 | model. 88 | """ 89 | dir_path = "cache/sp_tokenizer_data" 90 | if not os.path.exists(dir_path): 91 | os.mkdir(dir_path) 92 | self._process_data_files(dir_path) 93 | data_files = list_files(dir_path) 94 | 95 | model_cache_prefix = f'cache/sp_tokenizer_{self.vocab_limit}' 96 | spm.SentencePieceTrainer.train(input=data_files, 97 | model_prefix=model_cache_prefix, 98 | vocab_size=self.vocab_limit, 99 | bos_id=BASE_TOKENS[""], 100 | eos_id=BASE_TOKENS[""], 101 | unk_id=BASE_TOKENS[""], 102 | pad_id=BASE_TOKENS[""]) 103 | model = spm.SentencePieceProcessor(model_file=f'{model_cache_prefix}.model') 104 | return model 105 | 106 | def print_special_token_ids(self): 107 | logger.info("Special Token: Token ID") 108 | for token, id in self.special_tokens: 109 | logger.info(f"\t {token}: {id}") 110 | 111 | def add_special_tokens(self, special_tokens): 112 | logger.info(f"Using special tokens: {special_tokens}") 113 | curr_vocab_size = self.get_vocab_size() 114 | for idx, tok in enumerate(set(special_tokens)): 115 | self.special_tokens[tok] = idx + curr_vocab_size 116 | 117 | def get_vocab_size(self): 118 | """vocab_limit includes the base_tokens""" 119 | return self.vocab_limit + len(self.special_tokens) - len(BASE_TOKENS) 120 | 121 | def encode_text(self, input_sentence, add_eos=False, add_sos=False): 122 | """Encodes an utterance """ 123 | output_ids = [] 124 | input_toks = input_sentence.split() 125 | if any(input_tok in self.special_tokens for input_tok in input_toks): 126 | for tok in input_toks: 127 | id = self.special_tokens[tok] 128 | output_ids.append(id) 129 | else: 130 | output_ids = self.encode(input_sentence, add_eos=add_eos, add_sos=add_sos) 131 | 132 | return output_ids 133 | 134 | def encode(self, input_toks, add_eos=False, add_sos=False): 135 | """ Encodes a list of tokens to a list of ids """ 136 | return self.tokenizer_model.encode(input_toks, 137 | add_bos=add_sos, 138 | add_eos=add_eos) 139 | 140 | def decode(self, input_ids): 141 | """ Decodes a list of ids to a list of tokens """ 142 | return self.tokenizer_model.decode(input_ids) 143 | 144 | @classmethod 145 | def load_tokenizer(cls, tokenizer_path): 146 | return pickle.load(open(tokenizer_path, "rb")) 147 | 148 | class BasicTokenizer(BaseTokenizer): 149 | def __init__(self, data_path, md_transformer, vocab_limit=None): 150 | """ Basic Tokenizer.""" 151 | self.data_path = data_path 152 | self.md_transformer = md_transformer 153 | 154 | self._tok2id = {} 155 | self._id2tok = {} 156 | self.vocab = set() 157 | self.special_tokens = BASE_TOKENS 158 | 159 | self.vocab_limit = vocab_limit 160 | self.vocab_counter = Counter() 161 | 162 | # Initial construction of class 163 | self._create_vocab(data_path) 164 | self._create_token_to_id_map() 165 | self._create_id_to_token_map() 166 | 167 | def _create_vocab(self, data_path): 168 | logger.info("Creating vocab for tokenizer") 169 | 170 | for utterance in get_next_utterance(data_path): 171 | _, text = self.md_transformer.parse_raw_input(utterance) 172 | tokens = text.split() 173 | 174 | self.vocab_counter.update(tokens) 175 | 176 | if self.vocab_limit: 177 | word_counts = self.vocab_counter.most_common(self.vocab_limit) 178 | words, _ = zip(*word_counts) 179 | else: 180 | words = self.vocab_counter.elements() 181 | self.vocab = set(words) 182 | 183 | def _create_token_to_id_map(self): 184 | for tok, id in self.special_tokens.items(): 185 | self._tok2id[tok] = id 186 | special_tok_offset = len(self.special_tokens) 187 | for id, tok in enumerate(self.vocab): 188 | self._tok2id[tok] = special_tok_offset+id 189 | 190 | def _create_id_to_token_map(self): 191 | self._id2tok = {val:key for key, val in self._tok2id.items()} 192 | 193 | def get_vocab_size(self): 194 | """Return vocab + number of special tokens """ 195 | return len(self._tok2id) 196 | 197 | def print_special_token_ids(self): 198 | logger.info("Special Token: Token ID") 199 | for token in self.special_tokens: 200 | id = self._tok2id[token] 201 | logger.info(f"\t {token}: {id}") 202 | 203 | def add_special_tokens(self, special_tokens): 204 | # Expanding _tok2id and _id2tok with new special_tokens 205 | logger.info(f"Using special tokens: {special_tokens}") 206 | last_id = len(self._tok2id) 207 | new_tok2id = {} 208 | for idx, tok in enumerate(set(special_tokens)): 209 | new_tok2id[tok] = idx+last_id 210 | new_id2tok = {val:key for key, val in new_tok2id.items()} 211 | self._tok2id = {**self._tok2id, **new_tok2id} 212 | self._id2tok = {**self._id2tok, **new_id2tok} 213 | 214 | self.special_tokens = {**self.special_tokens, **new_tok2id} 215 | self.print_special_token_ids() 216 | 217 | def encode_text(self, input_sentence, add_eos=False, add_sos=False): 218 | """Encodes an utterance, and optionally prepends or postpends eos/sos tokens.""" 219 | input_toks = input_sentence.split() 220 | ids = self.encode(input_toks) 221 | if add_eos: 222 | eos_id = self._tok2id[""] 223 | ids.append(eos_id) 224 | if add_sos: 225 | sos_id = self._tok2id[""] 226 | ids.insert(0, sos_id) 227 | return ids 228 | 229 | def encode(self, input_toks): 230 | """ Encodes a list of tokens to a list of ids """ 231 | output_ids = [] 232 | for tok in input_toks: 233 | if tok in self._tok2id.keys(): 234 | id = self._tok2id[tok] 235 | else: 236 | id = self._tok2id[""] 237 | output_ids.append(id) 238 | return output_ids 239 | 240 | def decode(self, input_ids): 241 | """ Decodes a list of ids to a list of tokens """ 242 | output_toks = [] 243 | for id in input_ids: 244 | tok = self._id2tok[id] 245 | output_toks.append(tok) 246 | return output_toks 247 | 248 | @classmethod 249 | def load_tokenizer(cls, tokenizer_path): 250 | return pickle.load(open(tokenizer_path, "rb")) 251 | 252 | ##### Utility Wrapper ##### 253 | 254 | TOKENIZER_MAP = {"basic_tokenizer": BasicTokenizer, 255 | "sentence_piece_tokenizer": SPTokenizer} 256 | 257 | def get_tokenizer(tokenizer_type, data_path, md_transformer, 258 | vocab_limit=0, force_new_creation=False): 259 | """Either loads in a pretrained tokenizer or creates a new one and saves it""" 260 | assert(tokenizer_type in TOKENIZER_MAP),\ 261 | f"Invalid tokenizer_type: {tokenizer_type}" 262 | tokenizer_class = TOKENIZER_MAP[tokenizer_type] 263 | if vocab_limit: 264 | saved_tokenizer_path = f"cache/{tokenizer_type}_{vocab_limit}.pkl" 265 | else: 266 | saved_tokenizer_path = f"cache/{tokenizer_type}_full.pkl" 267 | if os.path.exists(saved_tokenizer_path) and not force_new_creation: 268 | logger.info(f"Loading in saved tokenizer from: {saved_tokenizer_path}") 269 | tokenizer = tokenizer_class.load_tokenizer(saved_tokenizer_path) 270 | else: 271 | if not os.path.exists("cache"): 272 | os.mkdir("cache") 273 | logger.info(f"Creating new tokenizer") 274 | tokenizer = tokenizer_class(data_path, md_transformer, vocab_limit) 275 | pickle.dump(tokenizer, open(saved_tokenizer_path, "wb")) 276 | logger.info(f"Saving out tokenizer to: {saved_tokenizer_path}") 277 | logger.info(f"Size of vocab: {tokenizer.get_vocab_size()}") 278 | return tokenizer 279 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import torch 5 | import os 6 | import shutil 7 | import json 8 | import subprocess 9 | from src.data import custom_collate, MetaDataset 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR 12 | 13 | """Utils for model training and evaluation""" 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | def move_to_device(batch): 18 | updated_batch = {} 19 | for key, val in batch.items(): 20 | if isinstance(val, dict): 21 | if key not in updated_batch: 22 | updated_batch[key] = {} 23 | for sub_key, sub_val in val.items(): 24 | if sub_val is not None: 25 | updated_batch[key][sub_key] = sub_val.to(device) 26 | else: 27 | if val is not None: 28 | updated_batch[key] = val.to(device) 29 | return updated_batch 30 | 31 | def get_dataloader(config, tokenizer, md_transformer, partition, split, config_section="DATA"): 32 | dataset = MetaDataset(config.get(config_section, f"{partition}_data_directory_{split}"), 33 | tokenizer, 34 | md_transformer) 35 | dataloader = DataLoader(dataset, 36 | batch_size=int(config.get("MODEL", "batch_size")), 37 | collate_fn=custom_collate) 38 | return dataloader 39 | 40 | def save_model_checkpoint(model, step, config, 41 | dev_loss_full, dev_ppl_full, 42 | dev_loss_head, dev_ppl_head, 43 | dev_loss_tail, dev_ppl_tail): 44 | """Saves out model artifact along with basic statistics about checkpoint""" 45 | experiment_directory = config.get("EXPERIMENT", "experiment_directory") 46 | checkpoints_dir = os.path.join(experiment_directory, "checkpoints") 47 | 48 | if not os.path.exists(checkpoints_dir): 49 | os.mkdir(checkpoints_dir) 50 | 51 | curr_checkpoint = os.path.join(checkpoints_dir, f"checkpoint_step_{step}") 52 | os.mkdir(curr_checkpoint) 53 | with open(os.path.join(curr_checkpoint, "info.txt"), "w") as f_out : 54 | f_out.write(f"Step {step}\n") 55 | f_out.write(f"Full Dev Loss: {dev_loss_full} - Dev PPL {dev_ppl_full}\n") 56 | f_out.write(f"Head Dev Loss: {dev_loss_head} - Dev PPL {dev_ppl_head}\n") 57 | f_out.write(f"Tail Dev Loss: {dev_loss_tail} - Dev PPL {dev_ppl_tail}") 58 | 59 | torch.save(model, os.path.join(curr_checkpoint, "model.pt")) 60 | 61 | def eval_model(model, dataloader, loss_fn): 62 | """ 63 | Evaluates model on a given data loader. If per_utterance_ppl is set to 64 | True function returns the ppl of each utterance in the dataloader. 65 | """ 66 | cumulative_loss = 0.0 67 | cumulative_tokens = 0 68 | for step, batch in enumerate(dataloader): 69 | 70 | batch = move_to_device(batch) 71 | pred_logits = model(batch) 72 | loss, tokens = loss_fn(pred_logits, batch) 73 | 74 | cumulative_loss += loss.item() 75 | cumulative_tokens += tokens.item() 76 | loss = cumulative_loss/cumulative_tokens 77 | ppl = torch.exp(torch.tensor(loss)) 78 | return (loss, ppl) 79 | 80 | def get_lr_scheduler(config, optimizer): 81 | """ 82 | Returns a bool of (update_lr_per_step, lr_scheduler) 83 | 84 | """ 85 | lr = float(config.get("TRAIN", "learning_rate")) 86 | scheduler_type = config.get("TRAIN", "scheduler", fallback="plateau") 87 | 88 | if scheduler_type == "plateau": 89 | eps_tolerance = float(config.get("TRAIN", "eps_tolerance", fallback='0')) 90 | patience = int(config.get("TRAIN", "patience", fallback='1')) 91 | decay_factor = float(config.get("TRAIN", "decay_factor", fallback='0.5')) 92 | scheduler = ReduceLROnPlateau(optimizer, 93 | factor=decay_factor, 94 | patience=patience, 95 | eps=eps_tolerance) 96 | update_lr_per_step = False 97 | elif scheduler_type == "one_cycle": 98 | max_train_steps = int(config.get("TRAIN", "max_train_steps")) 99 | anneal_strategy = config.get("TRAIN", "anneal_strategy", fallback="cos") 100 | pct_start = float(config.get("TRAIN", "pct_start", fallback=0.3)) 101 | scheduler = OneCycleLR(optimizer, 102 | max_lr=lr, 103 | total_steps=max_train_steps, 104 | anneal_strategy=anneal_strategy, 105 | pct_start=pct_start) 106 | update_lr_per_step = True 107 | else: 108 | raise Exception(f"Invalid scheduler type: {scheduler_type}") 109 | 110 | return (update_lr_per_step, scheduler) 111 | --------------------------------------------------------------------------------