├── README.md ├── __init__.py ├── gpu_requirements.txt ├── images ├── best_dev_ppl.svg ├── best_train_ppl.svg ├── gru_dev_ppl.svg ├── gru_train_ppl.svg ├── lstm_2_layer_dev_ppl.svg ├── lstm_2_layer_train_ppl.svg ├── lstm_dev_ppl.svg ├── lstm_leaky_dev_ppl.svg ├── lstm_leaky_train_ppl.svg ├── lstm_relu_dev_ppl.svg ├── lstm_relu_train_ppl.svg └── lstm_train_ppl.svg ├── local_env.yml ├── model_embeddings.py ├── nmt_model.py ├── run.py ├── run.sh ├── utils.py └── vocab.py /README.md: -------------------------------------------------------------------------------- 1 | # ZH-EN NMT Chinese to English Neural Machine Translation 2 | 3 | > This project is inspired by Stanford's CS224N NMT Project
4 | 5 | > Dataset used in this project: [News Commentary v14](http://data.statmt.org/news-commentary/v14) 6 | 7 | ## Intro 8 | 9 | This project is more of a learning project to make myself familiar with Pytorch, machine translation, and NLP model training. 10 | 11 | To investigate how would various setups of the recurrent layer affect the final performance, I compared Training Efficiency and Effectiveness of different types of RNN layer for encoder by changing one feature each time while controlling all other parameters: 12 | 13 | - RNN types 14 | - GRU 15 | - LSTM 16 | - Activation Functions on Output Layer 17 | - Tanh 18 | - ReLU 19 | - LeakyReLU 20 | - Number of layers 21 | 22 | - single layer 23 | - double layer 24 | 25 | ## Code Files 26 | 27 | ``` 28 | _/ 29 | ├─ utils.py # utilities 30 | ├─ vocab.py # generate vocab 31 | ├─ model_embeddings.py # embedding layer 32 | ├─ nmt_model.py # nmt model definition 33 | ├─ run.py # training and testing 34 | ``` 35 | 36 | ## Good Translation Examples 37 | 38 | - ***source***: 相反,这意味着合作的基础应当是共同的长期战略利益,而不是共同的价值观。 39 | - ***target***: Instead, it means that cooperation must be anchored not in shared values, but in shared long-term strategic interests. 40 | - ***translation***: On the contrary, that means cooperation should be a common long-term strategic interests, rather than shared values. 41 | 42 | - ***source***: 但这个问题其实很简单: 谁来承受这些用以降低预算赤字的紧缩措施的冲击。 43 | - ***target***: But the issue is actually simple: Who will bear the brunt of measures to reduce the budget deficit? 44 | - ***translation***: But the question is simple: Who is to bear the impact of austerity measures to reduce budget deficits? 45 | - ***source***: 上述合作对打击恐怖主义、贩卖人口和移民可能发挥至关重要的作用。 46 | - ***target***: Such cooperation is essential to combat terrorism, human trafficking, and migration. 47 | - ***translation***: Such cooperation is essential to fighting terrorism, trafficking, and migration. 48 | - ***source***: 与此同时, 政治危机妨碍着政府追求艰难的改革。 49 | - ***target***: At the same time, political crisis is impeding the government’s pursuit of difficult reforms. 50 | - ***translation***: Meanwhile, political crises hamper the government’s pursuit of difficult reforms. 51 | 52 | ## Preprocessing 53 | 54 | > Preprocessing Colab [notebook](https://colab.research.google.com/drive/1IJTdk7hj3uoPEE0Ox7QaeW4rTuUzuxPJ?usp=sharing) 55 | 56 | - using [`jieba` ](https://github.com/fxsjy/jieba)to separate Chinese words by spaces 57 | 58 | ## Generate Vocab From Training Data 59 | 60 | - Input: training data of Chinese and English 61 | 62 | - Output: a vocab file containing mapping from (sub)words to ids of Chinese and English -- a limited size of vocab is selected using [SentencePiece](https://github.com/google/sentencepiece) (essentially [Byte Pair Encoding](https://en.wikipedia.org/wiki/Byte_pair_encoding) of character n-grams) to cover around 99.95% of training data 63 | 64 | ## Model Definition 65 | 66 | - a Seq2Seq model with attention 67 | 68 | > This image is from the book [DIVE INTO DEEP LEARNING](https://zh-v2.d2l.ai/index.html) 69 | 70 | ![](https://zh-v2.d2l.ai/_images/seq2seq-attention-details.svg) 71 | 72 | - Encoder 73 | - A Recurrent Layer 74 | - Decoder 75 | - LSTMCell (hidden_size=512) 76 | - Attention 77 | - Multiplicative Attention 78 | 79 | ## Training And Testing Results 80 | 81 | > Training Colab [notebook](https://colab.research.google.com/drive/1HYbOh0AUMEasBAH7QPGNq9joH2dRRZwg?usp=sharing) 82 | 83 | - **Hyperparameters:** 84 | - Embedding Size & Hidden Size: 512 85 | - Dropout Rate: 0.25 86 | - Starting Learning Rate: 5e-4 87 | - Batch Size: 32 88 | - Beam Size for Beam Search: 10 89 | - **NOTE:** The BLEU score calculated here is based on the `Test Set`, so it could only be used to compare the **relative effectiveness** of the models using this data 90 | 91 | #### For Experiment 92 | 93 | - **Dataset:** the dataset is split into training set(~260000), validation set(~20000), and testing set(~20000) randomly (they are the same for each experiment group) 94 | - **Max Number of Iterations**: 50000 95 | - **NOTE:** I've tried Vanilla-RNN(nn.RNN) in various ways, but the BLEU score turns out to be extremely low for it (absence of `residual connections` might be the issue) 96 | - I decided to not include it for comparison until the issue is resolved 97 | 98 | | | Training Time(sec) | BLEU Score on Test Set | Training Perplexities | Validation Perplexities | 99 | | ------------------------------------------------ | ------------------ | ---------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 100 | | **A.** Bidirectional 1-Layer GRU with Tanh | 5158.99 | 14.26 | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/edb3bca3d2c190398ab211195d9b14a16a163d76/images/gru_train_ppl.svg) | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/edb3bca3d2c190398ab211195d9b14a16a163d76/images/gru_dev_ppl.svg) | 101 | | **B.** Bidirectional 1-Layer LSTM with Tanh | 5150.31 | 16.20 | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/edb3bca3d2c190398ab211195d9b14a16a163d76/images/lstm_train_ppl.svg) | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/edb3bca3d2c190398ab211195d9b14a16a163d76/images/lstm_dev_ppl.svg) | 102 | | **C.** Bidirectional 2-Layer LSTM with Tanh | 6197.58 | **16.38** | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/4e70246b618a0fa35d5ab75193df638ac1e27562/images/lstm_2_layer_train_ppl.svg) | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/4e70246b618a0fa35d5ab75193df638ac1e27562/images/lstm_2_layer_dev_ppl.svg) | 103 | | **D.** Bidirectional 1-Layer LSTM with ReLU | 5275.12 | 14.01 | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/4e70246b618a0fa35d5ab75193df638ac1e27562/images/lstm_relu_train_ppl.svg) | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/4e70246b618a0fa35d5ab75193df638ac1e27562/images/lstm_relu_dev_ppl.svg) | 104 | | **E.** Bidirectional 1-Layer LSTM with LeakyReLU(slope=0.1) | 5292.58 | 14.87 | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/437f11b1a3004156fd97122cbe4f7d6d92f3bc53/images/lstm_leaky_train_ppl.svg) | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/437f11b1a3004156fd97122cbe4f7d6d92f3bc53/images/lstm_leaky_dev_ppl.svg) | 105 | 106 | #### Current Best Version 107 | Bidirectional 2-Layer LSTM with Tanh, **1024 embed_size & hidden_size**, trained 11517.19 sec (44000 iterations), BLEU score **17.95** 108 | | | Traning Time | BLEU Score on Test Set | Training Perplexities | Validation Perplexities | 109 | |:----------:|:------------:|:----------------------:|-----------------------|-------------------------| 110 | | Best Model | 11517.19 | **17.95** | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/a46acb697327c7e35804bcc169f79362d2a8f99a/images/best_train_ppl.svg) | ![](https://raw.githubusercontent.com/JasonFengGit/ZH-EN-Neural-Model-Translation/a46acb697327c7e35804bcc169f79362d2a8f99a/images/best_dev_ppl.svg) | 111 | 112 | #### Analysis 113 | 114 | - LSTM tends to have better performance than GRU (it has an extra set of parameters) 115 | - Tanh tends to be better since less information is lost 116 | - Making the LSTM deeper (more layers) could improve the performance, but it cost more time to train 117 | - Surprisingly, the training time for **A**, **B**, and **D** are roughly the same 118 | - the issue may be the dataset is not large enough, or the cloud service I used to train models does not perform consistently 119 | 120 | ## Bad Examples & Case Analysis 121 | 122 | - ***source***: **全球目击组织(Global Witness)**的报告记录, 光是2015年就有**16个国家**的185人被杀。 123 | - ***target***: A **Global Witness** report documented 185 killings across **16 countries** in 2015 alone. 124 | - ***translation***: According to the **Global eye**, the World Health Organization reported that 185 people were killed in 2015. 125 | - ***problems***: 126 | - Information Loss: 16 countries 127 | - Unknown Proper Noun: Global Witness 128 | - ***source***: 大自然给了足以满足每个人需要的东西, **但无法满足每个人的贪婪**。 129 | - ***target***: Nature provides enough for everyone’s needs, **but not for everyone’s greed**. 130 | - ***translation***: Nature provides enough to satisfy everyone. 131 | - ***problems***: 132 | - Huge Information Loss 133 | - ***source***: 我衷心希望全球经济危机和巴拉克·奥巴马当选总统能对新冷战的荒唐理念进行正确的评估。 134 | - ***target***: It is my hope that the global economic crisis and Barack Obama’s presidency will put the farcical idea of a new Cold War into proper perspective. 135 | - ***translation***: I do hope that the global economic crisis and President Barack Obama will be corrected for a new Cold War. 136 | - ***problems***: 137 | - Action Sender And Receiver Exchanged 138 | - Failed To Translate Complex Sentence 139 | - ***source***: 人们纷纷**猜测**欧元区将崩溃。 140 | - ***target***: **Speculation** about a possible breakup was widespread. 141 | - ***translation***: The eurozone would collapse. 142 | - ***problems***: 143 | - Significant Information Loss 144 | 145 | ## Means to Improve the NMT model 146 | 147 | - Dataset 148 | - The dataset is fairly small, and our model is not being trained thorough all data 149 | - Being a native Chinese speaker, I could not understand what some of the source sentences are saying 150 | - The target sentences are not informational comprehensive; they themselves need context to be understood (e.g. the target sentence in the last "Bad Examples") 151 | - Even for human, some of the source sentence was too hard to translate 152 | - Model Architecture 153 | - CNN & Transformer 154 | - character based model 155 | - Make the model even larger & deeper (... I need GPUs) 156 | - Tricks that might help 157 | - Add a proper noun dictionary to translate unknown proper nouns word-by-word (phrase-by-phrase) 158 | - Initialize (sub)word embedding with pretrained embedding 159 | 160 | ## How To Run 161 | - Download the dataset you desire, and change all "./zh_en_data" in `run.sh` to the path where your data is stored 162 | - To run locally on a CPU (mostly for sanity check, CPU is not able to train the model) 163 | - set up the environment using conda/miniconda `conda env create --file local env.yml` 164 | - To run on a GPU 165 | - set up the environment and running process following the Colab [notebook](https://colab.research.google.com/drive/1HYbOh0AUMEasBAH7QPGNq9joH2dRRZwg?usp=sharing) 166 | 167 | 168 | ## Contact 169 | If you have any questions or you have trouble running the code, feel free to contact me via [email](mailto:jasonfen@usc.edu) 170 | 171 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonFengGit/Neural-Model-Translation/5585389efc841a0cb3a26656caaba912b29d1770/__init__.py -------------------------------------------------------------------------------- /gpu_requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | docopt 3 | tqdm==4.29.1 4 | sentencepiece 5 | sacrebleu 6 | torch 7 | -------------------------------------------------------------------------------- /images/best_dev_ppl.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 272 | 292 | 316 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 347 | 348 | 349 | 350 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 499 | 505 | 519 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 596 | 597 | 598 | 601 | 602 | 603 | 606 | 607 | 608 | 611 | 612 | 613 | 616 | 617 | 618 | 619 | 630 | 631 | 632 | 635 | 636 | 637 | 638 | 639 | 640 | 666 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | -------------------------------------------------------------------------------- /images/gru_dev_ppl.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | 7 | 8 | 9 | 15 | 16 | 17 | 18 | 24 | 25 | 26 | 27 | 28 | 31 | 32 | 33 | 34 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 316 | 336 | 360 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 391 | 392 | 393 | 394 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 517 | 523 | 537 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 620 | 621 | 622 | 625 | 626 | 627 | 630 | 631 | 632 | 635 | 636 | 637 | 640 | 641 | 642 | 643 | 654 | 655 | 656 | 659 | 660 | 661 | 662 | 663 | 664 | 690 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | -------------------------------------------------------------------------------- /images/lstm_2_layer_dev_ppl.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 320 | 340 | 364 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 395 | 396 | 397 | 398 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 521 | 527 | 541 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 624 | 625 | 626 | 629 | 630 | 631 | 634 | 635 | 636 | 639 | 640 | 641 | 644 | 645 | 646 | 647 | 658 | 659 | 660 | 663 | 664 | 665 | 666 | 667 | 668 | 694 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | -------------------------------------------------------------------------------- /images/lstm_dev_ppl.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 320 | 340 | 364 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 395 | 396 | 397 | 398 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 521 | 527 | 541 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 624 | 625 | 626 | 629 | 630 | 631 | 634 | 635 | 636 | 639 | 640 | 641 | 644 | 645 | 646 | 647 | 658 | 659 | 660 | 663 | 664 | 665 | 666 | 667 | 668 | 694 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | -------------------------------------------------------------------------------- /images/lstm_leaky_dev_ppl.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 320 | 340 | 364 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 395 | 396 | 397 | 398 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 521 | 527 | 541 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 624 | 625 | 626 | 629 | 630 | 631 | 634 | 635 | 636 | 639 | 640 | 641 | 644 | 645 | 646 | 647 | 658 | 659 | 660 | 663 | 664 | 665 | 666 | 667 | 668 | 694 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | -------------------------------------------------------------------------------- /images/lstm_relu_dev_ppl.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 320 | 340 | 364 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 395 | 396 | 397 | 398 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 521 | 527 | 541 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 624 | 625 | 626 | 629 | 630 | 631 | 634 | 635 | 636 | 639 | 640 | 641 | 644 | 645 | 646 | 647 | 658 | 659 | 660 | 663 | 664 | 665 | 666 | 667 | 668 | 694 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | -------------------------------------------------------------------------------- /local_env.yml: -------------------------------------------------------------------------------- 1 | name: local_nmt 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - numpy 8 | - scipy 9 | - tqdm 10 | - docopt 11 | - pytorch 12 | - nltk 13 | - torchvision 14 | - pip 15 | - pip: 16 | - sentencepiece 17 | - sacrebleu 18 | - jieba 19 | -------------------------------------------------------------------------------- /model_embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | 6 | class ModelEmbeddings(nn.Module): 7 | """ 8 | Class that converts input words to their embeddings. 9 | """ 10 | def __init__(self, embed_size, vocab): 11 | """ 12 | Init the Embedding layers. 13 | 14 | @param embed_size (int): Embedding size (dimensionality) 15 | @param vocab (Vocab): Vocabulary object containing src and tgt languages 16 | See vocab.py for documentation. 17 | """ 18 | super(ModelEmbeddings, self).__init__() 19 | self.embed_size = embed_size 20 | 21 | # default values 22 | self.source = None 23 | self.target = None 24 | 25 | src_pad_token_idx = vocab.src[''] 26 | tgt_pad_token_idx = vocab.tgt[''] 27 | 28 | self.source = nn.Embedding(len(vocab.src), self.embed_size, padding_idx = src_pad_token_idx) 29 | self.target = nn.Embedding(len(vocab.tgt), self.embed_size, padding_idx = tgt_pad_token_idx) 30 | 31 | 32 | -------------------------------------------------------------------------------- /nmt_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | from collections import namedtuple 6 | from typing import Dict, List, Set, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.utils 12 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 13 | 14 | from model_embeddings import ModelEmbeddings 15 | 16 | Hypothesis = namedtuple('Hypothesis', ['value', 'score']) 17 | 18 | 19 | class NMT(nn.Module): 20 | """ Simple Neural Machine Translation Model: 21 | - RNN Encoder 22 | - Unidirection LSTM Decoder 23 | - Global Attention Model (Luong, et al. 2015) 24 | """ 25 | def __init__(self, embed_size, hidden_size, vocab, dropout_rate=0.2, rnn_layer=nn.LSTM, num_layers=1, activation=torch.tanh): 26 | """ Init NMT Model. 27 | 28 | @param embed_size (int): Embedding size (dimensionality) 29 | @param hidden_size (int): Hidden Size, the size of hidden states (dimensionality) 30 | @param vocab (Vocab): Vocabulary object containing src and tgt languages 31 | See vocab.py for documentation. 32 | @param dropout_rate (float): Dropout probability, for attention 33 | """ 34 | super(NMT, self).__init__() 35 | self.model_embeddings = ModelEmbeddings(embed_size, vocab) 36 | self.hidden_size = hidden_size 37 | self.dropout_rate = dropout_rate 38 | self.vocab = vocab 39 | self.rnn_layer = rnn_layer 40 | self.num_layers = num_layers 41 | self.activation = activation 42 | 43 | # default values 44 | self.encoder = None 45 | self.decoder = None 46 | self.h_projection = None 47 | self.c_projection = None 48 | self.att_projection = None 49 | self.combined_output_projection = None 50 | self.target_vocab_projection = None 51 | self.dropout = None 52 | print("***",self.const,"***") 53 | # model layers 54 | self.is_lstm = (rnn_layer == nn.LSTM) 55 | self.encoder = rnn_layer(input_size=embed_size, hidden_size=hidden_size, bidirectional=True, bias=True, num_layers=num_layers) 56 | self.decoder = nn.LSTMCell(input_size=embed_size+hidden_size, hidden_size=hidden_size, bias=True) 57 | self.h_projection = nn.Linear(hidden_size*2, hidden_size, bias=False) 58 | if self.is_lstm: 59 | self.c_projection = nn.Linear(hidden_size*2, hidden_size, bias=False) 60 | self.att_projection = nn.Linear(hidden_size*2, hidden_size, bias=False) 61 | self.combined_output_projection = nn.Linear(hidden_size*3, hidden_size, bias=False) 62 | self.target_vocab_projection = nn.Linear(hidden_size, len(vocab.tgt), bias=False) 63 | self.dropout = nn.Dropout(p=self.dropout_rate) 64 | 65 | 66 | def forward(self, source: List[List[str]], target: List[List[str]]) -> torch.Tensor: 67 | """ Take a mini-batch of source and target sentences, compute the log-likelihood of 68 | target sentences under the language models learned by the NMT system. 69 | 70 | @param source (List[List[str]]): list of source sentence tokens 71 | @param target (List[List[str]]): list of target sentence tokens, wrapped by `` and `` 72 | 73 | @returns scores (Tensor): a variable/tensor of shape (b, ) representing the 74 | log-likelihood of generating the gold-standard target sentence for 75 | each example in the input batch. Here b = batch size. 76 | """ 77 | # Compute sentence lengths 78 | source_lengths = [len(s) for s in source] 79 | 80 | # Convert list of lists into tensors 81 | source_padded = self.vocab.src.to_input_tensor(source, device=self.device) # Tensor: (src_len, b) 82 | target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device) # Tensor: (tgt_len, b) 83 | 84 | enc_hiddens, dec_init_state = self.encode(source_padded, source_lengths) 85 | enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths) 86 | combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state, target_padded) 87 | P = F.log_softmax(self.target_vocab_projection(combined_outputs), dim=-1) 88 | 89 | # Zero out, probabilities for which we have nothing in the target text 90 | target_masks = (target_padded != self.vocab.tgt['']).float() 91 | 92 | # Compute log probability of generating true target words 93 | target_gold_words_log_prob = torch.gather(P, index=target_padded[1:].unsqueeze(-1), dim=-1).squeeze(-1) * target_masks[1:] 94 | scores = target_gold_words_log_prob.sum(dim=0) 95 | return scores 96 | 97 | 98 | def encode(self, source_padded: torch.Tensor, source_lengths: List[int]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 99 | """ Apply the encoder to source sentences to obtain encoder hidden states. 100 | Additionally, take the final states of the encoder and project them to obtain initial states for decoder. 101 | 102 | @param source_padded (Tensor): Tensor of padded source sentences with shape (src_len, b), where 103 | b = batch_size, src_len = maximum source sentence length. Note that 104 | these have already been sorted in order of longest to shortest sentence. 105 | @param source_lengths (List[int]): List of actual lengths for each of the source sentences in the batch 106 | @returns enc_hiddens (Tensor): Tensor of hidden units with shape (b, src_len, h*2), where 107 | b = batch size, src_len = maximum source sentence length, h = hidden size. 108 | @returns dec_init_state (tuple(Tensor, Tensor)): Tuple of tensors representing the decoder's initial 109 | hidden state and cell. 110 | """ 111 | enc_hiddens, dec_init_state = None, None 112 | last_hidden, last_cell = None, None 113 | X = self.model_embeddings.source(source_padded) # (src_len, b, e) 114 | enc_hiddens, last_state = self.encoder(pack_padded_sequence(X, source_lengths)) 115 | if self.is_lstm: 116 | last_hidden, last_cell = last_state 117 | else: 118 | last_hidden = last_state 119 | enc_hiddens, _ = pad_packed_sequence(enc_hiddens) # returns seq_unpacked, lens_unpacked 120 | enc_hiddens = enc_hiddens.permute(1, 0, 2) 121 | 122 | init_decoder_hidden = self.h_projection(torch.cat((last_hidden[0], last_hidden[1]), 1)) 123 | 124 | if self.is_lstm: 125 | init_decoder_cell = self.c_projection(torch.cat((last_cell[0], last_cell[1]), 1)) 126 | dec_init_state = (init_decoder_hidden, init_decoder_cell) 127 | else: 128 | dec_init_state = (init_decoder_hidden, torch.zeros_like(init_decoder_hidden)) 129 | 130 | print(enc_hiddens.shape) 131 | print("*"*40) 132 | return enc_hiddens, dec_init_state 133 | 134 | 135 | def decode(self, enc_hiddens: torch.Tensor, enc_masks: torch.Tensor, 136 | dec_init_state: Tuple[torch.Tensor, torch.Tensor], target_padded: torch.Tensor) -> torch.Tensor: 137 | """Compute combined output vectors for a batch. 138 | 139 | @param enc_hiddens (Tensor): Hidden states (b, src_len, h*2), where 140 | b = batch size, src_len = maximum source sentence length, h = hidden size. 141 | @param enc_masks (Tensor): Tensor of sentence masks (b, src_len), where 142 | b = batch size, src_len = maximum source sentence length. 143 | @param dec_init_state (tuple(Tensor, Tensor)): Initial state and cell for decoder 144 | @param target_padded (Tensor): Gold-standard padded target sentences (tgt_len, b), where 145 | tgt_len = maximum target sentence length, b = batch size. 146 | 147 | @returns combined_outputs (Tensor): combined output tensor (tgt_len, b, h), where 148 | tgt_len = maximum target sentence length, b = batch_size, h = hidden size 149 | """ 150 | # Chop of the token for max length sentences. 151 | target_padded = target_padded[:-1] 152 | 153 | # Initialize the decoder state (hidden and cell) 154 | dec_state = dec_init_state 155 | 156 | # Initialize previous combined output vector o_{t-1} as zero 157 | batch_size = enc_hiddens.size(0) 158 | o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device) 159 | 160 | # Initialize a list we will use to collect the combined output o_t on each step 161 | combined_outputs = [] 162 | 163 | enc_hiddens_proj = self.att_projection(enc_hiddens) # (b, src_len, h) 164 | Y = self.model_embeddings.target(target_padded) # (tgt_len, b, e) 165 | for Y_t in torch.split(Y, 1, dim=0): 166 | Y_t = torch.squeeze(Y_t, 0) # (b, e) 167 | Ybar_t = torch.cat((Y_t, o_prev), 1) 168 | dec_state, o_t, e_t = self.step(Ybar_t, dec_state, enc_hiddens, enc_hiddens_proj, enc_masks) 169 | combined_outputs.append(o_t) 170 | o_prev = o_t 171 | combined_outputs = torch.stack(combined_outputs, dim=0) # (tgt_len, b, h) 172 | 173 | return combined_outputs 174 | 175 | 176 | def step(self, Ybar_t: torch.Tensor, 177 | dec_state: Tuple[torch.Tensor, torch.Tensor], 178 | enc_hiddens: torch.Tensor, 179 | enc_hiddens_proj: torch.Tensor, 180 | enc_masks: torch.Tensor) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: 181 | """ Compute one forward step of the LSTM decoder, including the attention computation. 182 | 183 | @param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h). The input for the decoder, 184 | where b = batch size, e = embedding size, h = hidden size. 185 | @param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h), where b = batch size, h = hidden size. 186 | First tensor is decoder's prev hidden state, second tensor is decoder's prev cell. 187 | @param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h * 2), where b = batch size, 188 | src_len = maximum source length, h = hidden size. 189 | @param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h * 2) to h. Tensor is with shape (b, src_len, h), 190 | where b = batch size, src_len = maximum source length, h = hidden size. 191 | @param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len), 192 | where b = batch size, src_len is maximum source length. 193 | 194 | @returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size. 195 | First tensor is decoder's new hidden state, second tensor is decoder's new cell. 196 | @returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size. 197 | @returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution. 198 | Note: You will not use this outside of this function. 199 | We are simply returning this value so that we can sanity check 200 | your implementation. 201 | """ 202 | 203 | combined_output = None 204 | dec_state = self.decoder(Ybar_t, dec_state) 205 | dec_hidden, dec_cell = dec_state 206 | e_t = torch.squeeze(torch.bmm(enc_hiddens_proj, torch.unsqueeze(dec_hidden, dim=2)), dim=2) 207 | 208 | # Set e_t to -inf where enc_masks has 1 to ignore tokens 209 | if enc_masks is not None: 210 | e_t.data.masked_fill_(enc_masks.bool(), -float('inf')) 211 | 212 | alpha_t = F.softmax(e_t, dim=1) # (b, src_len) 213 | a_t = torch.squeeze(torch.bmm(torch.unsqueeze(alpha_t, 1), enc_hiddens), dim=1) # (b, 2h) 214 | U_t = torch.cat((dec_hidden, a_t), dim=1) 215 | V_t = self.combined_output_projection(U_t) 216 | O_t = self.dropout(self.activation(V_t)) 217 | 218 | combined_output = O_t 219 | return dec_state, combined_output, e_t 220 | 221 | def generate_sent_masks(self, enc_hiddens: torch.Tensor, source_lengths: List[int]) -> torch.Tensor: 222 | """ Generate sentence masks for encoder hidden states. 223 | 224 | @param enc_hiddens (Tensor): encodings of shape (b, src_len, 2*h), where b = batch size, 225 | src_len = max source length, h = hidden size. 226 | @param source_lengths (List[int]): List of actual lengths for each of the sentences in the batch. 227 | 228 | @returns enc_masks (Tensor): Tensor of sentence masks of shape (b, src_len), 229 | where src_len = max source length, h = hidden size. 230 | """ 231 | enc_masks = torch.zeros(enc_hiddens.size(0), enc_hiddens.size(1), dtype=torch.float) 232 | for e_id, src_len in enumerate(source_lengths): 233 | enc_masks[e_id, src_len:] = 1 234 | return enc_masks.to(self.device) 235 | 236 | 237 | def beam_search(self, src_sent: List[str], beam_size: int=5, max_decoding_time_step: int=70) -> List[Hypothesis]: 238 | """ Given a single source sentence, perform beam search, yielding translations in the target language. 239 | @param src_sent (List[str]): a single source sentence (words) 240 | @param beam_size (int): beam size 241 | @param max_decoding_time_step (int): maximum number of time steps to unroll the decoding RNN 242 | @returns hypotheses (List[Hypothesis]): a list of hypothesis, each hypothesis has two fields: 243 | value: List[str]: the decoded target sentence, represented as a list of words 244 | score: float: the log-likelihood of the target sentence 245 | """ 246 | src_sents_var = self.vocab.src.to_input_tensor([src_sent], self.device) 247 | 248 | src_encodings, dec_init_vec = self.encode(src_sents_var, [len(src_sent)]) 249 | src_encodings_att_linear = self.att_projection(src_encodings) 250 | 251 | h_tm1 = dec_init_vec 252 | att_tm1 = torch.zeros(1, self.hidden_size, device=self.device) 253 | 254 | eos_id = self.vocab.tgt[''] 255 | 256 | hypotheses = [['']] 257 | hyp_scores = torch.zeros(len(hypotheses), dtype=torch.float, device=self.device) 258 | completed_hypotheses = [] 259 | 260 | t = 0 261 | while len(completed_hypotheses) < beam_size and t < max_decoding_time_step: 262 | t += 1 263 | hyp_num = len(hypotheses) 264 | 265 | exp_src_encodings = src_encodings.expand(hyp_num, 266 | src_encodings.size(1), 267 | src_encodings.size(2)) 268 | 269 | exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, 270 | src_encodings_att_linear.size(1), 271 | src_encodings_att_linear.size(2)) 272 | 273 | y_tm1 = torch.tensor([self.vocab.tgt[hyp[-1]] for hyp in hypotheses], dtype=torch.long, device=self.device) 274 | y_t_embed = self.model_embeddings.target(y_tm1) 275 | 276 | x = torch.cat([y_t_embed, att_tm1], dim=-1) 277 | 278 | (h_t, cell_t), att_t, _ = self.step(x, h_tm1, 279 | exp_src_encodings, exp_src_encodings_att_linear, enc_masks=None) 280 | 281 | # log probabilities over target words 282 | log_p_t = F.log_softmax(self.target_vocab_projection(att_t), dim=-1) 283 | 284 | live_hyp_num = beam_size - len(completed_hypotheses) 285 | contiuating_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(-1) 286 | top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(contiuating_hyp_scores, k=live_hyp_num) 287 | 288 | prev_hyp_ids = top_cand_hyp_pos // len(self.vocab.tgt) 289 | hyp_word_ids = top_cand_hyp_pos % len(self.vocab.tgt) 290 | 291 | new_hypotheses = [] 292 | live_hyp_ids = [] 293 | new_hyp_scores = [] 294 | 295 | for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores): 296 | prev_hyp_id = prev_hyp_id.item() 297 | hyp_word_id = hyp_word_id.item() 298 | cand_new_hyp_score = cand_new_hyp_score.item() 299 | 300 | hyp_word = self.vocab.tgt.id2word[hyp_word_id] 301 | new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word] 302 | if hyp_word == '': 303 | completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1], 304 | score=cand_new_hyp_score)) 305 | else: 306 | new_hypotheses.append(new_hyp_sent) 307 | live_hyp_ids.append(prev_hyp_id) 308 | new_hyp_scores.append(cand_new_hyp_score) 309 | 310 | if len(completed_hypotheses) == beam_size: 311 | break 312 | 313 | live_hyp_ids = torch.tensor(live_hyp_ids, dtype=torch.long, device=self.device) 314 | h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) 315 | att_tm1 = att_t[live_hyp_ids] 316 | 317 | hypotheses = new_hypotheses 318 | hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device) 319 | 320 | if len(completed_hypotheses) == 0: 321 | completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:], 322 | score=hyp_scores[0].item())) 323 | 324 | completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True) 325 | 326 | return completed_hypotheses 327 | 328 | @property 329 | def device(self) -> torch.device: 330 | """ Determine which device to place the Tensors upon, CPU or GPU. 331 | """ 332 | return self.model_embeddings.source.weight.device 333 | 334 | @staticmethod 335 | def load(model_path: str): 336 | """ Load the model from a file. 337 | @param model_path (str): path to model 338 | """ 339 | params = torch.load(model_path, map_location=lambda storage, loc: storage) 340 | args = params['args'] 341 | model = NMT(vocab=params['vocab'], **args) 342 | model.load_state_dict(params['state_dict']) 343 | 344 | return model 345 | 346 | def save(self, path: str): 347 | """ Save the odel to a file. 348 | @param path (str): path to the model 349 | """ 350 | print('save model parameters to [%s]' % path, file=sys.stderr) 351 | 352 | params = { 353 | 'args': dict(embed_size=self.model_embeddings.embed_size, hidden_size=self.hidden_size, dropout_rate=self.dropout_rate, rnn_layer=self.rnn_layer, num_layers=self.num_layers, activation=self.activation), 354 | 'vocab': self.vocab, 355 | 'state_dict': self.state_dict() 356 | } 357 | 358 | torch.save(params, path) 359 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Usage: 6 | run.py train --train-src= --train-tgt= --dev-src= --dev-tgt= --vocab= [options] 7 | run.py decode [options] MODEL_PATH TEST_SOURCE_FILE OUTPUT_FILE 8 | run.py decode [options] MODEL_PATH TEST_SOURCE_FILE TEST_TARGET_FILE OUTPUT_FILE 9 | 10 | Options: 11 | -h --help show this screen. 12 | --cuda use GPU 13 | --train-src= train source file 14 | --train-tgt= train target file 15 | --dev-src= dev source file 16 | --dev-tgt= dev target file 17 | --vocab= vocab file 18 | --seed= seed [default: 0] 19 | --batch-size= batch size [default: 32] 20 | --embed-size= embedding size [default: 256] 21 | --hidden-size= hidden size [default: 256] 22 | --clip-grad= gradient clipping [default: 5.0] 23 | --log-every= log every [default: 10] 24 | --max-epoch= max epoch [default: 30] 25 | --input-feed use input feeding 26 | --patience= wait for how many iterations to decay learning rate [default: 5] 27 | --max-num-trial= terminate training after how many trials [default: 5] 28 | --lr-decay= learning rate decay [default: 0.5] 29 | --beam-size= beam size [default: 5] 30 | --sample-size= sample size [default: 5] 31 | --lr= learning rate [default: 0.001] 32 | --uniform-init= uniformly initialize all parameters [default: 0.1] 33 | --save-to= model save path [default: model.bin] 34 | --valid-niter= perform validation after how many iterations [default: 2000] 35 | --dropout= dropout [default: 0.3] 36 | --max-decoding-time-step= maximum number of decoding time steps [default: 70] 37 | """ 38 | import math 39 | import sys 40 | import time 41 | 42 | 43 | from docopt import docopt 44 | # from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction 45 | import sacrebleu 46 | #from nmt_model import Hypothesis, NMT 47 | from nmt_model import Hypothesis, NMT 48 | import numpy as np 49 | from typing import List, Tuple, Dict, Set, Union 50 | from tqdm import tqdm 51 | from utils import read_corpus, batch_iter, read_sent_zh 52 | from vocab import Vocab 53 | 54 | import torch 55 | import torch.nn.utils 56 | from torch import nn 57 | 58 | 59 | def evaluate_ppl(model, dev_data, batch_size=32): 60 | """ Evaluate perplexity on dev sentences 61 | @param model (NMT): NMT Model 62 | @param dev_data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence 63 | @param batch_size (batch size) 64 | @returns ppl (perplixty on dev sentences) 65 | """ 66 | was_training = model.training 67 | model.eval() 68 | 69 | cum_loss = 0. 70 | cum_tgt_words = 0. 71 | 72 | # no_grad() signals backend to throw away all gradients 73 | with torch.no_grad(): 74 | for src_sents, tgt_sents in batch_iter(dev_data, batch_size): 75 | loss = -model(src_sents, tgt_sents).sum() 76 | 77 | cum_loss += loss.item() 78 | tgt_word_num_to_predict = sum(len(s[1:]) for s in tgt_sents) # omitting leading `` 79 | cum_tgt_words += tgt_word_num_to_predict 80 | 81 | ppl = np.exp(cum_loss / cum_tgt_words) 82 | 83 | if was_training: 84 | model.train() 85 | 86 | return ppl 87 | 88 | 89 | def compute_corpus_level_bleu_score(references: List[List[str]], hypotheses: List[Hypothesis]) -> float: 90 | """ Given decoding results and reference sentences, compute corpus-level BLEU score. 91 | @param references (List[List[str]]): a list of gold-standard reference target sentences 92 | @param hypotheses (List[Hypothesis]): a list of hypotheses, one for each reference 93 | @returns bleu_score: corpus-level BLEU score 94 | """ 95 | # remove the start and end tokens 96 | if references[0][0] == '': 97 | references = [ref[1:-1] for ref in references] 98 | 99 | # detokenize the subword pieces to get full sentences 100 | detokened_refs = [''.join(pieces).replace('▁', ' ') for pieces in references] 101 | detokened_hyps = [''.join(hyp.value).replace('▁', ' ') for hyp in hypotheses] 102 | print(detokened_refs) 103 | print(detokened_hyps) 104 | # sacreBLEU can take multiple references (golden example per sentence) but we only feed it one 105 | bleu = sacrebleu.corpus_bleu(detokened_hyps, [detokened_refs]) 106 | 107 | return bleu.score 108 | 109 | 110 | def train(args: Dict): 111 | """ Train the NMT Model. 112 | @param args (Dict): args from cmd line 113 | """ 114 | train_data_src = read_corpus(args['--train-src'], source='src', vocab_size=21000) 115 | train_data_tgt = read_corpus(args['--train-tgt'], source='tgt', vocab_size=8000) 116 | 117 | dev_data_src = read_corpus(args['--dev-src'], source='src', vocab_size=3000) 118 | dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt', vocab_size=2000) 119 | 120 | train_data = list(zip(train_data_src, train_data_tgt)) 121 | dev_data = list(zip(dev_data_src, dev_data_tgt)) 122 | 123 | train_batch_size = int(args['--batch-size']) 124 | clip_grad = float(args['--clip-grad']) 125 | valid_niter = int(args['--valid-niter']) 126 | log_every = int(args['--log-every']) 127 | model_save_path = args['--save-to'] 128 | 129 | vocab = Vocab.load(args['--vocab']) 130 | 131 | model = NMT(embed_size=512, 132 | hidden_size=512, 133 | dropout_rate=float(args['--dropout']), 134 | vocab=vocab, 135 | rnn_layer=nn.LSTM, 136 | bidirectional=False) 137 | 138 | 139 | model.train() 140 | 141 | uniform_init = float(args['--uniform-init']) 142 | if np.abs(uniform_init) > 0.: 143 | print('uniformly initialize parameters [-%f, +%f]' % (uniform_init, uniform_init), file=sys.stderr) 144 | for p in model.parameters(): 145 | p.data.uniform_(-uniform_init, uniform_init) 146 | 147 | vocab_mask = torch.ones(len(vocab.tgt)) 148 | vocab_mask[vocab.tgt['']] = 0 149 | 150 | device = torch.device("cuda:0" if args['--cuda'] else "cpu") 151 | print('use device: %s' % device, file=sys.stderr) 152 | 153 | model = model.to(device) 154 | 155 | optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr'])) 156 | 157 | num_trial = 0 158 | train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0 159 | cum_examples = report_examples = epoch = valid_num = 0 160 | hist_valid_scores = [] 161 | train_time = begin_time = time.time() 162 | print('begin Maximum Likelihood training') 163 | train_ppl_log = open("ppl.log", "w") 164 | dev_ppl_log = open("dev_ppl.log", "w") 165 | while True: 166 | epoch += 1 167 | 168 | for src_sents, tgt_sents in batch_iter(train_data, batch_size=train_batch_size, shuffle=True): 169 | train_iter += 1 170 | 171 | optimizer.zero_grad() 172 | 173 | batch_size = len(src_sents) 174 | 175 | example_losses = -model(src_sents, tgt_sents) # (batch_size,) 176 | batch_loss = example_losses.sum() 177 | loss = batch_loss / batch_size 178 | 179 | loss.backward() 180 | 181 | # clip gradient 182 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) 183 | 184 | optimizer.step() 185 | 186 | batch_losses_val = batch_loss.item() 187 | report_loss += batch_losses_val 188 | cum_loss += batch_losses_val 189 | 190 | tgt_words_num_to_predict = sum(len(s[1:]) for s in tgt_sents) # omitting leading `` 191 | report_tgt_words += tgt_words_num_to_predict 192 | cum_tgt_words += tgt_words_num_to_predict 193 | report_examples += batch_size 194 | cum_examples += batch_size 195 | 196 | if train_iter % log_every == 0: 197 | print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \ 198 | 'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter, 199 | report_loss / report_examples, 200 | math.exp(report_loss / report_tgt_words), 201 | cum_examples, 202 | report_tgt_words / (time.time() - train_time), 203 | time.time() - begin_time), file=sys.stderr) 204 | train_ppl_log.write("{} {}\n".format(train_iter, math.exp(report_loss / report_tgt_words))) 205 | train_time = time.time() 206 | report_loss = report_tgt_words = report_examples = 0. 207 | 208 | # perform validation 209 | if train_iter % valid_niter == 0: 210 | print('epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d' % (epoch, train_iter, 211 | cum_loss / cum_examples, 212 | np.exp(cum_loss / cum_tgt_words), 213 | cum_examples), file=sys.stderr) 214 | 215 | cum_loss = cum_examples = cum_tgt_words = 0. 216 | valid_num += 1 217 | 218 | print('begin validation ...', file=sys.stderr) 219 | 220 | # compute dev. ppl and bleu 221 | dev_ppl = evaluate_ppl(model, dev_data, batch_size=128) # dev batch size can be a bit larger 222 | valid_metric = -dev_ppl 223 | 224 | print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), file=sys.stderr) 225 | dev_ppl_log.write("{} {}\n".format(train_iter, dev_ppl)) 226 | is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores) 227 | hist_valid_scores.append(valid_metric) 228 | 229 | if is_better: 230 | patience = 0 231 | print('save currently the best model to [%s]' % model_save_path, file=sys.stderr) 232 | model.save(model_save_path) 233 | 234 | # also save the optimizers' state 235 | torch.save(optimizer.state_dict(), model_save_path + '.optim') 236 | elif patience < int(args['--patience']): 237 | patience += 1 238 | print('hit patience %d' % patience, file=sys.stderr) 239 | 240 | if patience == int(args['--patience']): 241 | num_trial += 1 242 | print('hit #%d trial' % num_trial, file=sys.stderr) 243 | if num_trial == int(args['--max-num-trial']): 244 | print('early stop!', file=sys.stderr) 245 | exit(0) 246 | 247 | # decay lr, and restore from previously best checkpoint 248 | lr = optimizer.param_groups[0]['lr'] * float(args['--lr-decay']) 249 | print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) 250 | 251 | # load model 252 | params = torch.load(model_save_path, map_location=lambda storage, loc: storage) 253 | model.load_state_dict(params['state_dict']) 254 | model = model.to(device) 255 | 256 | print('restore parameters of the optimizers', file=sys.stderr) 257 | optimizer.load_state_dict(torch.load(model_save_path + '.optim')) 258 | 259 | # set new lr 260 | for param_group in optimizer.param_groups: 261 | param_group['lr'] = lr 262 | 263 | # reset patience 264 | patience = 0 265 | 266 | if epoch == int(args['--max-epoch']): 267 | print('reached maximum number of epochs!', file=sys.stderr) 268 | exit(0) 269 | 270 | 271 | def decode(args: Dict[str, str]): 272 | """ Performs decoding on a test set, and save the best-scoring decoding results. 273 | If the target gold-standard sentences are given, the function also computes 274 | corpus-level BLEU score. 275 | @param args (Dict): args from cmd line 276 | """ 277 | 278 | print("load test source sentences from [{}]".format(args['TEST_SOURCE_FILE']), file=sys.stderr) 279 | test_data_src = read_corpus(args['TEST_SOURCE_FILE'], source='src', vocab_size=3000) 280 | if args['TEST_TARGET_FILE']: 281 | print("load test target sentences from [{}]".format(args['TEST_TARGET_FILE']), file=sys.stderr) 282 | test_data_tgt = read_corpus(args['TEST_TARGET_FILE'], source='tgt', vocab_size=2000) 283 | print("load model from {}".format(args['MODEL_PATH']), file=sys.stderr) 284 | model = NMT.load(args['MODEL_PATH']) 285 | 286 | if args['--cuda']: 287 | model = model.to(torch.device("cuda:0")) 288 | 289 | hypotheses = beam_search(model, test_data_src, 290 | beam_size=10, 291 | max_decoding_time_step=int(args['--max-decoding-time-step'])) 292 | 293 | if args['TEST_TARGET_FILE']: 294 | top_hypotheses = [hyps[0] for hyps in hypotheses] 295 | bleu_score = compute_corpus_level_bleu_score(test_data_tgt, top_hypotheses) 296 | print('Corpus BLEU: {}'.format(bleu_score), file=sys.stderr) 297 | 298 | with open(args['OUTPUT_FILE'], 'w') as f: 299 | for src_sent, hyps in zip(test_data_src, hypotheses): 300 | top_hyp = hyps[0] 301 | src_sent = ''.join(src_sent).replace('_', ' ') 302 | hyp_sent = ''.join(top_hyp.value).replace('▁', ' ') 303 | f.write(src_sent+'\n'+hyp_sent + '\n\n') 304 | 305 | def beam_search(model: NMT, test_data_src: List[List[str]], beam_size: int, max_decoding_time_step: int) -> List[List[Hypothesis]]: 306 | """ Run beam search to construct hypotheses for a list of src-language sentences. 307 | @param model (NMT): NMT Model 308 | @param test_data_src (List[List[str]]): List of sentences (words) in source language, from test set. 309 | @param beam_size (int): beam_size (# of hypotheses to hold for a translation at every step) 310 | @param max_decoding_time_step (int): maximum sentence length that Beam search can produce 311 | @returns hypotheses (List[List[Hypothesis]]): List of Hypothesis translations for every source sentence. 312 | """ 313 | was_training = model.training 314 | model.eval() 315 | 316 | hypotheses = [] 317 | with torch.no_grad(): 318 | for src_sent in tqdm(test_data_src, desc='Decoding', file=sys.stdout): 319 | example_hyps = model.beam_search(src_sent, beam_size=beam_size, max_decoding_time_step=max_decoding_time_step) 320 | 321 | hypotheses.append(example_hyps) 322 | 323 | if was_training: 324 | model.train(was_training) 325 | 326 | return hypotheses 327 | 328 | 329 | if __name__ == '__main__': 330 | args = docopt(__doc__) 331 | # Check pytorch version 332 | assert(torch.__version__ >= "1.0.0"), "Please update your installation of PyTorch. You have {} and you should have version 1.0.0".format(torch.__version__) 333 | 334 | # seed the random number generators 335 | seed = int(args['--seed']) 336 | torch.manual_seed(seed) 337 | if args['--cuda']: 338 | torch.cuda.manual_seed(seed) 339 | np.random.seed(seed * 13 // 7) 340 | if args['train']: 341 | train(args) 342 | elif args['decode']: 343 | decode(args) 344 | else: 345 | raise RuntimeError('invalid run mode') 346 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # PLEASE change all "./zh_en_data" to the path where your data is stored 3 | 4 | if [ "$1" = "train" ]; then 5 | CUDA_VISIBLE_DEVICES=0 python run.py train --train-src=./zh_en_data/train.zh --train-tgt=./zh_en_data/train.en --dev-src=./zh_en_data/dev.zh --dev-tgt=./zh_en_data/dev.en --vocab=./zh_en_data/vocab_zh_en.json --cuda --lr=5e-4 --patience=1 --valid-niter=1000 --batch-size=32 --dropout=.25 6 | elif [ "$1" = "test" ]; then 7 | if [ "$2" = "" ]; then 8 | CUDA_VISIBLE_DEVICES=0 python run.py decode model.bin ./zh_en_data/test.zh ./zh_en_data/test.en outputs/test_outputs.txt --cuda 9 | else 10 | CUDA_VISIBLE_DEVICES=0 python run.py decode $2 ./zh_en_data/test.zh ./zh_en_data/test.en outputs/test_outputs.txt --cuda 11 | fi 12 | elif [ "$1" = "train_local" ]; then 13 | python run.py train --train-src=./zh_en_data/train.zh --train-tgt=./zh_en_data/train.en --dev-src=./zh_en_data/dev.zh --dev-tgt=./zh_en_data/dev.en --vocab=./zh_en_data/vocab_zh_en.json --lr=5e-4 14 | elif [ "$1" = "test_local" ]; then 15 | python run.py decode model.bin ./zh_en_data/test.zh ./zh_en_data/test.en outputs/test_outputs.txt 16 | elif [ "$1" = "vocab" ]; then 17 | python vocab.py --train-src=./zh_en_data/train.zh --train-tgt=./zh_en_data/train.en ./zh_en_data/vocab_zh_en.json 18 | else 19 | echo "Invalid Option Selected" 20 | fi 21 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | 6 | import numpy as np 7 | import nltk 8 | import sentencepiece as spm 9 | import jieba 10 | 11 | nltk.download('punkt') 12 | 13 | 14 | def pad_sents(sents, pad_token): 15 | """ Pad list of sentences according to the longest sentence in the batch. 16 | The paddings should be at the end of each sentence. 17 | @param sents (list[list[str]]): list of sentences, where each sentence 18 | is represented as a list of words 19 | @param pad_token (str): padding token 20 | @returns sents_padded (list[list[str]]): list of sentences where sentences shorter 21 | than the max length sentence are padded out with the pad_token, such that 22 | each sentences in the batch now has equal length. 23 | """ 24 | sents_padded = [] 25 | 26 | ### YOUR CODE HERE (~6 Lines) 27 | max_len = max([len(each) for each in sents]) 28 | for sent in sents: 29 | sent += [pad_token] * (max_len - len(sent)) 30 | sents_padded.append(sent) 31 | 32 | ### END YOUR CODE 33 | 34 | return sents_padded 35 | 36 | 37 | def read_corpus(file_path, source, vocab_size=2500): 38 | """ Read file, where each sentence is dilineated by a `\n`. 39 | @param file_path (str): path to file containing corpus 40 | @param source (str): "tgt" or "src" indicating whether text 41 | is of the source language or target language 42 | @param vocab_size (int): number of unique subwords in 43 | vocabulary when reading and tokenizing 44 | """ 45 | data = [] 46 | sp = spm.SentencePieceProcessor() 47 | sp.load('{}.model'.format(source)) 48 | 49 | with open(file_path, 'r', encoding='utf8') as f: 50 | for line in f: 51 | subword_tokens = sp.encode_as_pieces(line) 52 | 53 | # only append and to the target sentence 54 | if source == 'tgt': 55 | subword_tokens = [""] + subword_tokens + [""] 56 | 57 | data.append(subword_tokens) 58 | 59 | return data 60 | 61 | 62 | def read_sent_zh(sent, source): 63 | """ Read a Chinese sentence, seperate the words using jieba, and generate tokens 64 | @param sent (str): path to file containing corpus 65 | @param source (str): "tgt" or "src" for selecting sp model 66 | """ 67 | sp = spm.SentencePieceProcessor() 68 | sp.load('{}.model'.format(source)) 69 | 70 | sent = " ".join(jieba.cut(sent, HMM=True)) 71 | subword_tokens = sp.encode_as_pieces(sent) 72 | 73 | return subword_tokens 74 | 75 | 76 | def batch_iter(data, batch_size, shuffle=False): 77 | """ Yield batches of source and target sentences reverse sorted by length (largest to smallest). 78 | @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence 79 | @param batch_size (int): batch size 80 | @param shuffle (boolean): whether to randomly shuffle the dataset 81 | """ 82 | batch_num = math.ceil(len(data) / batch_size) 83 | index_array = list(range(len(data))) 84 | 85 | if shuffle: 86 | np.random.shuffle(index_array) 87 | 88 | for i in range(batch_num): 89 | indices = index_array[i * batch_size: (i + 1) * batch_size] 90 | examples = [data[idx] for idx in indices] 91 | 92 | examples = sorted(examples, key=lambda e: len(e[0]), reverse=True) 93 | examples = [examples[i] for i in range(len(examples)) if len(examples[i][0]) > 0 and len(examples[i][1]) > 0] 94 | 95 | src_sents = [e[0] for e in examples] 96 | tgt_sents = [e[1] for e in examples] 97 | yield src_sents, tgt_sents 98 | 99 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Usage: 6 | vocab.py --train-src= --train-tgt= [options] VOCAB_FILE 7 | 8 | Options: 9 | -h --help Show this screen. 10 | --train-src= File of training source sentences 11 | --train-tgt= File of training target sentences 12 | --size= vocab size [default: 50000] 13 | --freq-cutoff= frequency cutoff [default: 2] 14 | """ 15 | 16 | from collections import Counter 17 | from docopt import docopt 18 | from itertools import chain 19 | import json 20 | import torch 21 | from typing import List 22 | from utils import pad_sents 23 | import sentencepiece as spm 24 | 25 | 26 | class VocabEntry(object): 27 | """ Vocabulary Entry, i.e. structure containing either 28 | src or tgt language terms. 29 | """ 30 | def __init__(self, word2id=None): 31 | """ Init VocabEntry Instance. 32 | @param word2id (dict): dictionary mapping words 2 indices 33 | """ 34 | if word2id: 35 | self.word2id = word2id 36 | else: 37 | self.word2id = dict() 38 | self.word2id[''] = 0 # Pad Token 39 | self.word2id[''] = 1 # Start Token 40 | self.word2id[''] = 2 # End Token 41 | self.word2id[''] = 3 # Unknown Token 42 | self.unk_id = self.word2id[''] 43 | self.id2word = {v: k for k, v in self.word2id.items()} 44 | 45 | def __getitem__(self, word): 46 | """ Retrieve word's index. Return the index for the unk 47 | token if the word is out of vocabulary. 48 | @param word (str): word to look up. 49 | @returns index (int): index of word 50 | """ 51 | return self.word2id.get(word, self.unk_id) 52 | 53 | def __contains__(self, word): 54 | """ Check if word is captured by VocabEntry. 55 | @param word (str): word to look up 56 | @returns contains (bool): whether word is contained 57 | """ 58 | return word in self.word2id 59 | 60 | def __len__(self): 61 | """ Compute number of words in VocabEntry. 62 | @returns len (int): number of words in VocabEntry 63 | """ 64 | return len(self.word2id) 65 | 66 | def __repr__(self): 67 | """ Representation of VocabEntry to be used 68 | when printing the object. 69 | """ 70 | return 'Vocabulary[size=%d]' % len(self) 71 | 72 | def id2word(self, wid): 73 | """ Return mapping of index to word. 74 | @param wid (int): word index 75 | @returns word (str): word corresponding to index 76 | """ 77 | return self.id2word[wid] 78 | 79 | def add(self, word): 80 | """ Add word to VocabEntry, if it is previously unseen. 81 | @param word (str): word to add to VocabEntry 82 | @return index (int): index that the word has been assigned 83 | """ 84 | if word not in self: 85 | wid = self.word2id[word] = len(self) 86 | self.id2word[wid] = word 87 | return wid 88 | else: 89 | return self[word] 90 | 91 | def words2indices(self, sents): 92 | """ Convert list of words or list of sentences of words 93 | into list or list of list of indices. 94 | @param sents (list[str] or list[list[str]]): sentence(s) in words 95 | @return word_ids (list[int] or list[list[int]]): sentence(s) in indices 96 | """ 97 | if type(sents[0]) == list: 98 | return [[self[w] for w in s] for s in sents] 99 | else: 100 | return [self[w] for w in sents] 101 | 102 | def indices2words(self, word_ids): 103 | """ Convert list of indices into words. 104 | @param word_ids (list[int]): list of word ids 105 | @return sents (list[str]): list of words 106 | """ 107 | return [self.id2word[w_id] for w_id in word_ids] 108 | 109 | def to_input_tensor(self, sents: List[List[str]], device: torch.device) -> torch.Tensor: 110 | """ Convert list of sentences (words) into tensor with necessary padding for 111 | shorter sentences. 112 | 113 | @param sents (List[List[str]]): list of sentences (words) 114 | @param device: device on which to load the tesnor, i.e. CPU or GPU 115 | 116 | @returns sents_var: tensor of (max_sentence_length, batch_size) 117 | """ 118 | word_ids = self.words2indices(sents) 119 | sents_t = pad_sents(word_ids, self['']) 120 | sents_var = torch.tensor(sents_t, dtype=torch.long, device=device) 121 | return torch.t(sents_var) 122 | 123 | @staticmethod 124 | def from_corpus(corpus, size, freq_cutoff=2): 125 | """ Given a corpus construct a Vocab Entry. 126 | @param corpus (list[str]): corpus of text produced by read_corpus function 127 | @param size (int): # of words in vocabulary 128 | @param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word 129 | @returns vocab_entry (VocabEntry): VocabEntry instance produced from provided corpus 130 | """ 131 | vocab_entry = VocabEntry() 132 | word_freq = Counter(chain(*corpus)) 133 | valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff] 134 | print('number of word types: {}, number of word types w/ frequency >= {}: {}' 135 | .format(len(word_freq), freq_cutoff, len(valid_words))) 136 | top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size] 137 | for word in top_k_words: 138 | vocab_entry.add(word) 139 | return vocab_entry 140 | 141 | @staticmethod 142 | def from_subword_list(subword_list): 143 | vocab_entry = VocabEntry() 144 | for subword in subword_list: 145 | vocab_entry.add(subword) 146 | return vocab_entry 147 | 148 | 149 | class Vocab(object): 150 | """ Vocab encapsulating src and target langauges. 151 | """ 152 | def __init__(self, src_vocab: VocabEntry, tgt_vocab: VocabEntry): 153 | """ Init Vocab. 154 | @param src_vocab (VocabEntry): VocabEntry for source language 155 | @param tgt_vocab (VocabEntry): VocabEntry for target language 156 | """ 157 | self.src = src_vocab 158 | self.tgt = tgt_vocab 159 | 160 | @staticmethod 161 | def build(src_sents, tgt_sents) -> 'Vocab': 162 | """ Build Vocabulary. 163 | @param src_sents (list[str]): Source subwords provided by SentencePiece 164 | @param tgt_sents (list[str]): Target subwords provided by SentencePiece 165 | """ 166 | 167 | print('initialize source vocabulary ..') 168 | src = VocabEntry.from_subword_list(src_sents) 169 | 170 | print('initialize target vocabulary ..') 171 | tgt = VocabEntry.from_subword_list(tgt_sents) 172 | 173 | return Vocab(src, tgt) 174 | 175 | def save(self, file_path): 176 | """ Save Vocab to file as JSON dump. 177 | @param file_path (str): file path to vocab file 178 | """ 179 | with open(file_path, 'w') as f: 180 | json.dump(dict(src_word2id=self.src.word2id, tgt_word2id=self.tgt.word2id), f, indent=2) 181 | 182 | @staticmethod 183 | def load(file_path): 184 | """ Load vocabulary from JSON dump. 185 | @param file_path (str): file path to vocab file 186 | @returns Vocab object loaded from JSON dump 187 | """ 188 | entry = json.load(open(file_path, 'r')) 189 | src_word2id = entry['src_word2id'] 190 | tgt_word2id = entry['tgt_word2id'] 191 | return Vocab(VocabEntry(src_word2id), VocabEntry(tgt_word2id)) 192 | 193 | def __repr__(self): 194 | """ Representation of Vocab to be used 195 | when printing the object. 196 | """ 197 | return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt)) 198 | 199 | 200 | def get_vocab_list(file_path, source, vocab_size): 201 | """ Use SentencePiece to tokenize and acquire list of unique subwords. 202 | @param file_path (str): file path to corpus 203 | @param source (str): tgt or src 204 | @param vocab_size: desired vocabulary size 205 | """ 206 | spm.SentencePieceTrainer.train(input=file_path, model_prefix=source, vocab_size=vocab_size) # train the spm model 207 | sp = spm.SentencePieceProcessor() # create an instance; this saves .model and .vocab files 208 | sp.load('{}.model'.format(source)) # loads tgt.model or src.model 209 | sp_list = [sp.id_to_piece(piece_id) for piece_id in range(sp.get_piece_size())] # this is the list of subwords 210 | return sp_list 211 | 212 | 213 | 214 | if __name__ == '__main__': 215 | args = docopt(__doc__) 216 | 217 | print('read in source sentences: %s' % args['--train-src']) 218 | print('read in target sentences: %s' % args['--train-tgt']) 219 | 220 | src_sents = get_vocab_list(args['--train-src'], source='src', vocab_size=21000) 221 | tgt_sents = get_vocab_list(args['--train-tgt'], source='tgt', vocab_size=8000) 222 | vocab = Vocab.build(src_sents, tgt_sents) 223 | print('generated vocabulary, source %d words, target %d words' % (len(src_sents), len(tgt_sents))) 224 | 225 | vocab.save(args['VOCAB_FILE']) 226 | print('vocabulary saved to %s' % args['VOCAB_FILE']) 227 | --------------------------------------------------------------------------------