├── .gitignore ├── LICENSE ├── README.md ├── add_args.py ├── config.py ├── eval.py ├── models ├── coherence_models.py ├── gan_models.py ├── infersent_models.py └── language_models.py ├── prepare_data.py ├── preprocess.py ├── run_bigram_coherence.py ├── run_lm_coherence.py ├── train_lm.py └── utils ├── data_utils.py ├── lm_utils.py ├── logging_utils.py └── np_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.py[cod] 2 | **/*~ 3 | **/*.swp 4 | **/*.pkl 5 | **/*.txt 6 | **/*.csv 7 | **/*.tsv 8 | **/log/ 9 | **/lt_records/ 10 | **/checkpoint/ 11 | **/slurm_job_tool/skynet_jobs/* 12 | **/slurm_job_tool/*.out 13 | **/*.bin 14 | **/*.pkl 15 | **/*.zip 16 | **/*.gz 17 | pytorch_coherence/InferSent/results/ 18 | **/.ipynb_checkpoints/ 19 | **/fastText-0.1.0/ 20 | **/*.vec 21 | **/wsj_NER/ 22 | **/wsj_NER_results/ 23 | **/cnn_coherence-master/ 24 | **/brown_coherence/ 25 | **/NER_v2.ipynb 26 | **/core 27 | **/reranking/ 28 | **/runs/ 29 | **/results/ 30 | *.out 31 | data/* 32 | checkpoint/* 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | 439 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-Domain Coherence Modeling 2 | 3 | A Cross-Domain Transferable Neural Coherence Model 4 | 5 | Paper published in ACL 2019: [arxiv.org/abs/1905.11912](https://arxiv.org/abs/1905.11912) 6 | 7 | This implementation is based on PyTorch 0.4.1. 8 | 9 | ### Dataset 10 | 11 | To download the dataset: 12 | 13 | ``` 14 | python prepare_data.py 15 | ``` 16 | 17 | which includes WikiCoherence dataset we construct, 300-dim GloVe embeddings and pre-trained Infersent model. 18 | 19 | For WikiCoherence, it contains: 20 | 21 | - 7 categories under **Person** 22 | - Artist 23 | - Athlete 24 | - Politician 25 | - Writer 26 | - MilitaryPerson 27 | - OfficeHolder 28 | - Scientist 29 | - 3 categories from different irrelevant domains: 30 | - Plant 31 | - EducationalInstitution 32 | - CelestialBody 33 | - parsed\_wsj: original split for Wall Street Journal (WSJ) 34 | - parsed\_random: randomly split all paragraphs of the seven categories under **Person** into training part and testing part 35 | 36 | Check `config.py` for the data\_name for each setting. 37 | 38 | ### Preprocessing 39 | 40 | Premute the original documents or paragraphs to obtain the negative samples for evaluation: 41 | 42 | ``` 43 | python preprocess.py --data_name 44 | ``` 45 | 46 | ### LM Pre-training 47 | 48 | Train the LM with the following command: 49 | 50 | ``` 51 | python train_lm.py --data_name 52 | python train_lm.py --data_name --reverse 53 | ``` 54 | 55 | The pre-trained models will be saved in `./checkpoint`. 56 | 57 | ### Training and Evaluation 58 | 59 | To evaluate our proposed models: 60 | 61 | ``` 62 | python run_bigram_coherence.py --data_name --sent_encoder [--bidirectional] 63 | ``` 64 | 65 | where `sent_encoder` can be average\_glove, infersent or lm\_hidden. 66 | 67 | ``` 68 | python eval.py --data_name --sent_encoder [--bidirectional] 69 | ``` 70 | 71 | Run the above script will run the experiment multiple times and report the mean and std statistics. 72 | The log will be saved in `./log`. 73 | 74 | ### Cite 75 | 76 | If you found this codebase or our work useful, please cite: 77 | 78 | ``` 79 | @InProceedings{xu2019cross, 80 | author = {Xu, Peng and Saghir, Hamidreza and Kang, Jin Sung and Long, Teng and Bose, Avishek Joey and Cao, Yanshuai and Cheung, Jackie Chi Kit}, 81 | title = {A Cross-Domain Transferable Neural Coherence Model} 82 | booktitle = {The 57th Annual Meeting of the Association for Computational Linguistics (ACL 2019)}, 83 | month = {July}, 84 | year = {2019}, 85 | publisher = {ACL} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /add_args.py: -------------------------------------------------------------------------------- 1 | def add_bigram_args(parser): 2 | # System Hyperparameters 3 | parser.add_argument('--data_name', type=str, default='wsj_bigram', 4 | help="data name") 5 | parser.add_argument('--random_seed', type=int, default=2018, 6 | help="random seed") 7 | parser.add_argument('--test', default=False, action='store_true', 8 | help="Test with smaller infersent embeddings") 9 | parser.add_argument('--batch_size', type=int, default=128, 10 | help="batch_size") 11 | parser.add_argument('--save', default=False, action='store_true', 12 | help="whether to save the model") 13 | parser.add_argument('--portion', type=float, default=1.0, 14 | help="portion of negative samples to use") 15 | 16 | # Model Hyperparameters 17 | parser.add_argument('--loss', type=str, default='margin', 18 | help="training loss") 19 | parser.add_argument('--input_dropout', type=float, default=0.6, 20 | help="input_dropout") 21 | parser.add_argument('--hidden_state', type=int, default=500, 22 | help="hidden_state") 23 | parser.add_argument('--hidden_layers', type=int, default=1, 24 | help="hidden_layers") 25 | parser.add_argument('--hidden_dropout', type=float, default=0.3, 26 | help="hidden_dropout") 27 | parser.add_argument('--num_epochs', type=int, default=50, 28 | help="num_epochs") 29 | parser.add_argument('--margin', type=float, default=5.0, 30 | help="margin") 31 | parser.add_argument('--lr', type=float, default=0.001, 32 | help="learning rate") 33 | parser.add_argument('--l2_reg_lambda', type=float, default=0.0, 34 | help="l2_reg_lambda") 35 | parser.add_argument('--use_bn', default=False, action='store_true', 36 | help="use_bn") 37 | parser.add_argument('--embed_dim', type=int, default=100, 38 | help="embedi_dim") 39 | parser.add_argument('--dpout_model', type=float, default=0.0, 40 | help="dpout_model") 41 | parser.add_argument('--sent_encoder', type=str, default='infersent', 42 | help="sent_encoder") 43 | parser.add_argument('--bidirectional', default=False, action='store_true', 44 | help="bidirectional") 45 | 46 | parser.add_argument('--note', type=str, default='', 47 | help="human readable") 48 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from utils.data_utils import WSJ_Bigram_Dataset, WIKI_Bigram_Dataset 2 | 3 | # ------------------- PATH ------------------- 4 | ROOT_PATH = "." 5 | DATA_PATH = "%s/data" % ROOT_PATH 6 | LOG_PATH = "%s/log" % ROOT_PATH 7 | CHECKPOINT_PATH = "%s/checkpoint" % ROOT_PATH 8 | 9 | # ------------------- DATA ------------------- 10 | 11 | INFERSENT_MODEL = "%s/infersent1.pkl" % DATA_PATH 12 | WORD_EMBEDDING = "%s/glove.840B.300d.txt" % DATA_PATH 13 | 14 | DATASET = {} 15 | 16 | WSJ_DATA_PATH = "%s/parsed_wsj" % DATA_PATH 17 | SAMPLE_WSJ_DATA_PATH = "%s/parsed_wsj" % DATA_PATH 18 | WSJ_VALID_PERM = "%s/valid_perm.tsv" % WSJ_DATA_PATH 19 | WSJ_TEST_PERM = "%s/test_perm.tsv" % WSJ_DATA_PATH 20 | 21 | DATASET["wsj_bigram"] = { 22 | "dataset": WSJ_Bigram_Dataset, 23 | "data_path": WSJ_DATA_PATH, 24 | "sample_path": SAMPLE_WSJ_DATA_PATH, 25 | "valid_perm": WSJ_VALID_PERM, 26 | "test_perm": WSJ_TEST_PERM, 27 | "kwargs": {}, 28 | } 29 | 30 | WIKI_DATA_PATH = DATA_PATH 31 | SAMPLE_WIKI_DATA_PATH = DATA_PATH 32 | WIKI_IN_DOMAIN = ["Artist", "Athlete", "Politician", "Writer", "MilitaryPerson", 33 | "OfficeHolder", "Scientist"] 34 | WIKI_OUT_DOMAIN = ["Plant", "CelestialBody", "EducationalInstitution"] 35 | 36 | WIKI_EASY_DATA_PATH = "%s/parsed_random" % DATA_PATH 37 | WIKI_EASY_VALID_PERM = "%s/valid_perm.tsv" % WIKI_EASY_DATA_PATH 38 | WIKI_EASY_TEST_PERM = "%s/test_perm.tsv" % WIKI_EASY_DATA_PATH 39 | WIKI_EASY_TRAIN_LIST = ["train"] 40 | WIKI_EASY_TEST_LIST = ["test"] 41 | 42 | for i in range(7): 43 | category = WIKI_IN_DOMAIN[i] 44 | DATASET["wiki_bigram_%s" % category] = { 45 | "dataset": WIKI_Bigram_Dataset, 46 | "data_path": WIKI_DATA_PATH, 47 | "sample_path": SAMPLE_WIKI_DATA_PATH, 48 | "valid_perm": "%s/wiki_%s_valid_perm.tsv" % (DATA_PATH, category.lower()), 49 | "test_perm": "%s/wiki_%s_test_perm.tsv" % (DATA_PATH, category.lower()), 50 | "kwargs": { 51 | "train_list": WIKI_IN_DOMAIN[:i] + WIKI_IN_DOMAIN[i + 1:], 52 | "test_list": [category], 53 | }, 54 | } 55 | 56 | for category in WIKI_OUT_DOMAIN: 57 | DATASET["wiki_bigram_%s" % category] = { 58 | "dataset": WIKI_Bigram_Dataset, 59 | "data_path": WIKI_DATA_PATH, 60 | "sample_path": SAMPLE_WIKI_DATA_PATH, 61 | "valid_perm": "%s/wiki_%s_valid_perm.tsv" % (DATA_PATH, category.lower()), 62 | "test_perm": "%s/wiki_%s_test_perm.tsv" % (DATA_PATH, category.lower()), 63 | "kwargs": { 64 | "train_list": WIKI_IN_DOMAIN, 65 | "test_list": [category], 66 | }, 67 | } 68 | 69 | DATASET["wiki_bigram_easy"] = { 70 | "dataset": WIKI_Bigram_Dataset, 71 | "data_path": WIKI_EASY_DATA_PATH, 72 | "sample_path": SAMPLE_WIKI_DATA_PATH, 73 | "valid_perm": WIKI_EASY_VALID_PERM, 74 | "test_perm": WIKI_EASY_TEST_PERM, 75 | "kwargs": { 76 | "train_list": WIKI_EASY_TRAIN_LIST, 77 | "test_list": WIKI_EASY_TEST_LIST, 78 | }, 79 | } 80 | 81 | # ------------------- PARAM ------------------ 82 | 83 | RANDOM_SEED = 2018 84 | 85 | MAX_SENT_LENGTH = 40 86 | 87 | NEG_PERM = 20 88 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from run_bigram_coherence import run_bigram_coherence 2 | from utils.logging_utils import _get_logger 3 | from add_args import add_bigram_args 4 | from datetime import datetime 5 | import config 6 | import numpy as np 7 | import argparse 8 | import gc 9 | 10 | def experiments(args): 11 | runs = 5 12 | time_str = datetime.now().date().isoformat() 13 | logname = "[Data@%s]_[Encoder@%s]" % (args.data_name, args.sent_encoder) 14 | if args.bidirectional: 15 | logname += "_[Bi]" 16 | logname += "_%s.log" % time_str 17 | logger = _get_logger(config.LOG_PATH, logname) 18 | dis_accs = [] 19 | ins_accs = [] 20 | for i in range(runs): 21 | dis_acc, ins_acc = run_bigram_coherence(args) 22 | dis_accs.append(dis_acc[0]) 23 | ins_accs.append(ins_acc[0]) 24 | for _ in range(10): 25 | gc.collect() 26 | 27 | logger.info("=" * 50) 28 | for i in range(runs): 29 | logger.info("Run %d" % (i + 1)) 30 | logger.info("Dis Acc: %.6f" % dis_accs[i]) 31 | logger.info("Ins Acc: %.6f" % ins_accs[i]) 32 | logger.info("=" * 50) 33 | logger.info("Average Dis Acc: %.6f (%.6f)" % (np.mean(dis_accs), np.std(dis_accs))) 34 | logger.info("Average Ins Acc: %.6f (%.6f)" % (np.mean(ins_accs), np.std(ins_accs))) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | add_bigram_args(parser) 40 | args = parser.parse_args() 41 | 42 | experiments(args) 43 | -------------------------------------------------------------------------------- /models/coherence_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | from .gan_models import MLP_Discriminator 7 | import numpy as np 8 | import pickle 9 | from datetime import datetime 10 | from utils.np_utils import generate_random_pmatrices 11 | from tqdm import tqdm 12 | import utils.lm_utils as utils 13 | 14 | 15 | class MarginRankingLoss(nn.Module): 16 | def __init__(self, margin): 17 | super(MarginRankingLoss, self).__init__() 18 | self.margin = margin 19 | 20 | def forward(self, p_scores, n_scores, weights=None): 21 | scores = self.margin - p_scores + n_scores 22 | scores = scores.clamp(min=0) 23 | if weights is not None: 24 | scores = weights * scores 25 | return scores.mean() 26 | 27 | class BigramCoherence: 28 | def __init__(self, embed_dim, sent_encoder, hparams): 29 | self.embed_dim = embed_dim 30 | self.sent_encoder = sent_encoder 31 | self.num_epochs = hparams["num_epochs"] 32 | margin = hparams["margin"] 33 | lr = hparams["lr"] 34 | l2_reg_lambda = hparams["l2_reg_lambda"] 35 | self.task = hparams["task"] 36 | self.hparams = hparams 37 | 38 | self.use_cuda = torch.cuda.is_available() 39 | self.use_pretrained = isinstance(self.sent_encoder, dict) 40 | self.discriminator = MLP_Discriminator( 41 | embed_dim, hparams, self.use_cuda) 42 | model_parameters = list(self.discriminator.parameters()) 43 | if not self.use_pretrained: 44 | model_parameters += list(self.sent_encoder.parameters()) 45 | self.optimizer = optim.Adam(model_parameters, 46 | lr=lr, weight_decay=l2_reg_lambda) 47 | 48 | self.loss_name = hparams['loss'] 49 | if hparams['loss'] == 'margin': 50 | self.loss_fn = MarginRankingLoss(margin) 51 | elif hparams['loss'] == 'log': 52 | self.loss_fn = nn.BCEWithLogitsLoss() 53 | elif hparams['loss'] == 'margin+log': 54 | self.loss_fn = [MarginRankingLoss(margin), nn.BCEWithLogitsLoss()] 55 | else: 56 | raise NotImplementedError() 57 | 58 | if self.use_cuda: 59 | self.discriminator.cuda() 60 | if not self.use_pretrained: 61 | self.sent_encoder.cuda() 62 | self.best_discriminator = self.discriminator.state_dict() 63 | self.intervals = [0, 10, 20, np.inf] 64 | 65 | def init(self): 66 | def init_weights(model): 67 | if type(model) in [nn.Linear]: 68 | nn.init.xavier_normal_(model.weight.data) 69 | 70 | self.discriminator.apply(init_weights) 71 | 72 | def _variable(self, data): 73 | data = np.array(data) 74 | data = Variable(torch.from_numpy(data)) 75 | return data.cuda() if self.use_cuda else data 76 | 77 | def encode(self, sentences): 78 | if self.use_pretrained: 79 | sentences = np.array(list(map(self.sent_encoder.get, sentences))) 80 | sentences = self._variable(sentences) 81 | return sentences 82 | sentences, lengths, idx_sort = self.sent_encoder.prepare_samples( 83 | sentences, -1, 40, False, False) 84 | with torch.autograd.no_grad(): 85 | batch = Variable(self.sent_encoder.get_batch(sentences)) 86 | if self.use_cuda: 87 | batch = batch.cuda() 88 | batch = self.sent_encoder.forward((batch, lengths)) 89 | return batch 90 | 91 | def fit(self, train, valid=None, df=None): 92 | best_valid_acc = 0 93 | best_valid_epoch = 0 94 | step = 0 95 | for epoch in range(1, self.num_epochs + 1): 96 | for sentences in train: 97 | step += 1 98 | self.discriminator.zero_grad() 99 | 100 | sent1 = [] 101 | pos_sent2 = [] 102 | neg_sent2 = [] 103 | slens = [] 104 | for s in sentences: 105 | s1, ps2, ns2, slen = s.split('') 106 | sent1.append(s1) 107 | pos_sent2.append(ps2) 108 | neg_sent2.append(ns2) 109 | slens.append(int(slen)) 110 | sent1 = self.encode(sent1) 111 | slens = np.array(slens, dtype=np.float32) 112 | weights = 1. / slens / (slens - 1) 113 | weights /= np.mean(weights) 114 | weights = self._variable(weights) 115 | 116 | pos_sent2 = self.encode(pos_sent2) 117 | pos_scores = self.discriminator(sent1, pos_sent2) 118 | 119 | neg_sent2 = self.encode(neg_sent2) 120 | neg_scores = self.discriminator(sent1, neg_sent2) 121 | 122 | if self.loss_name == 'margin': 123 | loss = self.loss_fn(pos_scores, neg_scores) 124 | 125 | elif self.loss_name == 'log': 126 | loss = self.loss_fn(-pos_scores, 127 | torch.ones_like(pos_scores)) 128 | loss += self.loss_fn(-neg_scores, 129 | torch.zeros_like(neg_scores)) 130 | elif self.loss_name == 'margin+log': 131 | loss = self.loss_fn[0](pos_scores, neg_scores) 132 | loss += .1 * \ 133 | self.loss_fn[1](-pos_scores, 134 | torch.ones_like(pos_scores)) 135 | loss += .1 * \ 136 | self.loss_fn[1](-neg_scores, 137 | torch.zeros_like(neg_scores)) 138 | else: 139 | raise NotImplementedError() 140 | 141 | if step % 100 == 0: 142 | time_str = datetime.now().isoformat() 143 | print("{}: step {}, loss {:g}".format( 144 | time_str, step, loss.item())) 145 | 146 | loss.backward() 147 | self.optimizer.step() 148 | 149 | if valid is not None: 150 | print("\nValidation:") 151 | print("previous best epoch {}, acc {:g}".format( 152 | best_valid_epoch, best_valid_acc)) 153 | acc, _ = self.evaluate(valid, df, self.task) 154 | print("epoch {} acc {:g}".format(epoch, acc)) 155 | print("") 156 | if acc > best_valid_acc: 157 | best_valid_acc = acc 158 | best_valid_epoch = epoch 159 | self.best_discriminator = self.discriminator.state_dict() 160 | if epoch - best_valid_epoch > 3: 161 | break 162 | return best_valid_epoch, best_valid_acc 163 | 164 | def evaluate(self, test, df, task="discrimination"): 165 | if task == "discrimination": 166 | return self.evaluate_dis(test, df) 167 | elif task == "insertion": 168 | return self.evaluate_ins(test, df) 169 | else: 170 | raise ValueError("Invalid task name!") 171 | 172 | def score_article(self, article, reverse=False): 173 | if reverse: 174 | article = article[::-1] 175 | 176 | first_sentences = self.encode(article[:-1]) 177 | second_sentences = self.encode(article[1:]) 178 | y = self.discriminator(first_sentences, second_sentences) 179 | local_y = y.mean().data.cpu().numpy() 180 | 181 | return local_y 182 | 183 | def evaluate_dis(self, test, df, debug=False): 184 | correct_pred = [0, 0, 0] 185 | total_samples = [0, 0, 0] 186 | 187 | if debug: 188 | all_pos_scores = [] 189 | all_neg_scores = [] 190 | 191 | self.discriminator.train(False) 192 | for article in test: 193 | sentences = df.loc[article[0], "sentences"].split("") 194 | sent_num = len(sentences) 195 | sentences = [""] + sentences + [""] 196 | neg_sentences_list = df.loc[article[0], 197 | "neg_list"].split("") 198 | neg_sentences_list = [s.split("") 199 | for s in neg_sentences_list] 200 | 201 | pos_sent1 = sentences[:-1] 202 | pos_sent1 = self.encode(pos_sent1) 203 | pos_sent2 = sentences[1:] 204 | pos_sent2 = self.encode(pos_sent2) 205 | pos_scores = self.discriminator(pos_sent1, pos_sent2) 206 | # import ipdb 207 | # ipdb.set_trace() 208 | mean_pos_score = pos_scores.mean().data.cpu().numpy() 209 | 210 | if debug: 211 | all_pos_scores.append(pos_scores.data.cpu().numpy().squeeze()) 212 | 213 | mean_neg_scores = [] 214 | for neg_sentences in neg_sentences_list: 215 | neg_sentences = [""] + neg_sentences + [""] 216 | neg_sent1 = neg_sentences[:-1] 217 | neg_sent1 = self.encode(neg_sent1) 218 | neg_sent2 = neg_sentences[1:] 219 | neg_sent2 = self.encode(neg_sent2) 220 | neg_scores = self.discriminator(neg_sent1, neg_sent2) 221 | mean_neg_score = neg_scores.mean().data.cpu().numpy() 222 | mean_neg_scores.append(mean_neg_score) 223 | 224 | if debug: 225 | all_neg_scores.append( 226 | neg_scores.data.cpu().numpy().squeeze()) 227 | 228 | for mean_neg_score in mean_neg_scores: 229 | if mean_pos_score > mean_neg_score: 230 | for i in range(3): 231 | lower_bound = self.intervals[i] 232 | upper_bound = self.intervals[i + 1] 233 | if (sent_num > lower_bound) and (sent_num <= upper_bound): 234 | correct_pred[i] += 1 235 | for i in range(3): 236 | lower_bound = self.intervals[i] 237 | upper_bound = self.intervals[i + 1] 238 | if (sent_num > lower_bound) and (sent_num <= upper_bound): 239 | total_samples[i] += 1 240 | 241 | print(" ".join(sentences), mean_pos_score) 242 | print(" ".join(neg_sentences), mean_neg_score) 243 | self.discriminator.train(True) 244 | accs = np.true_divide(correct_pred, total_samples) 245 | acc = np.true_divide(np.sum(correct_pred), np.sum(total_samples)) 246 | 247 | if debug: 248 | all_pos_scores = np.concatenate(all_pos_scores) 249 | all_neg_scores = np.concatenate(all_neg_scores) 250 | 251 | import pandas as pd 252 | 253 | print('pos score stats') 254 | print(pd.DataFrame(all_pos_scores).describe()) 255 | 256 | print('neg score stats') 257 | print(pd.DataFrame(all_neg_scores).describe()) 258 | 259 | return acc, accs 260 | 261 | def evaluate_ins(self, test, df): 262 | correct_pred = [0, 0, 0] 263 | total_samples = [0, 0, 0] 264 | self.discriminator.train(False) 265 | for article in tqdm(test, disable=False): 266 | sentences = df.loc[article[0], "sentences"].split("") 267 | sent_num = len(sentences) 268 | sentences = [""] + sentences + [""] 269 | 270 | pos_sent1 = sentences[:-1] 271 | pos_sent1 = self.encode(pos_sent1) 272 | pos_sent2 = sentences[1:] 273 | pos_sent2 = self.encode(pos_sent2) 274 | pos_scores = self.discriminator(pos_sent1, pos_sent2) 275 | mean_pos_score = pos_scores.mean().data.cpu().numpy() 276 | 277 | cnt = 0.0 278 | for i in range(1, sent_num + 1): 279 | tmp = sentences[:i] + sentences[i + 1:] 280 | flag = True 281 | for j in range(1, sent_num + 1): 282 | if j == i: 283 | continue 284 | neg_sentences = tmp[:j] + sentences[i:i + 1] + tmp[j:] 285 | neg_sent1 = neg_sentences[:-1] 286 | neg_sent1 = self.encode(neg_sent1) 287 | neg_sent2 = neg_sentences[1:] 288 | neg_sent2 = self.encode(neg_sent2) 289 | neg_scores = self.discriminator(neg_sent1, neg_sent2) 290 | mean_neg_score = neg_scores.mean().data.cpu().numpy() 291 | if mean_pos_score < mean_neg_score: 292 | flag = False 293 | if flag: 294 | cnt += 1.0 295 | for i in range(3): 296 | lower_bound = self.intervals[i] 297 | upper_bound = self.intervals[i + 1] 298 | if (sent_num > lower_bound) and (sent_num <= upper_bound): 299 | correct_pred[i] += cnt / sent_num 300 | for i in range(3): 301 | lower_bound = self.intervals[i] 302 | upper_bound = self.intervals[i + 1] 303 | if (sent_num > lower_bound) and (sent_num <= upper_bound): 304 | total_samples[i] += 1 305 | self.discriminator.train(True) 306 | accs = np.true_divide(correct_pred, total_samples) 307 | acc = np.true_divide(np.sum(correct_pred), np.sum(total_samples)) 308 | return acc, accs 309 | 310 | def save(self, path): 311 | torch.save(self.best_discriminator, path + ".pt") 312 | with open(path + ".pkl", "wb") as f: 313 | pickle.dump(self.hparams, f, -1) 314 | 315 | def load(self, path): 316 | self.discriminator.load_state_dict(torch.load(path)) 317 | 318 | def load_best_state(self): 319 | self.discriminator.load_state_dict(self.best_discriminator) 320 | -------------------------------------------------------------------------------- /models/gan_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | 9 | 10 | def _sequence_mask(sequence_length, max_len=None): 11 | if max_len is None: 12 | max_len = sequence_length.data.max() 13 | batch_size = sequence_length.size(0) 14 | seq_range = torch.arange(0, max_len).long() 15 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 16 | seq_range_expand = Variable(seq_range_expand) 17 | if sequence_length.is_cuda: 18 | seq_range_expand = seq_range_expand.cuda() 19 | seq_length_expand = (sequence_length.unsqueeze(1) 20 | .expand_as(seq_range_expand)) 21 | return seq_range_expand < seq_length_expand 22 | 23 | 24 | def compute_loss(logit, target, length): 25 | logit_flat = logit.view(-1, logit.size(-1)) 26 | target_flat = target.view(-1) 27 | losses_flat = F.cross_entropy(logit_flat, target_flat, reduction='none') 28 | losses = losses_flat.view(*target.size()) 29 | mask = _sequence_mask(length, target.size(1)) 30 | losses = losses * mask.float() 31 | loss = losses.sum() / length.float().sum() 32 | return loss 33 | 34 | 35 | class MLP(nn.Module): 36 | def __init__(self, input_dims, n_hiddens, n_class, dropout, use_bn): 37 | super(MLP, self).__init__() 38 | assert isinstance(input_dims, int), 'Invalid type for input_dims!' 39 | self.input_dims = input_dims 40 | current_dims = input_dims 41 | layers = OrderedDict() 42 | 43 | if isinstance(n_hiddens, int): 44 | n_hiddens = [n_hiddens] 45 | else: 46 | n_hiddens = list(n_hiddens) 47 | 48 | for i, n_hidden in enumerate(n_hiddens): 49 | l_i = i + 1 50 | layers['fc{}'.format(l_i)] = nn.Linear(current_dims, n_hidden) 51 | layers['relu{}'.format(l_i)] = nn.ReLU() 52 | layers['drop{}'.format(l_i)] = nn.Dropout(dropout) 53 | if use_bn: 54 | layers['bn{}'.format(l_i)] = nn.BatchNorm1d(n_hidden) 55 | current_dims = n_hidden 56 | layers['out'] = nn.Linear(current_dims, n_class) 57 | 58 | self.model = nn.Sequential(layers) 59 | 60 | def forward(self, input): 61 | return self.model.forward(input) 62 | 63 | class MLP_Discriminator(nn.Module): 64 | def __init__(self, embed_dim, hparams, use_cuda): 65 | super(MLP_Discriminator, self).__init__() 66 | self.embed_dim = embed_dim 67 | self.hidden_state = hparams["hidden_state"] 68 | self.hidden_layers = hparams["hidden_layers"] 69 | self.hidden_dropout = hparams["hidden_dropout"] 70 | self.input_dropout = hparams["input_dropout"] 71 | self.use_bn = hparams["use_bn"] 72 | self.bidirectional = hparams["bidirectional"] 73 | self.use_cuda = use_cuda 74 | 75 | self.mlp = MLP(embed_dim * 5, [self.hidden_state] * self.hidden_layers, 76 | 1, self.hidden_dropout, self.use_bn) 77 | self.dropout = nn.Dropout(self.input_dropout) 78 | if self.bidirectional: 79 | self.backward_mlp = MLP(embed_dim * 5, [self.hidden_state] * self.hidden_layers, 80 | 1, self.hidden_dropout, self.use_bn) 81 | self.backward_dropout = nn.Dropout(self.input_dropout) 82 | 83 | def forward(self, s1, s2): 84 | inputs = torch.cat([s1, s2, s1 - s2, s1 * s2, torch.abs(s1 - s2)], -1) 85 | scores = self.mlp(self.dropout(inputs)) 86 | if self.bidirectional: 87 | backward_inputs = torch.cat( 88 | [s2, s1, s2 - s1, s1 * s2, torch.abs(s1 - s2)], -1) 89 | backward_scores = self.backward_mlp( 90 | self.backward_dropout(backward_inputs)) 91 | scores = (scores + backward_scores) / 2 92 | return scores 93 | 94 | class RNN_LM(nn.Module): 95 | def __init__(self, vocab_size, embed_size, hparams, use_cuda): 96 | super(RNN_LM, self).__init__() 97 | self.vocab_size = vocab_size 98 | self.embed_size = embed_size 99 | self.hidden_size = hparams["hidden_size"] 100 | self.num_layers = hparams["num_layers"] 101 | self.cell_type = hparams["cell_type"] 102 | self.tie_embed = hparams["tie_embed"] 103 | self.rnn_dropout = hparams["rnn_dropout"] 104 | self.hidden_dropout = hparams["hidden_dropout"] 105 | self.use_cuda = use_cuda 106 | 107 | self.embedding = nn.Embedding(vocab_size, embed_size) 108 | rnn_class = { 109 | 'rnn': nn.RNN, 110 | 'gru': nn.GRU, 111 | 'lstm': nn.LSTM, 112 | }[self.cell_type] 113 | self.rnn = rnn_class(self.embed_size, self.hidden_size, self.num_layers, 114 | dropout=self.rnn_dropout) 115 | self.dropout = nn.Dropout(self.hidden_dropout) 116 | 117 | if self.tie_embed: 118 | self.linear_out = nn.Linear(embed_size, vocab_size) 119 | if embed_size != self.hidden_size: 120 | in_size = self.hidden_size 121 | self.linear_proj = nn.Linear( 122 | in_size, embed_size, bias=None) 123 | self.linear_out.weight = self.embedding.weight 124 | else: 125 | self.linear_out = nn.Linear(self.hidden_size, vocab_size) 126 | self.linear_proj = lambda x: x 127 | 128 | def set_embed(self, emb): 129 | with torch.no_grad(): 130 | self.embedding.weight.fill_(0.) 131 | self.embedding.weight += emb 132 | 133 | def init_hidden(self, batch_size): 134 | h0 = Variable(torch.zeros(self.num_layers, 135 | batch_size, self.hidden_size)) 136 | if self.cell_type == 'lstm': 137 | c0 = Variable(torch.zeros(self.num_layers, 138 | batch_size, self.hidden_size)) 139 | return (h0.cuda(), c0.cuda()) if self.use_cuda else (h0, c0) 140 | else: 141 | return h0.cuda() if self.use_cuda else c0 142 | 143 | def encode(self, input, hidden): 144 | embedded = self.dropout(self.embedding(input)) 145 | output, _ = self.rnn(embedded, hidden) 146 | return output 147 | 148 | def forward(self, input, hidden): 149 | embedded = self.dropout(self.embedding(input)) 150 | output, hidden = self.rnn(embedded, hidden) 151 | max_len, batch_size, _ = output.size() 152 | output = output.view(max_len * batch_size, -1) 153 | output = self.dropout(output) 154 | 155 | output = self.linear_proj(output) 156 | output = self.dropout(output) 157 | output = self.linear_out(output) 158 | 159 | output = output.view(max_len, batch_size, -1) 160 | return output, hidden 161 | -------------------------------------------------------------------------------- /models/infersent_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf 10 | """ 11 | 12 | import numpy as np 13 | import time 14 | 15 | import torch 16 | from torch.autograd import Variable 17 | import torch.nn as nn 18 | 19 | import spacy 20 | from tqdm import tqdm 21 | 22 | nlp = spacy.load("en", disable=['tagger', 'ner', 'testcat']) 23 | 24 | """ 25 | BLSTM (max/mean) encoder 26 | """ 27 | 28 | class InferSent(nn.Module): 29 | 30 | def __init__(self, config): 31 | super(InferSent, self).__init__() 32 | self.bsize = config['bsize'] 33 | self.word_emb_dim = config['word_emb_dim'] 34 | self.enc_lstm_dim = config['enc_lstm_dim'] 35 | self.pool_type = config['pool_type'] 36 | self.dpout_model = config['dpout_model'] 37 | self.version = 1 if 'version' not in config else config['version'] 38 | 39 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 40 | bidirectional=True, dropout=self.dpout_model) 41 | 42 | assert self.version in [1, 2] 43 | if self.version == 1: 44 | self.bos = '' 45 | self.eos = '' 46 | self.max_pad = True 47 | self.moses_tok = False 48 | elif self.version == 2: 49 | self.bos = '

' 50 | self.eos = '

' 51 | self.max_pad = False 52 | self.moses_tok = True 53 | 54 | def is_cuda(self): 55 | # either all weights are on cpu or they are on gpu 56 | return self.enc_lstm.bias_hh_l0.data.is_cuda 57 | 58 | def forward(self, sent_tuple): 59 | # sent_len: [max_len, ..., min_len] (bsize) 60 | # sent: Variable(seqlen x bsize x worddim) 61 | sent, sent_len = sent_tuple 62 | 63 | # Sort by length (keep idx) 64 | sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 65 | idx_unsort = np.argsort(idx_sort) 66 | 67 | idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \ 68 | else torch.from_numpy(idx_sort) 69 | sent = sent.index_select(1, Variable(idx_sort)) 70 | 71 | # Handling padding in Recurrent Networks 72 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) 73 | sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid 74 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 75 | 76 | # Un-sort by length 77 | idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \ 78 | else torch.from_numpy(idx_unsort) 79 | sent_output = sent_output.index_select(1, Variable(idx_unsort)) 80 | 81 | # Pooling 82 | if self.pool_type == "mean": 83 | sent_len = Variable(torch.FloatTensor(sent_len.copy())).unsqueeze(1).cuda() 84 | emb = torch.sum(sent_output, 0).squeeze(0) 85 | emb = emb / sent_len.expand_as(emb) 86 | elif self.pool_type == "max": 87 | if not self.max_pad: 88 | sent_output[sent_output == 0] = -1e9 89 | emb = torch.max(sent_output, 0)[0] 90 | if emb.ndimension() == 3: 91 | emb = emb.squeeze(0) 92 | assert emb.ndimension() == 2 93 | 94 | return emb 95 | 96 | def set_w2v_path(self, w2v_path): 97 | self.w2v_path = w2v_path 98 | 99 | def get_word_dict(self, sentences, tokenize=True): 100 | # create vocab of words 101 | word_dict = {} 102 | sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] 103 | for sent in sentences: 104 | for word in sent: 105 | if word not in word_dict: 106 | word_dict[word] = '' 107 | word_dict[self.bos] = '' 108 | word_dict[self.eos] = '' 109 | return word_dict 110 | 111 | def get_w2v(self, word_dict): 112 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 113 | # create word_vec with w2v vectors 114 | word_vec = {} 115 | with open(self.w2v_path) as f: 116 | for line in f: 117 | word, vec = line.split(' ', 1) 118 | if word in word_dict: 119 | word_vec[word] = np.fromstring(vec, sep=' ') 120 | print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict))) 121 | return word_vec 122 | 123 | def get_w2v_k(self, K): 124 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 125 | # create word_vec with k first w2v vectors 126 | k = 0 127 | word_vec = {} 128 | with open(self.w2v_path) as f: 129 | for line in f: 130 | word, vec = line.split(' ', 1) 131 | if k <= K: 132 | word_vec[word] = np.fromstring(vec, sep=' ') 133 | k += 1 134 | if k > K: 135 | if word in [self.bos, self.eos]: 136 | word_vec[word] = np.fromstring(vec, sep=' ') 137 | 138 | if k > K and all([w in word_vec for w in [self.bos, self.eos]]): 139 | break 140 | return word_vec 141 | 142 | def build_vocab(self, sentences, tokenize=True): 143 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 144 | word_dict = self.get_word_dict(sentences, tokenize) 145 | self.word_vec = self.get_w2v(word_dict) 146 | print('Vocab size : %s' % (len(self.word_vec))) 147 | 148 | # build w2v vocab with k most frequent words 149 | def build_vocab_k_words(self, K): 150 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 151 | self.word_vec = self.get_w2v_k(K) 152 | print('Vocab size : %s' % (K)) 153 | 154 | def update_vocab(self, sentences, tokenize=True): 155 | assert hasattr(self, 'w2v_path'), 'warning : w2v path not set' 156 | assert hasattr(self, 'word_vec'), 'build_vocab before updating it' 157 | word_dict = self.get_word_dict(sentences, tokenize) 158 | 159 | # keep only new words 160 | for word in self.word_vec: 161 | if word in word_dict: 162 | del word_dict[word] 163 | 164 | # udpate vocabulary 165 | if word_dict: 166 | new_word_vec = self.get_w2v(word_dict) 167 | self.word_vec.update(new_word_vec) 168 | else: 169 | new_word_vec = [] 170 | print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec))) 171 | 172 | def get_batch(self, batch): 173 | # sent in batch in decreasing order of lengths 174 | # batch: (bsize, max_len, word_dim) 175 | embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) 176 | 177 | for i in range(len(batch)): 178 | for j in range(len(batch[i])): 179 | embed[j, i, :] = self.word_vec[batch[i][j]] 180 | 181 | return torch.FloatTensor(embed) 182 | 183 | def tokenize(self, s): 184 | if self.moses_tok: 185 | s = ' '.join(list(nlp(s))) 186 | s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization 187 | return s.split() 188 | else: 189 | return list(nlp(s)) 190 | 191 | def prepare_samples(self, sentences, bsize, maxlen, tokenize, verbose): 192 | sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else 193 | [self.bos] + self.tokenize(s) + [self.eos] for s in sentences] 194 | n_w = np.sum([len(x) for x in sentences]) 195 | 196 | # filters words without w2v vectors 197 | for i in range(len(sentences)): 198 | s_f = [word for word in sentences[i] if word in self.word_vec] 199 | if not s_f: 200 | import warnings 201 | warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \ 202 | Replacing by ""..' % (sentences[i], i)) 203 | s_f = [self.eos] 204 | sentences[i] = s_f[:maxlen] 205 | 206 | lengths = np.array([len(s) for s in sentences]) 207 | n_wk = np.sum(lengths) 208 | if verbose: 209 | print('Nb words kept : %s/%s (%.1f%s)' % ( 210 | n_wk, n_w, 100.0 * n_wk / n_w, '%')) 211 | 212 | # sort by decreasing length 213 | lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) 214 | sentences = np.array(sentences)[idx_sort] 215 | 216 | return sentences, lengths, idx_sort 217 | 218 | def encode(self, sentences, bsize=64, maxlen=40, tokenize=True, verbose=False): 219 | tic = time.time() 220 | sentences, lengths, idx_sort = self.prepare_samples( 221 | sentences, bsize, maxlen, tokenize, verbose) 222 | 223 | embeddings = [] 224 | for stidx in tqdm(range(0, len(sentences), bsize), disable=True): 225 | with torch.autograd.no_grad(): 226 | batch = Variable(self.get_batch( 227 | sentences[stidx:stidx + bsize])) 228 | if self.is_cuda(): 229 | batch = batch.cuda() 230 | batch = self.forward( 231 | (batch, lengths[stidx:stidx + bsize])).data.cpu().numpy() 232 | embeddings.append(batch) 233 | embeddings = np.vstack(embeddings) 234 | 235 | # unsort 236 | idx_unsort = np.argsort(idx_sort) 237 | embeddings = embeddings[idx_unsort] 238 | 239 | if verbose: 240 | print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % ( 241 | len(embeddings)/(time.time()-tic), 242 | 'gpu' if self.is_cuda() else 'cpu', bsize)) 243 | return embeddings 244 | 245 | def visualize(self, sent, tokenize=True): 246 | 247 | sent = sent.split() if not tokenize else self.tokenize(sent) 248 | sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]] 249 | 250 | if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos): 251 | import warnings 252 | warnings.warn('No words in "%s" have w2v vectors. Replacing \ 253 | by "%s %s"..' % (sent, self.bos, self.eos)) 254 | batch = Variable(self.get_batch(sent), volatile=True) 255 | 256 | if self.is_cuda(): 257 | batch = batch.cuda() 258 | output = self.enc_lstm(batch)[0] 259 | output, idxs = torch.max(output, 0) 260 | # output, idxs = output.squeeze(), idxs.squeeze() 261 | idxs = idxs.data.cpu().numpy() 262 | argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))] 263 | 264 | # visualize model 265 | import matplotlib.pyplot as plt 266 | x = range(len(sent[0])) 267 | y = [100.0 * n / np.sum(argmaxs) for n in argmaxs] 268 | plt.xticks(x, sent[0], rotation=45) 269 | plt.bar(x, y) 270 | plt.ylabel('%') 271 | plt.title('Visualisation of words importance') 272 | plt.show() 273 | 274 | return output, idxs 275 | 276 | """ 277 | BiGRU encoder (first/last hidden states) 278 | """ 279 | 280 | 281 | class BGRUlastEncoder(nn.Module): 282 | def __init__(self, config): 283 | super(BGRUlastEncoder, self).__init__() 284 | self.bsize = config['bsize'] 285 | self.word_emb_dim = config['word_emb_dim'] 286 | self.enc_lstm_dim = config['enc_lstm_dim'] 287 | self.pool_type = config['pool_type'] 288 | self.dpout_model = config['dpout_model'] 289 | 290 | self.enc_lstm = nn.GRU(self.word_emb_dim, self.enc_lstm_dim, 1, 291 | bidirectional=True, dropout=self.dpout_model) 292 | self.init_lstm = Variable(torch.FloatTensor(2, self.bsize, 293 | self.enc_lstm_dim).zero_()).cuda() 294 | 295 | def forward(self, sent_tuple): 296 | # sent_len: [max_len, ..., min_len] (batch) 297 | # sent: Variable(seqlen x batch x worddim) 298 | 299 | sent, sent_len = sent_tuple 300 | bsize = sent.size(1) 301 | 302 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 303 | Variable(torch.FloatTensor(2, bsize, self.enc_lstm_dim).zero_()).cuda() 304 | 305 | # Sort by length (keep idx) 306 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 307 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 308 | 309 | # Handling padding in Recurrent Networks 310 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 311 | _, hn = self.enc_lstm(sent_packed, self.init_lstm) 312 | emb = torch.cat((hn[0], hn[1]), 1) # batch x 2*nhid 313 | 314 | # Un-sort by length 315 | idx_unsort = np.argsort(idx_sort) 316 | emb = emb.index_select(0, Variable(torch.cuda.LongTensor(idx_unsort))) 317 | 318 | return emb 319 | 320 | 321 | """ 322 | BLSTM encoder with projection after BiLSTM 323 | """ 324 | 325 | 326 | class BLSTMprojEncoder(nn.Module): 327 | def __init__(self, config): 328 | super(BLSTMprojEncoder, self).__init__() 329 | self.bsize = config['bsize'] 330 | self.word_emb_dim = config['word_emb_dim'] 331 | self.enc_lstm_dim = config['enc_lstm_dim'] 332 | self.pool_type = config['pool_type'] 333 | self.dpout_model = config['dpout_model'] 334 | 335 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 336 | bidirectional=True, dropout=self.dpout_model) 337 | self.init_lstm = Variable(torch.FloatTensor(2, self.bsize, 338 | self.enc_lstm_dim).zero_()).cuda() 339 | self.proj_enc = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 340 | bias=False) 341 | 342 | def forward(self, sent_tuple): 343 | # sent_len: [max_len, ..., min_len] (batch) 344 | # sent: Variable(seqlen x batch x worddim) 345 | 346 | sent, sent_len = sent_tuple 347 | bsize = sent.size(1) 348 | 349 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 350 | Variable(torch.FloatTensor(2, bsize, self.enc_lstm_dim).zero_()).cuda() 351 | 352 | # Sort by length (keep idx) 353 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 354 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 355 | 356 | # Handling padding in Recurrent Networks 357 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 358 | sent_output = self.enc_lstm(sent_packed, 359 | (self.init_lstm, self.init_lstm))[0] 360 | # seqlen x batch x 2*nhid 361 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 362 | 363 | # Un-sort by length 364 | idx_unsort = np.argsort(idx_sort) 365 | sent_output = sent_output.index_select(1, 366 | Variable(torch.cuda.LongTensor(idx_unsort))) 367 | 368 | sent_output = self.proj_enc(sent_output.view(-1, 369 | 2*self.enc_lstm_dim)).view(-1, bsize, 2*self.enc_lstm_dim) 370 | # Pooling 371 | if self.pool_type == "mean": 372 | sent_len = Variable(torch.FloatTensor(sent_len)).unsqueeze(1).cuda() 373 | emb = torch.sum(sent_output, 0).squeeze(0) 374 | emb = emb / sent_len.expand_as(emb) 375 | elif self.pool_type == "max": 376 | emb = torch.max(sent_output, 0)[0].squeeze(0) 377 | 378 | return emb 379 | 380 | 381 | """ 382 | LSTM encoder 383 | """ 384 | 385 | 386 | class LSTMEncoder(nn.Module): 387 | def __init__(self, config): 388 | super(LSTMEncoder, self).__init__() 389 | self.bsize = config['bsize'] 390 | self.word_emb_dim = config['word_emb_dim'] 391 | self.enc_lstm_dim = config['enc_lstm_dim'] 392 | self.pool_type = config['pool_type'] 393 | self.dpout_model = config['dpout_model'] 394 | 395 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 396 | bidirectional=False, dropout=self.dpout_model) 397 | self.init_lstm = Variable(torch.FloatTensor(1, self.bsize, 398 | self.enc_lstm_dim).zero_()).cuda() 399 | 400 | def forward(self, sent_tuple): 401 | # sent_len [max_len, ..., min_len] (batch) | sent Variable(seqlen x batch x worddim) 402 | 403 | sent, sent_len = sent_tuple 404 | bsize = sent.size(1) 405 | 406 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 407 | Variable(torch.FloatTensor(1, bsize, self.enc_lstm_dim).zero_()).cuda() 408 | 409 | # Sort by length (keep idx) 410 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 411 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 412 | 413 | # Handling padding in Recurrent Networks 414 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 415 | sent_output = self.enc_lstm(sent_packed, (self.init_lstm, 416 | self.init_lstm))[1][0].squeeze(0) # batch x 2*nhid 417 | 418 | # Un-sort by length 419 | idx_unsort = np.argsort(idx_sort) 420 | emb = sent_output.index_select(0, Variable(torch.cuda.LongTensor(idx_unsort))) 421 | 422 | return emb 423 | 424 | 425 | """ 426 | GRU encoder 427 | """ 428 | 429 | 430 | class GRUEncoder(nn.Module): 431 | def __init__(self, config): 432 | super(GRUEncoder, self).__init__() 433 | self.bsize = config['bsize'] 434 | self.word_emb_dim = config['word_emb_dim'] 435 | self.enc_lstm_dim = config['enc_lstm_dim'] 436 | self.pool_type = config['pool_type'] 437 | self.dpout_model = config['dpout_model'] 438 | 439 | self.enc_lstm = nn.GRU(self.word_emb_dim, self.enc_lstm_dim, 1, 440 | bidirectional=False, dropout=self.dpout_model) 441 | self.init_lstm = Variable(torch.FloatTensor(1, self.bsize, 442 | self.enc_lstm_dim).zero_()).cuda() 443 | 444 | def forward(self, sent_tuple): 445 | # sent_len: [max_len, ..., min_len] (batch) 446 | # sent: Variable(seqlen x batch x worddim) 447 | 448 | sent, sent_len = sent_tuple 449 | bsize = sent.size(1) 450 | 451 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 452 | Variable(torch.FloatTensor(1, bsize, self.enc_lstm_dim).zero_()).cuda() 453 | 454 | # Sort by length (keep idx) 455 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 456 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 457 | 458 | # Handling padding in Recurrent Networks 459 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 460 | 461 | sent_output = self.enc_lstm(sent_packed, self.init_lstm)[1].squeeze(0) 462 | # batch x 2*nhid 463 | 464 | # Un-sort by length 465 | idx_unsort = np.argsort(idx_sort) 466 | emb = sent_output.index_select(0, Variable(torch.cuda.LongTensor(idx_unsort))) 467 | 468 | return emb 469 | 470 | 471 | """ 472 | Inner attention from "hierarchical attention for document classification" 473 | """ 474 | 475 | 476 | class InnerAttentionNAACLEncoder(nn.Module): 477 | def __init__(self, config): 478 | super(InnerAttentionNAACLEncoder, self).__init__() 479 | self.bsize = config['bsize'] 480 | self.word_emb_dim = config['word_emb_dim'] 481 | self.enc_lstm_dim = config['enc_lstm_dim'] 482 | self.pool_type = config['pool_type'] 483 | 484 | 485 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 486 | bidirectional=True) 487 | self.init_lstm = Variable(torch.FloatTensor(2, self.bsize, 488 | self.enc_lstm_dim).zero_()).cuda() 489 | 490 | self.proj_key = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 491 | bias=False) 492 | self.proj_lstm = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 493 | bias=False) 494 | self.query_embedding = nn.Embedding(1, 2*self.enc_lstm_dim) 495 | self.softmax = nn.Softmax() 496 | 497 | def forward(self, sent_tuple): 498 | # sent_len: [max_len, ..., min_len] (batch) 499 | # sent: Variable(seqlen x batch x worddim) 500 | 501 | sent, sent_len = sent_tuple 502 | bsize = sent.size(1) 503 | 504 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 505 | Variable(torch.FloatTensor(2, bsize, self.enc_lstm_dim).zero_()).cuda() 506 | 507 | # Sort by length (keep idx) 508 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 509 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 510 | # Handling padding in Recurrent Networks 511 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 512 | sent_output = self.enc_lstm(sent_packed, 513 | (self.init_lstm, self.init_lstm))[0] 514 | # seqlen x batch x 2*nhid 515 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 516 | # Un-sort by length 517 | idx_unsort = np.argsort(idx_sort) 518 | sent_output = sent_output.index_select(1, Variable(torch.cuda.LongTensor(idx_unsort))) 519 | 520 | sent_output = sent_output.transpose(0,1).contiguous() 521 | 522 | sent_output_proj = self.proj_lstm(sent_output.view(-1, 523 | 2*self.enc_lstm_dim)).view(bsize, -1, 2*self.enc_lstm_dim) 524 | 525 | sent_key_proj = self.proj_key(sent_output.view(-1, 526 | 2*self.enc_lstm_dim)).view(bsize, -1, 2*self.enc_lstm_dim) 527 | 528 | sent_key_proj = torch.tanh(sent_key_proj) 529 | # NAACL paper: u_it=tanh(W_w.h_it + b_w) (bsize, seqlen, 2nhid) 530 | 531 | sent_w = self.query_embedding(Variable(torch.LongTensor(bsize*[0]).cuda())).unsqueeze(2) #(bsize, 2*nhid, 1) 532 | 533 | Temp = 2 534 | keys = sent_key_proj.bmm(sent_w).squeeze(2) / Temp 535 | 536 | # Set probas of padding to zero in softmax 537 | keys = keys + ((keys == 0).float()*-10000) 538 | 539 | alphas = self.softmax(keys/Temp).unsqueeze(2).expand_as(sent_output) 540 | if int(time.time()) % 100 == 0: 541 | print('w', torch.max(sent_w), torch.min(sent_w)) 542 | print('alphas', alphas[0, :, 0]) 543 | emb = torch.sum(alphas * sent_output_proj, 1).squeeze(1) 544 | 545 | return emb 546 | 547 | 548 | """ 549 | Inner attention inspired from "Self-attentive ..." 550 | """ 551 | 552 | 553 | class InnerAttentionMILAEncoder(nn.Module): 554 | def __init__(self, config): 555 | super(InnerAttentionMILAEncoder, self).__init__() 556 | self.bsize = config['bsize'] 557 | self.word_emb_dim = config['word_emb_dim'] 558 | self.enc_lstm_dim = config['enc_lstm_dim'] 559 | self.pool_type = config['pool_type'] 560 | 561 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 562 | bidirectional=True) 563 | self.init_lstm = Variable(torch.FloatTensor(2, self.bsize, 564 | self.enc_lstm_dim).zero_()).cuda() 565 | 566 | self.proj_key = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 567 | bias=False) 568 | self.proj_lstm = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 569 | bias=False) 570 | self.query_embedding = nn.Embedding(2, 2*self.enc_lstm_dim) 571 | self.softmax = nn.Softmax() 572 | 573 | def forward(self, sent_tuple): 574 | # sent_len: [max_len, ..., min_len] (batch) 575 | # sent: Variable(seqlen x batch x worddim) 576 | 577 | sent, sent_len = sent_tuple 578 | bsize = sent.size(1) 579 | 580 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 581 | Variable(torch.FloatTensor(2, bsize, self.enc_lstm_dim).zero_()).cuda() 582 | 583 | # Sort by length (keep idx) 584 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 585 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 586 | # Handling padding in Recurrent Networks 587 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 588 | sent_output = self.enc_lstm(sent_packed, 589 | (self.init_lstm, self.init_lstm))[0] 590 | # seqlen x batch x 2*nhid 591 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 592 | # Un-sort by length 593 | idx_unsort = np.argsort(idx_sort) 594 | sent_output = sent_output.index_select(1, 595 | Variable(torch.cuda.LongTensor(idx_unsort))) 596 | 597 | sent_output = sent_output.transpose(0,1).contiguous() 598 | sent_output_proj = self.proj_lstm(sent_output.view(-1, 599 | 2*self.enc_lstm_dim)).view(bsize, -1, 2*self.enc_lstm_dim) 600 | sent_key_proj = self.proj_key(sent_output.view(-1, 601 | 2*self.enc_lstm_dim)).view(bsize, -1, 2*self.enc_lstm_dim) 602 | sent_key_proj = torch.tanh(sent_key_proj) 603 | # NAACL : u_it=tanh(W_w.h_it + b_w) like in NAACL paper 604 | 605 | # Temperature 606 | Temp = 3 607 | 608 | sent_w1 = self.query_embedding(Variable(torch.LongTensor(bsize*[0]).cuda())).unsqueeze(2) #(bsize, nhid, 1) 609 | keys1 = sent_key_proj.bmm(sent_w1).squeeze(2) / Temp 610 | keys1 = keys1 + ((keys1 == 0).float()*-1000) 611 | alphas1 = self.softmax(keys1).unsqueeze(2).expand_as(sent_key_proj) 612 | emb1 = torch.sum(alphas1 * sent_output_proj, 1).squeeze(1) 613 | 614 | 615 | sent_w2 = self.query_embedding(Variable(torch.LongTensor(bsize*[1]).cuda())).unsqueeze(2) #(bsize, nhid, 1) 616 | keys2 = sent_key_proj.bmm(sent_w2).squeeze(2) / Temp 617 | keys2 = keys2 + ((keys2 == 0).float()*-1000) 618 | alphas2 = self.softmax(keys2).unsqueeze(2).expand_as(sent_key_proj) 619 | emb2 = torch.sum(alphas2 * sent_output_proj, 1).squeeze(1) 620 | 621 | sent_w3 = self.query_embedding(Variable(torch.LongTensor(bsize*[1]).cuda())).unsqueeze(2) #(bsize, nhid, 1) 622 | keys3 = sent_key_proj.bmm(sent_w3).squeeze(2) / Temp 623 | keys3 = keys3 + ((keys3 == 0).float()*-1000) 624 | alphas3 = self.softmax(keys3).unsqueeze(2).expand_as(sent_key_proj) 625 | emb3 = torch.sum(alphas3 * sent_output_proj, 1).squeeze(1) 626 | 627 | sent_w4 = self.query_embedding(Variable(torch.LongTensor(bsize*[1]).cuda())).unsqueeze(2) #(bsize, nhid, 1) 628 | keys4 = sent_key_proj.bmm(sent_w4).squeeze(2) / Temp 629 | keys4 = keys4 + ((keys4 == 0).float()*-1000) 630 | alphas4 = self.softmax(keys4).unsqueeze(2).expand_as(sent_key_proj) 631 | emb4 = torch.sum(alphas4 * sent_output_proj, 1).squeeze(1) 632 | 633 | 634 | if int(time.time()) % 100 == 0: 635 | print('alphas', torch.cat((alphas1.data[0, :, 0], 636 | alphas2.data[0, :, 0], 637 | torch.abs(alphas1.data[0, :, 0] - 638 | alphas2.data[0, :, 0])), 1)) 639 | 640 | emb = torch.cat((emb1, emb2, emb3, emb4), 1) 641 | return emb 642 | 643 | 644 | """ 645 | Inner attention from Yang et al. 646 | """ 647 | 648 | 649 | class InnerAttentionYANGEncoder(nn.Module): 650 | def __init__(self, config): 651 | super(InnerAttentionYANGEncoder, self).__init__() 652 | self.bsize = config['bsize'] 653 | self.word_emb_dim = config['word_emb_dim'] 654 | self.enc_lstm_dim = config['enc_lstm_dim'] 655 | self.pool_type = config['pool_type'] 656 | 657 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 658 | bidirectional=True) 659 | self.init_lstm = Variable(torch.FloatTensor(2, self.bsize, 660 | self.enc_lstm_dim).zero_()).cuda() 661 | 662 | self.proj_lstm = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 663 | bias=True) 664 | self.proj_query = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 665 | bias=True) 666 | self.proj_enc = nn.Linear(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, 667 | bias=True) 668 | 669 | self.query_embedding = nn.Embedding(1, 2*self.enc_lstm_dim) 670 | self.softmax = nn.Softmax() 671 | 672 | def forward(self, sent_tuple): 673 | # sent_len: [max_len, ..., min_len] (batch) 674 | # sent: Variable(seqlen x batch x worddim) 675 | 676 | sent, sent_len = sent_tuple 677 | bsize = sent.size(1) 678 | 679 | self.init_lstm = self.init_lstm if bsize == self.init_lstm.size(1) else \ 680 | Variable(torch.FloatTensor(2, bsize, self.enc_lstm_dim).zero_()).cuda() 681 | 682 | # Sort by length (keep idx) 683 | sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 684 | sent = sent.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) 685 | # Handling padding in Recurrent Networks 686 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len) 687 | sent_output = self.enc_lstm(sent_packed, 688 | (self.init_lstm, self.init_lstm))[0] 689 | # seqlen x batch x 2*nhid 690 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 691 | # Un-sort by length 692 | idx_unsort = np.argsort(idx_sort) 693 | sent_output = sent_output.index_select(1, 694 | Variable(torch.cuda.LongTensor(idx_unsort))) 695 | 696 | sent_output = sent_output.transpose(0,1).contiguous() 697 | 698 | sent_output_proj = self.proj_lstm(sent_output.view(-1, 699 | 2*self.enc_lstm_dim)).view(bsize, -1, 2*self.enc_lstm_dim) 700 | 701 | sent_keys = self.proj_enc(sent_output.view(-1, 702 | 2*self.enc_lstm_dim)).view(bsize, -1, 2*self.enc_lstm_dim) 703 | 704 | sent_max = torch.max(sent_output, 1)[0].squeeze(1) # (bsize, 2*nhid) 705 | sent_summary = self.proj_query( 706 | sent_max).unsqueeze(1).expand_as(sent_keys) 707 | # (bsize, seqlen, 2*nhid) 708 | 709 | sent_M = torch.tanh(sent_keys + sent_summary) 710 | # (bsize, seqlen, 2*nhid) YANG : M = tanh(Wh_i + Wh_avg 711 | sent_w = self.query_embedding(Variable(torch.LongTensor( 712 | bsize*[0]).cuda())).unsqueeze(2) # (bsize, 2*nhid, 1) 713 | 714 | sent_alphas = self.softmax(sent_M.bmm(sent_w).squeeze(2)).unsqueeze(1) 715 | # (bsize, 1, seqlen) 716 | 717 | if int(time.time()) % 200 == 0: 718 | print('w', torch.max(sent_w[0]), torch.min(sent_w[0])) 719 | print('alphas', sent_alphas[0][0][0:sent_len[0]]) 720 | # Get attention vector 721 | emb = sent_alphas.bmm(sent_output_proj).squeeze(1) 722 | 723 | return emb 724 | 725 | 726 | 727 | """ 728 | Hierarchical ConvNet 729 | """ 730 | class ConvNetEncoder(nn.Module): 731 | def __init__(self, config): 732 | super(ConvNetEncoder, self).__init__() 733 | 734 | self.bsize = config['bsize'] 735 | self.word_emb_dim = config['word_emb_dim'] 736 | self.enc_lstm_dim = config['enc_lstm_dim'] 737 | self.pool_type = config['pool_type'] 738 | 739 | self.convnet1 = nn.Sequential( 740 | nn.Conv1d(self.word_emb_dim, 2*self.enc_lstm_dim, kernel_size=3, 741 | stride=1, padding=1), 742 | nn.ReLU(inplace=True), 743 | ) 744 | self.convnet2 = nn.Sequential( 745 | nn.Conv1d(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, kernel_size=3, 746 | stride=1, padding=1), 747 | nn.ReLU(inplace=True), 748 | ) 749 | self.convnet3 = nn.Sequential( 750 | nn.Conv1d(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, kernel_size=3, 751 | stride=1, padding=1), 752 | nn.ReLU(inplace=True), 753 | ) 754 | self.convnet4 = nn.Sequential( 755 | nn.Conv1d(2*self.enc_lstm_dim, 2*self.enc_lstm_dim, kernel_size=3, 756 | stride=1, padding=1), 757 | nn.ReLU(inplace=True), 758 | ) 759 | 760 | 761 | 762 | def forward(self, sent_tuple): 763 | # sent_len: [max_len, ..., min_len] (batch) 764 | # sent: Variable(seqlen x batch x worddim) 765 | 766 | sent, sent_len = sent_tuple 767 | 768 | sent = sent.transpose(0,1).transpose(1,2).contiguous() 769 | # batch, nhid, seqlen) 770 | 771 | sent = self.convnet1(sent) 772 | u1 = torch.max(sent, 2)[0] 773 | 774 | sent = self.convnet2(sent) 775 | u2 = torch.max(sent, 2)[0] 776 | 777 | sent = self.convnet3(sent) 778 | u3 = torch.max(sent, 2)[0] 779 | 780 | sent = self.convnet4(sent) 781 | u4 = torch.max(sent, 2)[0] 782 | 783 | emb = torch.cat((u1, u2, u3, u4), 1) 784 | 785 | return emb 786 | 787 | 788 | """ 789 | Main module for Natural Language Inference 790 | """ 791 | 792 | 793 | class NLINet(nn.Module): 794 | def __init__(self, config): 795 | super(NLINet, self).__init__() 796 | 797 | # classifier 798 | self.nonlinear_fc = config['nonlinear_fc'] 799 | self.fc_dim = config['fc_dim'] 800 | self.n_classes = config['n_classes'] 801 | self.enc_lstm_dim = config['enc_lstm_dim'] 802 | self.encoder_type = config['encoder_type'] 803 | self.dpout_fc = config['dpout_fc'] 804 | 805 | self.encoder = eval(self.encoder_type)(config) 806 | self.inputdim = 4*2*self.enc_lstm_dim 807 | self.inputdim = 4*self.inputdim if self.encoder_type in \ 808 | ["ConvNetEncoder", "InnerAttentionMILAEncoder"] else self.inputdim 809 | self.inputdim = self.inputdim/2 if self.encoder_type == "LSTMEncoder" \ 810 | else self.inputdim 811 | if self.nonlinear_fc: 812 | self.classifier = nn.Sequential( 813 | nn.Dropout(p=self.dpout_fc), 814 | nn.Linear(self.inputdim, self.fc_dim), 815 | nn.Tanh(), 816 | nn.Dropout(p=self.dpout_fc), 817 | nn.Linear(self.fc_dim, self.fc_dim), 818 | nn.Tanh(), 819 | nn.Dropout(p=self.dpout_fc), 820 | nn.Linear(self.fc_dim, self.n_classes), 821 | ) 822 | else: 823 | self.classifier = nn.Sequential( 824 | nn.Linear(self.inputdim, self.fc_dim), 825 | nn.Linear(self.fc_dim, self.fc_dim), 826 | nn.Linear(self.fc_dim, self.n_classes) 827 | ) 828 | 829 | def forward(self, s1, s2): 830 | # s1 : (s1, s1_len) 831 | u = self.encoder(s1) 832 | v = self.encoder(s2) 833 | 834 | features = torch.cat((u, v, torch.abs(u-v), u*v), 1) 835 | output = self.classifier(features) 836 | return output 837 | 838 | def encode(self, s1): 839 | emb = self.encoder(s1) 840 | return emb 841 | 842 | 843 | """ 844 | Main module for Classification 845 | """ 846 | 847 | 848 | class ClassificationNet(nn.Module): 849 | def __init__(self, config): 850 | super(ClassificationNet, self).__init__() 851 | 852 | # classifier 853 | self.nonlinear_fc = config['nonlinear_fc'] 854 | self.fc_dim = config['fc_dim'] 855 | self.n_classes = config['n_classes'] 856 | self.enc_lstm_dim = config['enc_lstm_dim'] 857 | self.encoder_type = config['encoder_type'] 858 | self.dpout_fc = config['dpout_fc'] 859 | 860 | self.encoder = eval(self.encoder_type)(config) 861 | self.inputdim = 2*self.enc_lstm_dim 862 | self.inputdim = 4*self.inputdim if self.encoder_type == "ConvNetEncoder" else self.inputdim 863 | self.inputdim = self.enc_lstm_dim if self.encoder_type =="LSTMEncoder" else self.inputdim 864 | self.classifier = nn.Sequential( 865 | nn.Linear(self.inputdim, 512), 866 | nn.Linear(512, self.n_classes), 867 | ) 868 | 869 | def forward(self, s1): 870 | # s1 : (s1, s1_len) 871 | u = self.encoder(s1) 872 | 873 | output = self.classifier(u) 874 | return output 875 | 876 | def encode(self, s1): 877 | emb = self.encoder(s1) 878 | return emb 879 | -------------------------------------------------------------------------------- /models/language_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from .gan_models import RNN_LM 6 | import numpy as np 7 | import pickle 8 | import time 9 | from tqdm import tqdm 10 | 11 | def repackage_hidden(h): 12 | if isinstance(h, tuple): 13 | return tuple(repackage_hidden(v) for v in h) 14 | else: 15 | return h.detach() 16 | 17 | def batchify(data, bsz, use_cuda): 18 | nbatch = data.size(0) // bsz 19 | data = data.narrow(0, 0, nbatch * bsz) 20 | data = data.view(bsz, -1).t().contiguous() 21 | if use_cuda: 22 | data = data.cuda() 23 | return data 24 | 25 | def get_batch(source, i, seq_len): 26 | seq_len = min(seq_len, len(source) - 1 - i) 27 | data = Variable(source[i:i + seq_len]) 28 | target = Variable(source[i + 1: i + 1 + seq_len]) 29 | return data, target 30 | 31 | class LanguageModel: 32 | def __init__(self, vocab_size, embed_dim, corpus, hparams): 33 | self.vocab_size = vocab_size 34 | self.hparams = hparams 35 | self.num_epochs = hparams["num_epochs"] 36 | self.batch_size = hparams["batch_size"] 37 | self.bptt = hparams["bptt"] 38 | self.log_interval = hparams["log_interval"] 39 | self.save_path = hparams["save_path"] 40 | lr = hparams["lr"] 41 | wdecay = hparams["wdecay"] 42 | self.hparams = hparams 43 | self.use_cuda = torch.cuda.is_available() 44 | 45 | self.train_data = batchify(corpus.train, self.batch_size, self.use_cuda) 46 | self.valid_data = batchify(corpus.valid, self.batch_size, self.use_cuda) 47 | self.lm = RNN_LM(vocab_size, embed_dim, hparams, self.use_cuda) 48 | if self.use_cuda: 49 | self.lm.cuda() 50 | self.lm.set_embed(self._variable(corpus.glove_embed)) 51 | model_parameters = filter(lambda p: p.requires_grad, self.lm.parameters()) 52 | self.optimizer = optim.Adam(model_parameters, 53 | lr=lr, weight_decay=wdecay) 54 | self.loss_fn = nn.CrossEntropyLoss() 55 | 56 | def _variable(self, data): 57 | data = np.array(data) 58 | data = Variable(torch.from_numpy(data)) 59 | return data.cuda() if self.use_cuda else data 60 | 61 | def logging(self, s, print_=True): 62 | if print_: 63 | print(s) 64 | 65 | def train(self, epoch): 66 | self.lm.train() 67 | 68 | total_loss = 0 69 | hidden = self.lm.init_hidden(self.batch_size) 70 | i, batch = np.random.choice(self.bptt), 0 71 | start_time = time.time() 72 | 73 | while i < self.train_data.size(0) - 1 - 1: 74 | data, target = get_batch(self.train_data, i, self.bptt) 75 | self.optimizer.zero_grad() 76 | 77 | output, hidden = self.lm(data, hidden) 78 | hidden = repackage_hidden(hidden) 79 | 80 | loss = self.loss_fn(output.view(-1, self.vocab_size), target.view(-1)) 81 | total_loss += loss.data 82 | loss.backward() 83 | self.optimizer.step() 84 | 85 | if batch % self.log_interval == 0 and batch > 0: 86 | cur_loss = total_loss.item() / self.log_interval 87 | elapsed = time.time() - start_time 88 | self.logging('| epoch {:2d} | {:5d}/{:5d} batches | {:5.2f} ms/batch | ' 89 | 'loss {:.5f} | ppl {:3.3f}'.format( 90 | epoch, batch, len(self.train_data) // self.bptt, 91 | elapsed * 1000 / self.batch_size, cur_loss, np.exp(cur_loss))) 92 | total_loss = 0 93 | start_time = time.time() 94 | 95 | batch += 1 96 | i += self.bptt 97 | 98 | def evaluate(self, data_source): 99 | self.lm.eval() 100 | 101 | total_loss = 0 102 | hidden = self.lm.init_hidden(self.batch_size) 103 | with torch.no_grad(): 104 | for i in range(0, data_source.size(0) - 1, self.bptt): 105 | data, target = get_batch(data_source, i, self.bptt) 106 | 107 | output, hidden = self.lm(data, hidden) 108 | hidden = repackage_hidden(hidden) 109 | 110 | loss = self.loss_fn(output.view(-1, self.vocab_size), target.view(-1)) 111 | total_loss += loss.data * len(data) 112 | return total_loss.item() / len(data_source) 113 | 114 | def fit(self): 115 | best_valid_loss = np.inf 116 | best_valid_epoch = 0 117 | for epoch in range(1, self.num_epochs + 1): 118 | epoch_start_time = time.time() 119 | self.train(epoch) 120 | val_loss = self.evaluate(self.valid_data) 121 | self.logging('-' * 50) 122 | self.logging('| end of epoch {:2d} | time {:5.2f}s | valid loss {:.5f} | ' 123 | 'valid ppl {:3.3f}'.format(epoch, (time.time() - epoch_start_time), 124 | val_loss, np.exp(val_loss))) 125 | self.logging('-' * 50) 126 | if val_loss < best_valid_loss: 127 | self.save(self.save_path) 128 | best_valid_loss = val_loss 129 | best_valid_epoch = epoch 130 | elif epoch - best_valid_epoch > 5: 131 | break 132 | return best_valid_loss 133 | 134 | def save(self, path): 135 | torch.save(self.lm.state_dict(), path + ".pt") 136 | with open(path + ".pkl", "wb") as f: 137 | pickle.dump(self.hparams, f, -1) 138 | 139 | def load(self, path): 140 | self.lm.load_state_dict(torch.load(path)) 141 | 142 | class LMCoherence: 143 | def __init__(self, forward_lm, backward_lm, corpus): 144 | self.forward_lm = forward_lm 145 | self.backward_lm = backward_lm 146 | self.corpus = corpus 147 | self.use_cuda = torch.cuda.is_available() 148 | 149 | self.loss = nn.CrossEntropyLoss() 150 | 151 | def score_article(self, sentences, reverse=False): 152 | vocab_size = len(self.corpus.vocab) 153 | lm = self.backward_lm if reverse else self.forward_lm 154 | 155 | if reverse: 156 | sentences = sentences[::-1] 157 | 158 | if reverse: 159 | sentences_inds = [[self.corpus.vocab[w] 160 | for w in [''] + sent.split()[::-1] + ['']] 161 | for sent in sentences] 162 | else: 163 | sentences_inds = [[self.corpus.vocab[w] 164 | for w in [''] + sent.split() + ['']] 165 | for sent in sentences] 166 | 167 | scores = [] 168 | hidden_f = lm.init_hidden(1) 169 | for s in sentences_inds: 170 | s = torch.LongTensor(s).unsqueeze(1) 171 | if self.use_cuda: 172 | s = s.to('cuda') 173 | x = s[:-1] 174 | y = s[1:].squeeze() 175 | 176 | c_f_outs, hidden_f = lm(x, hidden_f) 177 | c_loss_f = self.loss(c_f_outs.view(-1, vocab_size), y.view(-1)) 178 | 179 | scores.append(- c_loss_f.item()) 180 | 181 | return np.mean(scores) 182 | 183 | def evaluate_dis(self, test, df): 184 | self.forward_lm.eval() 185 | self.backward_lm.eval() 186 | 187 | correct_pred = 0 188 | total = 0 189 | for article in tqdm(test): 190 | if total % 2000 == 0 and total: 191 | print(correct_pred / total) 192 | sentences = df.loc[article[0], "sentences"].split("") 193 | neg_sentences_list = df.loc[article[0], "neg_list"].split("") 194 | neg_sentences_list = [s.split('') for s in neg_sentences_list] 195 | 196 | pos_score_f = self.score_article(sentences) 197 | pos_score_b = self.score_article(sentences, True) 198 | pos_score = pos_score_f + pos_score_b 199 | 200 | for neg_sentences in neg_sentences_list: 201 | neg_score_f = self.score_article(neg_sentences) 202 | neg_score_b = self.score_article(neg_sentences, True) 203 | neg_score = neg_score_f + neg_score_b 204 | if pos_score > neg_score: 205 | correct_pred += 1 206 | total += 1 207 | return correct_pred / total 208 | 209 | def evaluate_ins(self, test, df): 210 | self.forward_lm.eval() 211 | self.backward_lm.eval() 212 | 213 | correct_pred = 0.0 214 | total = 0 215 | for article in tqdm(test): 216 | if total % 100 == 0 and total: 217 | print(correct_pred / total) 218 | sentences = df.loc[article[0], "sentences"].split("") 219 | sent_num = len(sentences) 220 | 221 | pos_score_f = self.score_article(sentences) 222 | pos_score_b = self.score_article(sentences, True) 223 | pos_score = pos_score_f + pos_score_b 224 | 225 | cnt = 0.0 226 | for i in range(sent_num): 227 | tmp = sentences[:i] + sentences[i + 1:] 228 | flag = True 229 | for j in range(sent_num): 230 | if j == i: 231 | continue 232 | neg_sentences = tmp[:j] + sentences[i:i + 1] + tmp[j:] 233 | neg_score_f = self.score_article(neg_sentences) 234 | neg_score_b = self.score_article(neg_sentences, True) 235 | neg_score = neg_score_f + neg_score_b 236 | if pos_score < neg_score: 237 | flag = False 238 | if flag: 239 | cnt += 1.0 240 | correct_pred += cnt / sent_num 241 | total += 1 242 | return correct_pred / total 243 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tarfile 4 | 5 | def download_file_from_google_drive(id, destination): 6 | URL = "https://docs.google.com/uc?export=download" 7 | 8 | session = requests.Session() 9 | 10 | response = session.get(URL, params={'id': id}, stream=True) 11 | token = get_confirm_token(response) 12 | 13 | if token: 14 | params = {'id': id, 'confirm': token} 15 | response = session.get(URL, params=params, stream=True) 16 | 17 | save_response_content(response, destination) 18 | 19 | def get_confirm_token(response): 20 | for key, value in response.cookies.items(): 21 | if key.startswith('download_warning'): 22 | return value 23 | 24 | return None 25 | 26 | def save_response_content(response, destination): 27 | CHUNK_SIZE = 32768 28 | 29 | with open(destination, "wb") as f: 30 | for chunk in response.iter_content(CHUNK_SIZE): 31 | if chunk: # filter out keep-alive new chunks 32 | f.write(chunk) 33 | 34 | print("Downloading WikiCoherence Corpus...") 35 | download_file_from_google_drive("1Il9mZt111kRAkzy8IXirp_7NUZ2dYAs2", "WikiCoherence.tar.gz") 36 | tar = tarfile.open("WikiCoherence.tar.gz", "r:gz") 37 | tar.extractall() 38 | tar.close() 39 | os.remove("WikiCoherence.tar.gz") 40 | 41 | print("Downloading GloVe embeddings...") 42 | os.system("wget http://nlp.stanford.edu/data/glove.840B.300d.zip") 43 | os.system("unzip glove.840B.300d.zip") 44 | os.system("rm glove.840B.300d.zip") 45 | os.system("mv glove.840B.300d.txt data") 46 | 47 | print("Downloading infersent pre-trained models...") 48 | os.system("curl -Lo data/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl") 49 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import config 4 | from utils.logging_utils import _set_basic_logging 5 | from utils.data_utils import DataSet 6 | from models.infersent_models import InferSent 7 | from models.language_models import LanguageModel 8 | import torch 9 | import numpy as np 10 | import random 11 | import copy 12 | import itertools 13 | import pickle 14 | from tqdm import tqdm 15 | import argparse 16 | 17 | def permute_articles(cliques, num_perm): 18 | permuted_articles = [] 19 | for clique in cliques: 20 | clique = list(clique) 21 | old_clique = copy.deepcopy(clique) 22 | random.shuffle(clique) 23 | perms = itertools.permutations(clique) 24 | inner_perm = [] 25 | i = 0 26 | for perm in perms: 27 | comparator = [old_sent == sent for old_sent, sent 28 | in zip(old_clique, perm)] 29 | if not np.all(comparator): 30 | inner_perm.append(list(perm)) 31 | i += 1 32 | if i >= num_perm: 33 | break 34 | permuted_articles.append(inner_perm) 35 | return permuted_articles 36 | 37 | def permute_articles_with_replacement(cliques, num_perm): 38 | permuted_articles = [] 39 | for clique in cliques: 40 | clique = list(clique) 41 | old_clique = copy.deepcopy(clique) 42 | inner_perm = [] 43 | i = 0 44 | while i < num_perm: 45 | random_perm = copy.deepcopy(clique) 46 | random.shuffle(random_perm) 47 | comparator = [old_sent == sent for old_sent, sent 48 | in zip(old_clique, random_perm)] 49 | if not np.all(comparator): 50 | inner_perm.append(random_perm) 51 | i += 1 52 | if i >= num_perm: 53 | break 54 | permuted_articles.append(inner_perm) 55 | return permuted_articles 56 | 57 | def prep_wsj_lm_data(data_path): 58 | train_list = ['00', '01', '02', '03', '04', '05', '06', 59 | '07', '08', '09', '10'] 60 | 61 | valid_list = ['11', '12', '13'] 62 | 63 | test_list = ['14', '15', '16', '17', '18', '19', '20', 64 | '21', '22', '23', '24'] 65 | 66 | datasets = [('train', train_list), 67 | ('valid', valid_list), 68 | ('test', test_list)] 69 | 70 | for dname, dlist in datasets: 71 | with open(os.path.join('./', dname+'.txt'), 'w') as wr: 72 | for dirname in os.listdir(data_path): 73 | 74 | if dirname in dlist: 75 | print(dname, dirname) 76 | subdirpath = os.path.join(data_path, dirname) 77 | 78 | for filename in os.listdir(subdirpath): 79 | fname = os.path.join(subdirpath, filename) 80 | 81 | with open(fname) as fr: 82 | wr.write(""+"\n") 83 | wr.write(fr.read().strip()+'\n') 84 | wr.write(""+"\n") 85 | 86 | def load_wsj_file_list(data_path): 87 | dir_list = ['00', '01', '02', '03', '04', '05', '06', '07', '08', 88 | '09', '10', '11', '12', '13', '14', '15', '16', '17', 89 | '18', '19', '20', '21', '22', '23', '24'] 90 | 91 | file_list = [] 92 | for dirname in os.listdir(data_path): 93 | if dirname in dir_list: 94 | subdirpath = os.path.join(data_path, dirname) 95 | for filename in os.listdir(subdirpath): 96 | file_list.append(os.path.join(subdirpath, filename)) 97 | return file_list 98 | 99 | def load_wiki_file_list(data_path, dir_list): 100 | file_list = [] 101 | for dirname in os.listdir(data_path): 102 | if dirname in dir_list: 103 | subdirpath = os.path.join(data_path, dirname) 104 | file_list.append(os.path.join(subdirpath, "extracted_paras.txt")) 105 | return file_list 106 | 107 | def load_file_list(data_name, if_sample): 108 | if data_name in ["wsj", "wsj_bigram", "wsj_trigram"]: 109 | if if_sample: 110 | return load_wsj_file_list(config.SAMPLE_WSJ_DATA_PATH) 111 | return load_wsj_file_list(config.WSJ_DATA_PATH) 112 | elif data_name in ["wiki_random", "wiki_bigram_easy"]: 113 | dir_list = config.WIKI_EASY_TRAIN_LIST + config.WIKI_EASY_TEST_LIST 114 | if if_sample: 115 | return load_wiki_file_list(config.SAMPLE_WIKI_DATA_PATH, dir_list) 116 | return load_wiki_file_list(config.WIKI_EASY_DATA_PATH, dir_list) 117 | elif (data_name in ["wiki_domain"]) or ("wiki_bigram" in data_name): 118 | category = data_name[12:] 119 | if category in config.WIKI_OUT_DOMAIN: 120 | dir_list = config.WIKI_IN_DOMAIN + [category] 121 | else: 122 | dir_list = config.WIKI_IN_DOMAIN 123 | if if_sample: 124 | return load_wiki_file_list(config.SAMPLE_WIKI_DATA_PATH, dir_list) 125 | return load_wiki_file_list(config.WIKI_DATA_PATH, dir_list) 126 | else: 127 | raise ValueError("Invalid data name!") 128 | 129 | def get_infersent(data_name, on_gpu=True, if_sample=False, return_model=False): 130 | logging.info("Start parsing...") 131 | file_list = load_file_list(data_name, if_sample) 132 | 133 | sentences = [] 134 | for file_path in file_list: 135 | with open(file_path) as f: 136 | for line in f: 137 | line = line.strip() 138 | if (line != '') and (line != ''): 139 | sentences.append(line) 140 | logging.info("%d sentences in total." % len(sentences)) 141 | 142 | logging.info("Loading infersent models...") 143 | params = { 144 | 'bsize': 64, 145 | 'word_emb_dim': 300, 146 | 'enc_lstm_dim': 2048, 147 | 'pool_type': 'max', 148 | 'dpout_model': 0.0, 149 | 'version': 1 150 | } 151 | model = InferSent(params) 152 | model.load_state_dict(torch.load(config.INFERSENT_MODEL)) 153 | model.set_w2v_path(config.WORD_EMBEDDING) 154 | vocab_size = 10000 if if_sample else 2196017 155 | model.build_vocab_k_words(K=vocab_size) 156 | if on_gpu: 157 | model.cuda() 158 | 159 | logging.info("Encoding sentences...") 160 | embeddings = model.encode( 161 | sentences, 128, config.MAX_SENT_LENGTH, tokenize=False, verbose=True) 162 | logging.info("number of sentences encoded: %d" % len(embeddings)) 163 | 164 | assert len(sentences) == len(embeddings), "Lengths don't match!" 165 | embed_dict = dict(zip(sentences, embeddings)) 166 | np.random.seed(0) 167 | embed_dict[""] = np.random.uniform(size=4096).astype(np.float32) 168 | embed_dict[""] = np.random.uniform(size=4096).astype(np.float32) 169 | 170 | if return_model: 171 | return embed_dict, model 172 | else: 173 | return embed_dict 174 | 175 | def get_average_glove(data_name, if_sample=False): 176 | logging.info("Start parsing...") 177 | file_list = load_file_list(data_name, if_sample) 178 | 179 | sentences = [] 180 | for file_path in file_list: 181 | with open(file_path) as f: 182 | for line in f: 183 | line = line.strip() 184 | if (line != '') and (line != ''): 185 | sentences.append(line) 186 | logging.info("%d sentences in total." % len(sentences)) 187 | 188 | logging.info("Loading glove...") 189 | word_vec = {} 190 | with open(config.WORD_EMBEDDING) as f: 191 | for line in f: 192 | word, vec = line.split(' ', 1) 193 | word_vec[word] = np.fromstring(vec, sep=' ') 194 | 195 | embed_dict = {} 196 | for s in sentences: 197 | tokens = s.split() 198 | embed_dict[s] = np.zeros(300, dtype=np.float32) 199 | sent_len = 0 200 | for token in tokens: 201 | if token in word_vec: 202 | embed_dict[s] += word_vec[token] 203 | sent_len += 1 204 | if sent_len > 0: 205 | embed_dict[s] = np.true_divide(embed_dict[s], sent_len) 206 | np.random.seed(0) 207 | embed_dict[""] = np.random.uniform(size=300).astype(np.float32) 208 | embed_dict[""] = np.random.uniform(size=300).astype(np.float32) 209 | return embed_dict 210 | 211 | def get_lm_hidden(data_name, lm_name, corpus): 212 | logging.info("Start parsing...") 213 | file_list = load_file_list(data_name, False) 214 | 215 | sentences = [] 216 | for file_path in file_list: 217 | with open(file_path) as f: 218 | for line in f: 219 | line = line.strip() 220 | if (line != '') and (line != ''): 221 | sentences.append(line) 222 | logging.info("%d sentences in total." % len(sentences)) 223 | 224 | with open(os.path.join(config.CHECKPOINT_PATH, lm_name + "_forward.pkl"), "rb") as f: 225 | hparams = pickle.load(f) 226 | 227 | kwargs = { 228 | "vocab_size": corpus.glove_embed.shape[0], 229 | "embed_dim": corpus.glove_embed.shape[1], 230 | "corpus": corpus, 231 | "hparams": hparams, 232 | } 233 | 234 | forward_lm = LanguageModel(**kwargs) 235 | forward_lm.load(os.path.join(config.CHECKPOINT_PATH, lm_name + "_forward.pt")) 236 | forward_lm = forward_lm.lm 237 | forward_lm.eval() 238 | 239 | backward_lm = LanguageModel(**kwargs) 240 | backward_lm.load(os.path.join(config.CHECKPOINT_PATH, lm_name + "_backward.pt")) 241 | backward_lm = backward_lm.lm 242 | backward_lm.eval() 243 | 244 | embed_dict = {} 245 | ini_hidden = forward_lm.init_hidden(1) 246 | for sent in tqdm(sentences): 247 | fs = [corpus.vocab[w] for w in [''] + sent.split() + ['']] 248 | fs = torch.LongTensor(fs).unsqueeze(1) 249 | fs = fs.to('cuda') 250 | fout = forward_lm.encode(fs, ini_hidden) 251 | fout = torch.max(fout, 0)[0].squeeze().data.cpu().numpy().astype(np.float32) 252 | 253 | bs = [corpus.vocab[w] for w in [''] + sent.split()[::-1] + ['']] 254 | bs = torch.LongTensor(bs).unsqueeze(1) 255 | bs = bs.to('cuda') 256 | bout = backward_lm.encode(bs, ini_hidden) 257 | bout = torch.max(bout, 0)[0].squeeze().data.cpu().numpy().astype(np.float32) 258 | 259 | embed_dict[sent] = np.hstack((fout, bout)) 260 | np.random.seed(0) 261 | embed_dict[""] = np.random.uniform(size=2048).astype(np.float32) 262 | embed_dict[""] = np.random.uniform(size=2048).astype(np.float32) 263 | return embed_dict 264 | 265 | def get_s2s_hidden(data_name, model_name, corpus): 266 | logging.info("Start parsing...") 267 | file_list = load_file_list(data_name, False) 268 | 269 | sentences = [] 270 | for file_path in file_list: 271 | with open(file_path) as f: 272 | for line in f: 273 | line = line.strip() 274 | if (line != '') and (line != ''): 275 | sentences.append(line) 276 | logging.info("%d sentences in total." % len(sentences)) 277 | 278 | with open(os.path.join(config.CHECKPOINT_PATH, model_name + "_forward.pkl"), "rb") as f: 279 | hparams = pickle.load(f) 280 | 281 | kwargs = { 282 | "vocab_size": corpus.glove_embed.shape[0], 283 | "embed_dim": corpus.glove_embed.shape[1], 284 | "corpus": corpus, 285 | "hparams": hparams, 286 | } 287 | 288 | forward_model = Seq2SeqModel(**kwargs) 289 | forward_model.load(os.path.join(config.CHECKPOINT_PATH, model_name + "_forward.pt")) 290 | forward_model = forward_model.model 291 | forward_model.eval() 292 | 293 | backward_model = Seq2SeqModel(**kwargs) 294 | backward_model.load(os.path.join(config.CHECKPOINT_PATH, model_name + "_backward.pt")) 295 | backward_model = backward_model.model 296 | backward_model.eval() 297 | 298 | embed_dict = {} 299 | for sent in tqdm(sentences): 300 | fs = [corpus.vocab[w] for w in sent.split() + ['']] 301 | # fs_len = torch.LongTensor([len(fs)]) 302 | # fs_len = fs_len.to('cuda') 303 | fs = torch.LongTensor(fs).unsqueeze(0) 304 | fs = fs.to('cuda') 305 | # fout = forward_model.encoding(fs, fs_len) 306 | fout = forward_model.encode(fs) 307 | # fout = fout.squeeze().data.cpu().numpy().astype(np.float32) 308 | fout = torch.max(fout, 1)[0].squeeze().data.cpu().numpy().astype(np.float32) 309 | 310 | bs = [corpus.vocab[w] for w in sent.split()[::-1] + ['']] 311 | # bs_len = torch.LongTensor([len(bs)]) 312 | # bs_len = bs_len.to('cuda') 313 | bs = torch.LongTensor(bs).unsqueeze(0) 314 | bs = bs.to('cuda') 315 | # bout = backward_model.encoding(bs, bs_len) 316 | bout = backward_model.encode(bs) 317 | # bout = bout.squeeze().data.cpu().numpy().astype(np.float32) 318 | bout = torch.max(bout, 1)[0].squeeze().data.cpu().numpy().astype(np.float32) 319 | 320 | embed_dict[sent] = np.hstack((fout, bout)) 321 | np.random.seed(0) 322 | embed_dict[""] = np.random.uniform(size=2048).astype(np.float32) 323 | embed_dict[""] = np.random.uniform(size=2048).astype(np.float32) 324 | return embed_dict 325 | 326 | def save_eval_perm(data_name, if_sample=False, random_seed=config.RANDOM_SEED): 327 | random.seed(random_seed) 328 | 329 | logging.info("Loading valid and test data...") 330 | if data_name not in config.DATASET: 331 | raise ValueError("Invalid data name!") 332 | dataset = DataSet(config.DATASET[data_name]) 333 | # dataset.random_seed = random_seed 334 | if if_sample: 335 | valid_dataset = dataset.load_valid_sample() 336 | else: 337 | valid_dataset = dataset.load_valid() 338 | if if_sample: 339 | test_dataset = dataset.load_test_sample() 340 | else: 341 | test_dataset = dataset.load_test() 342 | valid_df = valid_dataset.article_df 343 | test_df = test_dataset.article_df 344 | 345 | logging.info("Generating permuted articles...") 346 | 347 | def permute(x): 348 | x = np.array(x).squeeze() 349 | # neg_x_list = permute_articles([x], config.NEG_PERM)[0] 350 | neg_x_list = permute_articles_with_replacement([x], config.NEG_PERM)[0] 351 | return "".join(["".join(i) for i in neg_x_list]) 352 | 353 | valid_df["neg_list"] = valid_df.sentences.map(permute) 354 | valid_df["sentences"] = valid_df.sentences.map(lambda x: "".join(x)) 355 | valid_nums = valid_df.neg_list.map(lambda x: len(x.split(""))).sum() 356 | test_df["neg_list"] = test_df.sentences.map(permute) 357 | test_df["sentences"] = test_df.sentences.map(lambda x: "".join(x)) 358 | test_nums = test_df.neg_list.map(lambda x: len(x.split(""))).sum() 359 | 360 | logging.info("Number of validation pairs %d" % valid_nums) 361 | logging.info("Number of test pairs %d" % test_nums) 362 | 363 | logging.info("Saving...") 364 | dataset.save_valid_perm(valid_df) 365 | dataset.save_test_perm(test_df) 366 | logging.info("Finished!") 367 | 368 | def add_args(parser): 369 | parser.add_argument('--data_name', type=str, default='wsj_bigram') 370 | 371 | 372 | if __name__ == "__main__": 373 | _set_basic_logging() 374 | parser = argparse.ArgumentParser() 375 | add_args(parser) 376 | args = parser.parse_args() 377 | save_eval_perm(args.data_name, False) 378 | -------------------------------------------------------------------------------- /run_bigram_coherence.py: -------------------------------------------------------------------------------- 1 | from models.coherence_models import BigramCoherence 2 | from preprocess import get_infersent, get_average_glove, save_eval_perm, get_lm_hidden 3 | from preprocess import get_s2s_hidden 4 | from utils.data_utils import DataSet 5 | from utils.lm_utils import Corpus, SentCorpus 6 | from utils.logging_utils import _set_basic_logging 7 | import logging 8 | import config 9 | from torch.utils.data import DataLoader 10 | import os 11 | import argparse 12 | from add_args import add_bigram_args 13 | import torch 14 | 15 | 16 | def run_bigram_coherence(args): 17 | logging.info("Loading data...") 18 | if args.data_name not in config.DATASET: 19 | raise ValueError("Invalid data name!") 20 | dataset = DataSet(config.DATASET[args.data_name]) 21 | # dataset.random_seed = args.random_seed 22 | if not os.path.isfile(dataset.test_perm): 23 | save_eval_perm(args.data_name, random_seed=args.random_seed) 24 | 25 | train_dataset = dataset.load_train(args.portion) 26 | train_dataloader = DataLoader(dataset=train_dataset, 27 | batch_size=args.batch_size, 28 | shuffle=True, drop_last=True) 29 | valid_dataset = dataset.load_valid(args.portion) 30 | valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False) 31 | valid_df = dataset.load_valid_perm() 32 | test_dataset = dataset.load_test(args.portion) 33 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 34 | test_df = dataset.load_test_perm() 35 | 36 | logging.info("Loading sent embedding...") 37 | if args.sent_encoder == "infersent": 38 | sent_embedding = get_infersent(args.data_name, if_sample=args.test) 39 | embed_dim = 4096 40 | elif args.sent_encoder == "average_glove": 41 | sent_embedding = get_average_glove(args.data_name, if_sample=args.test) 42 | embed_dim = 300 43 | elif args.sent_encoder == "lm_hidden": 44 | corpus = Corpus(train_dataset.file_list, test_dataset.file_list) 45 | sent_embedding = get_lm_hidden(args.data_name, "lm_" + args.data_name, corpus) 46 | embed_dim = 2048 47 | elif args.sent_encoder == "s2s_hidden": 48 | corpus = SentCorpus(train_dataset.file_list, test_dataset.file_list) 49 | sent_embedding = get_s2s_hidden(args.data_name, "s2s_" + args.data_name, corpus) 50 | embed_dim = 2048 51 | else: 52 | raise ValueError("Invalid sent encoder name!") 53 | 54 | logging.info("Training BigramCoherence model...") 55 | kwargs = { 56 | "embed_dim": embed_dim, 57 | "sent_encoder": sent_embedding, 58 | "hparams": { 59 | "loss": args.loss, 60 | "input_dropout": args.input_dropout, 61 | "hidden_state": args.hidden_state, 62 | "hidden_layers": args.hidden_layers, 63 | "hidden_dropout": args.hidden_dropout, 64 | "num_epochs": args.num_epochs, 65 | "margin": args.margin, 66 | "lr": args.lr, 67 | "l2_reg_lambda": args.l2_reg_lambda, 68 | "use_bn": args.use_bn, 69 | "task": "discrimination", 70 | "bidirectional": args.bidirectional, 71 | } 72 | } 73 | 74 | model = BigramCoherence(**kwargs) 75 | model.init() 76 | best_step, valid_acc = model.fit(train_dataloader, valid_dataloader, valid_df) 77 | if args.save: 78 | model_path = os.path.join(config.CHECKPOINT_PATH, "%s-%.4f" % (args.data_name, valid_acc)) 79 | # model.save(model_path) 80 | torch.save(model, model_path + '.pth') 81 | model.load_best_state() 82 | 83 | # dataset = DataSet(config.DATASET["wsj_bigram"]) 84 | # test_dataset = dataset.load_test() 85 | # test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 86 | # test_df = dataset.load_test_perm() 87 | # if args.sent_encoder == "infersent": 88 | # model.sent_encoder = get_infersent("wsj_bigram", if_sample=args.test) 89 | # elif args.sent_encoder == "average_glove": 90 | # model.sent_encoder = get_average_glove("wsj_bigram", if_sample=args.test) 91 | # else: 92 | # model.sent_encoder = get_lm_hidden("wsj_bigram", "lm_" + args.data_name, corpus) 93 | 94 | logging.info("Results for discrimination:") 95 | dis_acc = model.evaluate_dis(test_dataloader, test_df) 96 | print("Test Acc:", dis_acc) 97 | logging.info("Disc Accuracy: {}".format(dis_acc[0])) 98 | 99 | logging.info("Results for insertion:") 100 | ins_acc = model.evaluate_ins(test_dataloader, test_df) 101 | print("Test Acc:", ins_acc) 102 | logging.info("Insert Accuracy: {}".format(ins_acc[0])) 103 | 104 | return dis_acc, ins_acc 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser() 109 | add_bigram_args(parser) 110 | args = parser.parse_args() 111 | 112 | _set_basic_logging() 113 | run_bigram_coherence(args) 114 | -------------------------------------------------------------------------------- /run_lm_coherence.py: -------------------------------------------------------------------------------- 1 | from models.language_models import LanguageModel, LMCoherence 2 | from utils.lm_utils import Corpus 3 | from utils.data_utils import DataSet 4 | from utils.logging_utils import _set_basic_logging 5 | import logging 6 | import config 7 | from torch.utils.data import DataLoader 8 | import os 9 | import argparse 10 | import pickle 11 | 12 | 13 | def run_lm_coherence(args): 14 | logging.info("Loading data...") 15 | if args.data_name not in config.DATASET: 16 | raise ValueError("Invalid data name!") 17 | 18 | dataset = DataSet(config.DATASET[args.data_name]) 19 | train_dataset = dataset.load_train() 20 | test_df = dataset.load_test_perm() 21 | test_dataset = dataset.load_test() 22 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 23 | corpus = Corpus(train_dataset.file_list, test_dataset.file_list) 24 | 25 | # dataset = DataSet(config.DATASET["wsj_bigram"]) 26 | # test_df = dataset.load_test_perm() 27 | # test_dataset = dataset.load_test() 28 | # test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 29 | 30 | with open(os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_forward.pkl"), "rb") as f: 31 | hparams = pickle.load(f) 32 | 33 | kwargs = { 34 | "vocab_size": corpus.glove_embed.shape[0], 35 | "embed_dim": corpus.glove_embed.shape[1], 36 | "corpus": corpus, 37 | "hparams": hparams, 38 | } 39 | 40 | forward_model = LanguageModel(**kwargs) 41 | forward_model.load(os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_forward.pt")) 42 | backward_model = LanguageModel(**kwargs) 43 | backward_model.load(os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_backward.pt")) 44 | 45 | logging.info("Results for discrimination:") 46 | model = LMCoherence(forward_model.lm, backward_model.lm, corpus) 47 | dis_acc = model.evaluate_dis(test_dataloader, test_df) 48 | logging.info("Disc Accuracy: {}".format(dis_acc)) 49 | 50 | logging.info("Results for insertion:") 51 | ins_acc = model.evaluate_ins(test_dataloader, test_df) 52 | logging.info("Disc Accuracy: {}".format(ins_acc)) 53 | 54 | return dis_acc, ins_acc 55 | 56 | def add_args(parser): 57 | parser.add_argument('--data_name', type=str, default="wiki_bigram_easy", 58 | help='data name') 59 | parser.add_argument('--lm_name', type=str, default="lm_wiki_bigram_easy", 60 | help='languange model name') 61 | 62 | 63 | if __name__ == "__main__": 64 | _set_basic_logging() 65 | 66 | parser = argparse.ArgumentParser() 67 | add_args(parser) 68 | args = parser.parse_args() 69 | 70 | run_lm_coherence(args) 71 | -------------------------------------------------------------------------------- /train_lm.py: -------------------------------------------------------------------------------- 1 | from utils.lm_utils import Corpus 2 | from utils.data_utils import DataSet 3 | from models.language_models import LanguageModel 4 | import config 5 | import argparse 6 | 7 | def train_lm(args): 8 | if args.data_name not in config.DATASET: 9 | raise ValueError("Invalid data name!") 10 | dataset = DataSet(config.DATASET[args.data_name]) 11 | train_dataset = dataset.load_train() 12 | test_dataset = dataset.load_test() 13 | corpus = Corpus(train_dataset.file_list, test_dataset.file_list, args.reverse) 14 | suffix = "backward" if args.reverse else "forward" 15 | 16 | kwargs = { 17 | "vocab_size": corpus.glove_embed.shape[0], 18 | "embed_dim": corpus.glove_embed.shape[1], 19 | "corpus": corpus, 20 | "hparams": { 21 | "hidden_size": args.hidden_size, 22 | "num_layers": args.num_layers, 23 | "cell_type": args.cell_type, 24 | "tie_embed": args.tie_embed, 25 | "rnn_dropout": args.rnn_dropout, 26 | "hidden_dropout": args.hidden_dropout, 27 | "num_epochs": args.num_epochs, 28 | "batch_size": args.batch_size, 29 | "bptt": args.bptt, 30 | "log_interval": args.log_interval, 31 | "save_path": args.save_path + '_' + args.data_name + '_' + suffix, 32 | "lr": args.lr, 33 | "wdecay": args.wdecay, 34 | } 35 | } 36 | 37 | lm = LanguageModel(**kwargs) 38 | best_valid_loss = lm.fit() 39 | print("Best Valid Loss:", best_valid_loss) 40 | 41 | def add_args(parser): 42 | parser.add_argument('--data_name', type=str, default="wsj_bigram", 43 | help='data name') 44 | parser.add_argument('--batch_size', type=int, default=128, 45 | help='batch_size') 46 | parser.add_argument('--num_epochs', type=int, default=1000, 47 | help='number of training epochs') 48 | parser.add_argument('--bptt', type=int, default=35, 49 | help='sequence length') 50 | parser.add_argument('--log_interval', type=int, default=100, 51 | help='log interval for training') 52 | parser.add_argument('--save_path', type=str, default='checkpoint/lm', 53 | help='save path') 54 | parser.add_argument('--reverse', default=False, action='store_true', 55 | help='reverse the text') 56 | 57 | parser.add_argument('--hidden_size', type=int, default=1024, 58 | help='hidden size') 59 | parser.add_argument('--num_layers', type=int, default=2, 60 | help='number of hidden layers') 61 | parser.add_argument('--cell_type', type=str, default='lstm', 62 | help='RNN cell type (i.e. rnn, gru or lstm)') 63 | parser.add_argument('--tie_embed', default=False, action='store_true', 64 | help='Tie embedding and softmax weights') 65 | parser.add_argument('--rnn_dropout', type=float, default=0.5, 66 | help='RNN dropout') 67 | parser.add_argument('--hidden_dropout', type=float, default=0.5, 68 | help='hidden dropout') 69 | 70 | parser.add_argument('--lr', type=float, default=0.001, 71 | help='learning rate') 72 | parser.add_argument('--wdecay', type=float, default=0.0, 73 | help='weight decay') 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | add_args(parser) 79 | args = parser.parse_args() 80 | 81 | train_lm(args) 82 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import os 5 | 6 | class WSJ_Bigram_Dataset(Dataset): 7 | def __init__(self, scr_path, portion=1.0, mode='train'): 8 | self.data_path = scr_path 9 | self.train_list = ['00', '01', '02', '03', '04', '05', '06', 10 | '07', '08', '09', '10'] 11 | self.valid_list = ['11', '12', '13'] 12 | self.test_list = ['14', '15', '16', '17', '18', '19', '20', 13 | '21', '22', '23', '24'] 14 | self.portion = portion 15 | self.mode = mode 16 | 17 | if self.mode == 'train': 18 | self.file_list = self.get_file_list(self.train_list) 19 | elif self.mode == 'valid': 20 | self.file_list = self.get_file_list(self.valid_list) 21 | elif self.mode == 'test': 22 | self.file_list = self.get_file_list(self.test_list) 23 | else: 24 | raise ValueError("Invalid mode name!") 25 | 26 | self.sentences = self.get_all_sentences(self.file_list) 27 | self.total_sent = len(self.sentences) 28 | self.examples = [] 29 | 30 | self.sent_index, self.article_index = self.create_index(self.file_list) 31 | self.sent_df = pd.DataFrame(self.sent_index) 32 | self.sent_df.columns = ['article', 'sentences'] 33 | self.sent_df.set_index('article', inplace=True) 34 | self.article_df = pd.DataFrame(self.article_index) 35 | self.article_df.columns = ['article', 'sentences'] 36 | self.article_df.reset_index(level=0, inplace=True) 37 | 38 | if self.mode in ['train']: 39 | self.total_cliques = len(self.examples) 40 | else: 41 | self.total_cliques = len(self.article_df) 42 | 43 | def __len__(self): 44 | return self.total_cliques 45 | 46 | def __getitem__(self, index): 47 | if self.mode in ['train']: 48 | article = self.examples[index] 49 | samples = self.sent_df.loc[article, 'sentences'] 50 | sample = np.random.choice(samples) 51 | return sample 52 | else: 53 | article_row = self.article_df.loc[index] 54 | return article_row['article'] 55 | 56 | def get_file_list(self, dir_list): 57 | file_list = [] 58 | for dirname in os.listdir(self.data_path): 59 | if dirname in dir_list: 60 | subdirpath = os.path.join(self.data_path, dirname) 61 | for filename in os.listdir(subdirpath): 62 | file_list.append(os.path.join(subdirpath, filename)) 63 | return file_list 64 | 65 | def preprocess(self, article, index): 66 | sentences = [] 67 | with open(article) as f: 68 | for line in f: 69 | line = line.strip() 70 | if line != '' and line != '': 71 | sentences.append(line) 72 | sent_num = len(sentences) 73 | if sent_num < 3: 74 | return sentences 75 | sentences = [''] + sentences + [''] 76 | sent1 = sentences[:-1] 77 | pos_sent2 = sentences[1:] 78 | samples = [] 79 | weights = [] 80 | for i in range(sent_num + 1): 81 | for j in list(range(i)) + list(range(i + 2, sent_num + 2)): 82 | sent = "".join( 83 | [sent1[i], pos_sent2[i], sentences[j], str(sent_num)]) 84 | samples.append(sent) 85 | factor = np.sqrt(max(1, np.abs(i - j))) 86 | weights.append(1.0 / factor) 87 | index.append([article, np.random.choice( 88 | samples, max(1, int(len(samples) * self.portion)), False)]) 89 | for _ in range(50): 90 | self.examples.append(article) 91 | return sentences[1:-1] 92 | 93 | def create_index(self, file_list): 94 | sidx = [] 95 | aidx = [] 96 | for article in file_list: 97 | sentences = self.preprocess(article, sidx) 98 | if len(sentences) > 2: 99 | aidx.append([article, sentences]) 100 | return sidx, aidx 101 | 102 | def get_all_sentences(self, file_list): 103 | sentences = [] 104 | for article in file_list: 105 | with open(article) as f: 106 | for line in f: 107 | line = line.strip() 108 | if line != '' and line != '': 109 | sentences.append(line) 110 | return sentences 111 | 112 | class WIKI_Bigram_Dataset(Dataset): 113 | def __init__(self, scr_path, portion, mode='train', 114 | train_list=[], test_list=[], article_index=None): 115 | self.data_path = scr_path 116 | self.portion = portion 117 | self.mode = mode 118 | self.train_list = train_list 119 | self.test_list = test_list 120 | 121 | if self.mode in ['train', 'valid']: 122 | self.file_list = self.get_file_list(self.train_list) 123 | elif self.mode in ['test']: 124 | self.file_list = self.get_file_list(self.test_list) 125 | else: 126 | raise ValueError("Invalid mode name!") 127 | 128 | # self.sent_index, self.article_index = self.create_index(self.file_list) 129 | # self.sent_df = pd.DataFrame(self.sent_index) 130 | # self.sent_df.columns = ['article', 'sentences'] 131 | # self.sent_df.reset_index(level=0, inplace=True) 132 | 133 | self.article_index = article_index 134 | if article_index is None: 135 | self.article_index = self.create_index(self.file_list) 136 | 137 | self.article_df = pd.DataFrame(self.article_index) 138 | self.article_df.columns = ['article', 'sentences'] 139 | self.article_df.reset_index(level=0, inplace=True) 140 | 141 | np.random.seed(0) 142 | mask = np.random.uniform(0, 1, len(self.article_df)) > 0.1 143 | if self.mode == 'train': 144 | self.article_df = self.article_df.iloc[mask] 145 | elif self.mode == 'valid': 146 | self.article_df = self.article_df.iloc[~mask] 147 | self.article_df.reset_index(drop=True, inplace=True) 148 | 149 | if self.mode == 'train': 150 | self.sentences = self.get_all_sentences(self.file_list) 151 | self.total_sent = len(self.sentences) 152 | self.examples = [] 153 | 154 | self.sent_index = self.create_sent_index() 155 | self.sent_df = pd.DataFrame(self.sent_index) 156 | self.sent_df.columns = ['article', 'sentences'] 157 | self.sent_df.set_index('article', inplace=True) 158 | 159 | self.total_cliques = len(self.examples) 160 | else: 161 | self.total_cliques = len(self.article_df) 162 | 163 | def __len__(self): 164 | return self.total_cliques 165 | 166 | def __getitem__(self, index): 167 | if self.mode in ['train']: 168 | article = self.examples[index] 169 | samples = self.sent_df.loc[article, 'sentences'] 170 | sample = np.random.choice(samples) 171 | return sample 172 | else: 173 | article_row = self.article_df.loc[index] 174 | return article_row['article'] 175 | 176 | def get_file_list(self, dir_list): 177 | file_list = [] 178 | for dirname in os.listdir(self.data_path): 179 | if dirname in dir_list: 180 | subdirpath = os.path.join(self.data_path, dirname) 181 | file_list.append(os.path.join( 182 | subdirpath, "extracted_paras.txt")) 183 | return file_list 184 | 185 | def create_index(self, file_list): 186 | aidx = [] 187 | for article in file_list: 188 | idx = 0 189 | with open(article) as f: 190 | sentences = [] 191 | for line in f: 192 | line = line.strip() 193 | if line == '': 194 | idx += 1 195 | aidx.append([article + str(idx), sentences]) 196 | sentences = [] 197 | elif line != "": 198 | sentences.append(line) 199 | return aidx 200 | 201 | def create_sent_index(self): 202 | sidx = [] 203 | for i in range(self.article_df.shape[0]): 204 | article = self.article_df.article[i] 205 | sentences = self.article_df.sentences[i] 206 | sent_num = len(sentences) 207 | sentences = [""] + sentences + [""] 208 | sent1 = sentences[:-1] 209 | pos_sent2 = sentences[1:] 210 | samples = [] 211 | weights = [] 212 | for i in range(sent_num + 1): 213 | for j in list(range(i)) + list(range(i + 2, sent_num + 2)): 214 | sent = "".join( 215 | [sent1[i], pos_sent2[i], sentences[j], str(sent_num)]) 216 | samples.append(sent) 217 | factor = np.sqrt(max(1, np.abs(i - j))) 218 | weights.append(1.0 / factor) 219 | sidx.append([article, np.random.choice( 220 | samples, max(1, int(len(samples) * self.portion)), False)]) 221 | for _ in range(50): 222 | self.examples.append(article) 223 | return sidx 224 | 225 | def get_all_sentences(self, file_list): 226 | sentences = [] 227 | for article in file_list: 228 | with open(article) as f: 229 | for line in f: 230 | line = line.strip() 231 | if line != '' and line != '': 232 | sentences.append(line) 233 | return sentences 234 | 235 | 236 | class DataSet: 237 | def __init__(self, d): 238 | self.dataset = d["dataset"] 239 | self.data_path = d["data_path"] 240 | self.sample_path = d["sample_path"] 241 | self.valid_perm = d["valid_perm"] 242 | self.test_perm = d["test_perm"] 243 | self.kwargs = d["kwargs"] 244 | self.col_names = ["article", "sentences", "neg_list"] 245 | 246 | def load_train(self, portion=1.0): 247 | return self.dataset(self.data_path, portion, 'train', **self.kwargs) 248 | 249 | def load_valid(self, portion=1.0): 250 | return self.dataset(self.data_path, portion, 'valid', **self.kwargs) 251 | 252 | def load_test(self, portion=1.0, **kwargs): 253 | kwargs.update(self.kwargs) 254 | return self.dataset(self.data_path, portion, 'test', **kwargs) 255 | 256 | def load_train_sample(self): 257 | return self.dataset(self.sample_path, 1.0, 'train', **self.kwargs) 258 | 259 | def load_valid_sample(self): 260 | return self.dataset(self.sample_path, 1.0, 'valid', **self.kwargs) 261 | 262 | def load_test_sample(self): 263 | return self.dataset(self.sample_path, 1.0, 'test', **self.kwargs) 264 | 265 | def load_valid_perm(self): 266 | df = pd.read_csv(self.valid_perm, sep="\t", names=[ 267 | "article", "sentences", "neg_list"]) 268 | df.set_index("article", inplace=True) 269 | return df 270 | 271 | def load_test_perm(self): 272 | df = pd.read_csv(self.test_perm, sep="\t", names=[ 273 | "article", "sentences", "neg_list"]) 274 | df.set_index("article", inplace=True) 275 | return df 276 | 277 | def save_valid_perm(self, df): 278 | df[self.col_names].to_csv( 279 | self.valid_perm, sep="\t", index=False, header=False) 280 | 281 | def save_test_perm(self, df): 282 | df[self.col_names].to_csv( 283 | self.test_perm, sep="\t", index=False, header=False) 284 | -------------------------------------------------------------------------------- /utils/lm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import Counter 3 | import config 4 | import numpy as np 5 | 6 | class Vocabulary(object): 7 | OOV = 0 8 | EOS = 1 9 | 10 | def __init__(self): 11 | self.word2idx = {"": 0, "": 1} 12 | self.idx2word = ["", ""] 13 | self.size = 2 14 | 15 | def add_word(self, word): 16 | if word not in self.word2idx: 17 | self.idx2word.append(word) 18 | self.word2idx[word] = len(self.idx2word) - 1 19 | self.size += 1 20 | 21 | def __getitem__(self, word): 22 | return self.word2idx.get(word, 0) 23 | 24 | def to_word(self, idx): 25 | return self.idx2word[idx] 26 | 27 | def __len__(self): 28 | return self.size 29 | 30 | class Corpus(object): 31 | def __init__(self, train_list, test_list, reverse=False): 32 | self.vocab = Vocabulary() 33 | 34 | corpus_vocab = self.create_vocab(train_list + test_list) 35 | embedding_file = config.WORD_EMBEDDING 36 | 37 | words = [] 38 | glove_embed = [] 39 | with open(embedding_file) as f: 40 | for line in f: 41 | word, vec = line.split(' ', 1) 42 | if word in corpus_vocab: 43 | self.vocab.add_word(word) 44 | glove_embed.append(np.fromstring( 45 | vec, sep=' ', dtype=np.float32)) 46 | words.append(word) 47 | 48 | self.glove_embed = np.vstack(glove_embed) 49 | _mu = self.glove_embed.mean() 50 | _std = self.glove_embed.std() 51 | self.glove_embed = np.vstack([np.random.randn( 52 | 2, self.glove_embed.shape[1]) * _std + _mu, self.glove_embed]).astype(np.float32) 53 | print(self.glove_embed.shape) 54 | 55 | train = self.tokenize(train_list, reverse) 56 | train_len = int(train.shape[0] * 0.9) 57 | 58 | self.train = train[:train_len] 59 | self.valid = train[train_len:] 60 | self.test = self.tokenize(test_list, reverse) 61 | 62 | def create_vocab(self, file_list, top=100000): 63 | counter = Counter() 64 | for article in file_list: 65 | with open(article) as f: 66 | for line in f: 67 | line = line.strip() 68 | if line == '': 69 | continue 70 | for word in line.split(): 71 | counter[word] += 1 72 | counter = sorted(counter, key=counter.get, reverse=True) 73 | counter = counter[:top] 74 | return set(counter) 75 | 76 | def tokenize(self, file_list, reverse=False): 77 | words = [] 78 | for article in file_list: 79 | with open(article) as f: 80 | for line in f: 81 | line = line.strip() 82 | if (line == '') or (line == ''): 83 | continue 84 | for word in line.split() + ['']: 85 | words.append(word) 86 | 87 | idxs = [self.vocab[w] for w in words] 88 | if reverse: 89 | idxs = idxs[::-1] 90 | idxs = torch.LongTensor(idxs) 91 | return idxs 92 | 93 | class SentCorpus(object): 94 | def __init__(self, train_list, test_list, reverse=False, max_len=40, shuffle=True): 95 | self.vocab = Vocabulary() 96 | self.reverse = reverse 97 | self.shuffle = shuffle 98 | 99 | corpus_vocab = self.create_vocab(train_list + test_list) 100 | embedding_file = config.WORD_EMBEDDING 101 | 102 | words = [] 103 | glove_embed = [] 104 | with open(embedding_file) as f: 105 | for line in f: 106 | word, vec = line.split(' ', 1) 107 | if word in corpus_vocab: 108 | self.vocab.add_word(word) 109 | glove_embed.append(np.fromstring( 110 | vec, sep=' ', dtype=np.float32)) 111 | words.append(word) 112 | 113 | self.glove_embed = np.vstack(glove_embed) 114 | _mu = self.glove_embed.mean() 115 | _std = self.glove_embed.std() 116 | self.glove_embed = np.vstack([np.random.randn( 117 | 2, self.glove_embed.shape[1]) * _std + _mu, self.glove_embed]).astype(np.float32) 118 | print(self.glove_embed.shape) 119 | 120 | data = self.get_data(train_list) 121 | 122 | train_len = int(len(data) * 0.9) 123 | 124 | self.train = data[:train_len] 125 | self.valid = data[train_len:] 126 | self.max_len = max_len 127 | 128 | def create_vocab(self, file_list, top=100000): 129 | counter = Counter() 130 | for article in file_list: 131 | with open(article) as f: 132 | for line in f: 133 | line = line.strip() 134 | if line == '': 135 | continue 136 | for word in line.split(): 137 | counter[word] += 1 138 | counter = sorted(counter, key=counter.get, reverse=True) 139 | counter = counter[:top] 140 | return set(counter) 141 | 142 | def get_data(self, file_list): 143 | source = [] 144 | target = [] 145 | for article in file_list: 146 | sentences = [] 147 | with open(article) as f: 148 | for line in f: 149 | line = line.strip() 150 | if (line == '') or (line == ''): 151 | continue 152 | sentences.append(line) 153 | source.extend(sentences[:-1]) 154 | target.extend(sentences[1:]) 155 | if self.reverse: 156 | return list(zip(target, source)) 157 | else: 158 | return list(zip(source, target)) 159 | 160 | def tokenize(self, sent, reverse=False): 161 | words = sent.split() 162 | if reverse or self.reverse: 163 | words = words[::-1] 164 | indices = [self.vocab[w] for w in words][:self.max_len - 1] 165 | indices += [self.vocab.EOS] * (self.max_len - len(indices)) 166 | return indices 167 | 168 | def fetch_train_batches(self, batch_size): 169 | data_size = len(self.train) 170 | nbatch = data_size // batch_size 171 | 172 | if self.shuffle: 173 | shuffle_indices = np.random.permutation(np.arange(data_size)) 174 | else: 175 | shuffle_indices = np.arange(data_size) 176 | 177 | for i in range(nbatch): 178 | start_index = i * batch_size 179 | end_index = (i + 1) * batch_size 180 | source, source_len = [], [] 181 | target, target_len = [], [] 182 | for j in shuffle_indices[start_index:end_index]: 183 | s, t = self.train[j] 184 | source_len.append(min(len(s.split()) + 1, self.max_len)) 185 | source.append(self.tokenize(s)) 186 | target_len.append(min(len(t.split()) + 1, self.max_len)) 187 | target.append(self.tokenize(t)) 188 | yield source, source_len, target, target_len 189 | 190 | def fetch_valid_batches(self, batch_size): 191 | data_size = len(self.valid) 192 | nbatch = data_size // batch_size 193 | for i in range(nbatch): 194 | start_index = i * batch_size 195 | end_index = (i + 1) * batch_size 196 | source, source_len = [], [] 197 | target, target_len = [], [] 198 | for j in range(start_index, end_index): 199 | s, t = self.valid[j] 200 | source_len.append(min(len(s.split()) + 1, self.max_len)) 201 | source.append(self.tokenize(s)) 202 | target_len.append(min(len(t.split()) + 1, self.max_len)) 203 | target.append(self.tokenize(t)) 204 | yield source, source_len, target, target_len 205 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import logging.handlers 4 | 5 | def _get_logger(logdir, logname, loglevel=logging.INFO): 6 | fmt = "[%(asctime)s] %(levelname)s: %(message)s" 7 | formatter = logging.Formatter(fmt) 8 | 9 | handler = logging.handlers.RotatingFileHandler( 10 | filename=os.path.join(logdir, logname), 11 | maxBytes=10 * 1024 * 1024, 12 | backupCount=10 13 | ) 14 | handler.setFormatter(formatter) 15 | 16 | logger = logging.getLogger("") 17 | logger.addHandler(handler) 18 | logger.setLevel(loglevel) 19 | return logger 20 | 21 | def _set_basic_logging(): 22 | fmt = "[%(asctime)s] %(levelname)s: %(message)s" 23 | logging.basicConfig(level=logging.INFO, 24 | format=fmt) 25 | -------------------------------------------------------------------------------- /utils/np_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def random_permutation_matrix(N, dtype=np.float32): 4 | """ 5 | Generate a random permutation matrix. 6 | 7 | :param N: dimension of the permutation matrix 8 | :return: a numpy array with shape (N, N) 9 | """ 10 | A = np.identity(N, dtype=dtype) 11 | idx = np.random.permutation(N) 12 | return A[idx, :] 13 | 14 | def generate_random_pmatrices(N, size): 15 | """ 16 | Generate a batch of random permutation matrices. 17 | 18 | :param N: dimension of the permutation matrices 19 | :param size: number of generated matrices 20 | :return: a numpy array with shape (size, N, N) 21 | """ 22 | res = [] 23 | for i in range(size): 24 | res.append(random_permutation_matrix(N)) 25 | return np.array(res) 26 | --------------------------------------------------------------------------------