├── .gitignore ├── Deep Reinforcement Learning For Language Models.pdf ├── Final_Saves ├── RL_BLUE │ ├── bleu_0.135_177.dat │ └── emb_dict.dat ├── RL_COSINE │ ├── cosine_0.621_03.dat │ └── emb_dict.dat ├── RL_Mutual │ ├── emb_dict.dat │ └── epoch_180_-4.325_-7.192.dat ├── RL_Perplexity │ ├── emb_dict.dat │ └── epoch_050_1.463_3.701.dat ├── backward_seq2seq │ ├── emb_dict.dat │ └── epoch_080_0.780_0.104.dat └── seq2seq │ ├── emb_dict.dat │ └── epoch_090_0.800_0.107.dat ├── Images ├── MMI comparison.PNG ├── Training BLEU RL.PNG ├── Training CE Seq2Seq.PNG ├── Training RL Cosine.PNG ├── Training RL MMI.PNG ├── Training RL Perplexity.PNG ├── automatic evaluations.PNG └── many generated sequences.PNG ├── README.md ├── cur_reader.py ├── data_test.py ├── libbots ├── cornell.py ├── data.py ├── model.py └── utils.py ├── model_test.py ├── test_all_modells.py ├── tests ├── test_data.py └── test_subtitles.py ├── train_crossent.py ├── train_rl_BLEU.py ├── train_rl_MMI.py ├── train_rl_PREPLEXITY.py ├── train_rl_cosine.py └── use_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | cornell 2 | data -------------------------------------------------------------------------------- /Deep Reinforcement Learning For Language Models.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Deep Reinforcement Learning For Language Models.pdf -------------------------------------------------------------------------------- /Final_Saves/RL_BLUE/bleu_0.135_177.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_BLUE/bleu_0.135_177.dat -------------------------------------------------------------------------------- /Final_Saves/RL_BLUE/emb_dict.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_BLUE/emb_dict.dat -------------------------------------------------------------------------------- /Final_Saves/RL_COSINE/cosine_0.621_03.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_COSINE/cosine_0.621_03.dat -------------------------------------------------------------------------------- /Final_Saves/RL_COSINE/emb_dict.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_COSINE/emb_dict.dat -------------------------------------------------------------------------------- /Final_Saves/RL_Mutual/emb_dict.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Mutual/emb_dict.dat -------------------------------------------------------------------------------- /Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat -------------------------------------------------------------------------------- /Final_Saves/RL_Perplexity/emb_dict.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Perplexity/emb_dict.dat -------------------------------------------------------------------------------- /Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat -------------------------------------------------------------------------------- /Final_Saves/backward_seq2seq/emb_dict.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/backward_seq2seq/emb_dict.dat -------------------------------------------------------------------------------- /Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat -------------------------------------------------------------------------------- /Final_Saves/seq2seq/emb_dict.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/seq2seq/emb_dict.dat -------------------------------------------------------------------------------- /Final_Saves/seq2seq/epoch_090_0.800_0.107.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/seq2seq/epoch_090_0.800_0.107.dat -------------------------------------------------------------------------------- /Images/MMI comparison.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/MMI comparison.PNG -------------------------------------------------------------------------------- /Images/Training BLEU RL.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training BLEU RL.PNG -------------------------------------------------------------------------------- /Images/Training CE Seq2Seq.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training CE Seq2Seq.PNG -------------------------------------------------------------------------------- /Images/Training RL Cosine.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training RL Cosine.PNG -------------------------------------------------------------------------------- /Images/Training RL MMI.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training RL MMI.PNG -------------------------------------------------------------------------------- /Images/Training RL Perplexity.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training RL Perplexity.PNG -------------------------------------------------------------------------------- /Images/automatic evaluations.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/automatic evaluations.PNG -------------------------------------------------------------------------------- /Images/many generated sequences.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/many generated sequences.PNG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Language Models Using Reinforcement Learning 2 | 3 | In this work, we show how to use traditional 4 | evaluation methods used in NLP to improve a language model. We construct 5 | couple of models, using different objective for each model, and comparing between 6 | them, while offering some intuition regarding the results. 7 | 8 | The project implementation contains: 9 | 1. Basic Cross-entropy seq2seq model training - Using couple of tricks to get a faster training convergence 10 | 2. 4 Policy Gradient model - using 4 different rewards (for further explanation, please refer to "Deep Reinforcement Learning For Language Models.pdf", which is in this repository): 11 | - [ ] BLUEU score 12 | - [ ] Maximum Mutual Information score 13 | - [ ] Perplexity score 14 | - [ ] Cosine Similarity score 15 | 3. Testing generated dialogues - we test each of our models and sum up the results (for all results, please refer to "Deep Reinforcement Learning For Language Models.pdf", which is in this repository). 16 | 4. Automatic evaluation - we evaluate each of the models using traditional NLP evaluation methods. 17 | 18 | 19 |

20 | 21 |

22 | Here we present a comparison of target sentences generation given a source sequence between Perplexity, MMI, BLEU and Seq2Seq. 23 |
24 |
25 |
26 |
27 |

28 | 29 |

30 | This is a comparison of target sentences generation given a source sequence between a model of MMI trained using reinforcement learning policy gradient and a model which is calculating MMI only at test time. 31 | 32 | 33 | ## Getting Started 34 | 35 | - [ ] Install and arrange an environment that meets the prerequisites (use below subsection) 36 | - [ ] Pycharm is recommended for code development 37 | - [ ] Clone this directory to your computer 38 | - [ ] Train a cross-entropy seq2seq model using 'train_crossent' function 39 | - [ ] Train RL models using appropriate functions (e.g. use 'train_rl_MMI' to use MMI criterion for policy gradient) 40 | - [ ] Test models using automatic evaluations using 'test_all_modells' 41 | - [ ] Generate responces using 'use_model' 42 | 43 | ### Prerequisites 44 | 45 | Before starting, make sure you already have the followings, if you don't - you should install before trying to use this code: 46 | - [ ] the code was tested only on python 3.6 47 | - [ ] numpy 48 | - [ ] scipy 49 | - [ ] pickle 50 | - [ ] pytorch 51 | - [ ] TensorboardX 52 | 53 | 54 | ## Running the Code 55 | 56 | ### Training all models 57 | 58 | #### Train Cross-entropy Seq2Seq and backward Seq2Seq 59 | Parameters for calling 'train_crossent' ( can be seen when writing --help ): 60 | 61 | | Name Of Param | Description | Default Value | Type | 62 | | --- | --- | --- | --- | 63 | | SAVES_DIR | Save directory | 'saves' | str | 64 | | name | Specific model saves directory | 'seq2seq' | str | 65 | | BATCH_SIZE | Batch Size for training | 32 | int | 66 | | LEARNING_RATE | Learning Rate | 1e-3 | float | 67 | | MAX_EPOCHES | Number of training iterations | 100 | int | 68 | | TEACHER_PROB | Probability to force reference inputs | 0.5 | float | 69 | | data | Genre to use - for data | 'comedy' | str | 70 | | train_backward | Choose - train backward/forward model | False | bool | 71 | 72 | Run Seq2Seq using: 73 | ``` 74 | python train_crossent.py -SAVES_DIR <'saves'> -name <'seq2seq'> -BATCH_SIZE <32> -LEARNING_RATE <0.001> -MAX_EPOCHES <100> -TEACHER_PROB <0.5> -data <'comedy'> -train_backward 75 | ``` 76 | 77 | Run backward Seq2Seq using: 78 | ``` 79 | python train_crossent.py -SAVES_DIR <'saves'> -name <'backward_seq2seq'> -BATCH_SIZE <32> -LEARNING_RATE <0.001> -MAX_EPOCHES <100> -TEACHER_PROB <0.5> -data <'comedy'> -train_backward 80 | ``` 81 | 82 | We present below some training graph's of the CE Seq2Seq and backward Seq2Seq (Loss and BLEU score for both models) 83 |

84 | 85 |

86 | 87 | 88 | #### Train Policy Gradient using BLEU reward 89 | Parameters for calling 'train_rl_BLEU' ( can be seen when writing --help ): 90 | 91 | | Name Of Param | Description | Default Value | Type | 92 | | --- | --- | --- | --- | 93 | | SAVES_DIR | Save directory | 'saves' | str | 94 | | name | Specific model saves directory | 'RL_BLUE' | str | 95 | | BATCH_SIZE | Batch Size for training | 16 | int | 96 | | LEARNING_RATE | Learning Rate | 1e-4 | float | 97 | | MAX_EPOCHES | Number of training iterations | 10000 | int | 98 | | data | Genre to use - for data | 'comedy' | str | 99 | | num_of_samples | Number of samples per per each example | 4 | int | 100 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str | 101 | 102 | Run using: 103 | ``` 104 | python train_rl_BLEU.py -SAVES_DIR <'saves'> -name <'RL_BLUE'> -BATCH_SIZE <16> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> 105 | ``` 106 | 107 | BLEU agent training graph: 108 |

109 | 110 |

111 | 112 | #### Train Policy Gradient using MMI reward 113 | Parameters for calling 'train_rl_MMI' ( can be seen when writing --help ): 114 | 115 | | Name Of Param | Description | Default Value | Type | 116 | | --- | --- | --- | --- | 117 | | SAVES_DIR | Save directory | 'saves' | str | 118 | | name | Specific model saves directory | 'RL_Mutual' | str | 119 | | BATCH_SIZE | Batch Size for training | 32 | int | 120 | | LEARNING_RATE | Learning Rate | 1e-4 | float | 121 | | MAX_EPOCHES | Number of training iterations | 10000 | int | 122 | | CROSS_ENT_PROB | Probability to run a CE batch | 0.3 | float | 123 | | TEACHER_PROB | Probability to run an imitation batch in case of using CE | 0.8 | float | 124 | | data | Genre to use - for data | 'comedy' | str | 125 | | num_of_samples | Number of samples per per each example | 4 | int | 126 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str | 127 | | laod_b_seq2seq_path | Pre-trained backward seq2seq model location | 'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat' | str | 128 | 129 | Run using: 130 | ``` 131 | python train_rl_MMI.py -SAVES_DIR <'saves'> -name <'RL_Mutual'> -BATCH_SIZE <32> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -CROSS_ENT_PROB <0.3> -TEACHER_PROB <0.8> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> -laod_b_seq2seq_path <'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat'> 132 | ``` 133 | 134 | Training graph of the MMI agent: 135 |

136 | 137 |

138 | 139 | #### Train Policy Gradient using Perplexity reward 140 | Parameters for calling 'train_rl_PREPLEXITY' ( can be seen when writing --help ): 141 | 142 | | Name Of Param | Description | Default Value | Type | 143 | | --- | --- | --- | --- | 144 | | SAVES_DIR | Save directory | 'saves' | str | 145 | | name | Specific model saves directory | 'RL_PREPLEXITY' | str | 146 | | BATCH_SIZE | Batch Size for training | 32 | int | 147 | | LEARNING_RATE | Learning Rate | 1e-4 | float | 148 | | MAX_EPOCHES | Number of training iterations | 10000 | int | 149 | | CROSS_ENT_PROB | Probability to run a CE batch | 0.5 | float | 150 | | TEACHER_PROB | Probability to run an imitation batch in case of using CE | 0.5 | float | 151 | | data | Genre to use - for data | 'comedy' | str | 152 | | num_of_samples | Number of samples per per each example | 4 | int | 153 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str | 154 | 155 | Run using: 156 | ``` 157 | python train_rl_PREPLEXITY.py -SAVES_DIR <'saves'> -name <'RL_PREPLEXITY'> -BATCH_SIZE <32> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -CROSS_ENT_PROB <0.5> -TEACHER_PROB <0.5> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> 158 | ``` 159 | 160 | Below is presented with training graph of the perplexity agent: 161 |

162 | 163 |

164 | 165 | #### Train Policy Gradient using Cosine Similarity reward 166 | Parameters for calling 'train_rl_cosine' ( can be seen when writing --help ): 167 | 168 | | Name Of Param | Description | Default Value | Type | 169 | | --- | --- | --- | --- | 170 | | SAVES_DIR | Save directory | 'saves' | str | 171 | | name | Specific model saves directory | 'RL_COSINE' | str | 172 | | BATCH_SIZE | Batch Size for training | 16 | int | 173 | | LEARNING_RATE | Learning Rate | 1e-4 | float | 174 | | MAX_EPOCHES | Number of training iterations | 10000 | int | 175 | | data | Genre to use - for data | 'comedy' | str | 176 | | num_of_samples | Number of samples per per each example | 4 | int | 177 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str | 178 | 179 | Run using: 180 | ``` 181 | python train_rl_cosine.py -SAVES_DIR <'saves'> -name <'RL_COSINE'> -BATCH_SIZE <16> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> 182 | ``` 183 | 184 | Below we present training graph of the cosine agent: 185 |

186 | 187 |

188 | 189 | ### Testing all models 190 | There are two types of test: automatic and qualitative. 191 | 192 | For qualitative test, it is recommended to use the 'use_model' function with pycharm interface. I have suggested (can be seen in code) many sentences to use as source sentences. 193 | 194 | For automatic test, a user can call 'test_all_modells' function, with the parameters: 195 | 196 | | Name Of Param | Description | Default Value | Type | 197 | | --- | --- | --- | --- | 198 | | SAVES_DIR | Save directory | 'saves' | str | 199 | | BATCH_SIZE | Batch Size for training | 32 | int | 200 | | data | Genre to use - for data | 'comedy' | str | 201 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str | 202 | | laod_b_seq2seq_path | Pre-trained backward seq2seq model location | 'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat' | str | 203 | | bleu_model_path | Pre-trained BLEU model location | 'Final_Saves/RL_BLUE/bleu_0.135_177.dat' | str | 204 | | mutual_model_path | Pre-trained MMI model location | 'Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat' | str | 205 | | prep_model_path | Pre-trained Perplexity model location | 'Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat' | str | 206 | | cos_model_path | Pre-trained Cosine Similarity model location | 'Final_Saves/RL_COSINE/cosine_0.621_03.dat' | str | 207 | 208 | 209 | Run using: 210 | ``` 211 | python test_all_modells.py -SAVES_DIR <'saves'> -BATCH_SIZE <16> -data <'comedy'> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> -laod_b_seq2seq_path <'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat'> -bleu_model_path <'Final_Saves/RL_BLUE/bleu_0.135_177.dat'> -mutual_model_path <'Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat'> -prep_model_path <'Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat'> -cos_model_path <'Final_Saves/RL_COSINE/cosine_0.621_03.dat'> 212 | ``` 213 | 214 | 215 | ## Results for tests 216 | WE present below a table with automatic tests results on 'comedy' movies. For more results and conclusions, please review 'Deep Reinforcement Learning For Language Models.pdf' file in this directory. 217 | 218 |

219 | 220 |

221 | 222 | ## Contributing 223 | 224 | Please read [CONTRIBUTING.md](https://gist.github.com/PurpleBooth/b24679402957c63ec426) for details on our code of conduct, and the process for submitting pull requests to us. 225 | 226 | 227 | 228 | ## Authors 229 | 230 | * **Eyal Ben David** 231 | 232 | 233 | ## License 234 | 235 | This project is licensed under the MIT License 236 | 237 | ## Acknowledgments 238 | 239 | * Inspiration for using RL in text tasks - "Deep Reinforcement Learning For Dialogue Generation", (Ritter et al., 1996) 240 | * Policy Gradient Implementation basics in dialogue agents - "Deep Reinforcement Learning Hands-On", by Maxim Lapan (Chapter 12) -------------------------------------------------------------------------------- /cur_reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import collections 4 | 5 | from libbots import cornell, data 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("-g", "--genre", default='', help="Genre to show dialogs from") 11 | parser.add_argument("--show-genres", action='store_true', default=False, help="Display genres stats") 12 | parser.add_argument("--show-dials", action='store_true', default=False, help="Display dialogs") 13 | parser.add_argument("--show-train", action='store_true', default=False, help="Display training pairs") 14 | parser.add_argument("--show-dict-freq", action='store_true', default=False, help="Display dictionary frequency") 15 | args = parser.parse_args() 16 | 17 | if args.show_genres: 18 | genre_counts = collections.Counter() 19 | genres = cornell.read_genres(cornell.DATA_DIR) 20 | for movie, g_list in genres.items(): 21 | for g in g_list: 22 | genre_counts[g] += 1 23 | print("Genres:") 24 | for g, count in genre_counts.most_common(): 25 | print("%s: %d" % (g, count)) 26 | 27 | if args.show_dials: 28 | dials = cornell.load_dialogues(genre_filter=args.genre) 29 | for d_idx, dial in enumerate(dials): 30 | print("Dialog %d with %d phrases:" % (d_idx, len(dial))) 31 | for p in dial: 32 | print(" ".join(p)) 33 | print() 34 | 35 | if args.show_train or args.show_dict_freq: 36 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.genre) 37 | 38 | if args.show_train: 39 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()} 40 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 41 | train_data = data.group_train_data(train_data) 42 | unk_token = emb_dict[data.UNKNOWN_TOKEN] 43 | 44 | print("Training pairs (%d total)" % len(train_data)) 45 | train_data.sort(key=lambda p: len(p[1]), reverse=True) 46 | for idx, (p1, p2_group) in enumerate(train_data): 47 | w1 = data.decode_words(p1, rev_emb_dict) 48 | w2_group = [data.decode_words(p2, rev_emb_dict) for p2 in p2_group] 49 | print("%d:" % idx, " ".join(w1)) 50 | for w2 in w2_group: 51 | print("%s:" % (" " * len(str(idx))), " ".join(w2)) 52 | 53 | if args.show_dict_freq: 54 | words_stat = collections.Counter() 55 | for p1, p2 in phrase_pairs: 56 | words_stat.update(p1) 57 | print("Frequency stats for %d tokens in the dict" % len(emb_dict)) 58 | for token, count in words_stat.most_common(): 59 | print("%s: %d" % (token, count)) 60 | pass -------------------------------------------------------------------------------- /data_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import logging 4 | 5 | from libbots import data, model, utils 6 | 7 | import torch 8 | 9 | log = logging.getLogger("data_test") 10 | 11 | if __name__ == "__main__": 12 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--data", required=True, 15 | help="Category to use for training. Empty string to train on full dataset") 16 | parser.add_argument("-m", "--model", required=True, help="Model name to load") 17 | args = parser.parse_args() 18 | 19 | phrase_pairs, emb_dict = data.load_data(args.data) 20 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 21 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 22 | train_data = data.group_train_data(train_data) 23 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()} 24 | 25 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE) 26 | net.load_state_dict(torch.load(args.model)) 27 | 28 | end_token = emb_dict[data.END_TOKEN] 29 | 30 | seq_count = 0 31 | sum_bleu = 0.0 32 | 33 | for seq_1, targets in train_data: 34 | input_seq = model.pack_input(seq_1, net.emb) 35 | enc = net.encode(input_seq) 36 | _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], 37 | seq_len=data.MAX_TOKENS, stop_at_token=end_token) 38 | references = [seq[1:] for seq in targets] 39 | bleu = utils.calc_bleu_many(tokens, references) 40 | sum_bleu += bleu 41 | seq_count += 1 42 | 43 | log.info("Processed %d phrases, mean BLEU = %.4f", seq_count, sum_bleu / seq_count) -------------------------------------------------------------------------------- /libbots/cornell.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cornel Movies Dialogs Corpus 3 | https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html 4 | """ 5 | import os 6 | import logging 7 | 8 | from . import utils 9 | 10 | log = logging.getLogger("cornell") 11 | DATA_DIR = "data/cornell" 12 | SEPARATOR = "+++$+++" 13 | 14 | 15 | def load_dialogues(data_dir=DATA_DIR, genre_filter=''): 16 | """ 17 | Load dialogues from cornell data 18 | :return: list of list of list of words 19 | """ 20 | movie_set = None 21 | if genre_filter: 22 | movie_set = read_movie_set(data_dir, genre_filter) 23 | log.info("Loaded %d movies with genre %s", len(movie_set), genre_filter) 24 | log.info("Read and tokenise phrases...") 25 | lines = read_phrases(data_dir, movies=movie_set) 26 | log.info("Loaded %d phrases", len(lines)) 27 | dialogues = load_conversations(data_dir, lines, movie_set) 28 | return dialogues 29 | 30 | 31 | def iterate_entries(data_dir, file_name): 32 | with open(os.path.join(data_dir, file_name), "rb") as fd: 33 | for l in fd: 34 | l = str(l, encoding='utf-8', errors='ignore') 35 | yield list(map(str.strip, l.split(SEPARATOR))) 36 | 37 | 38 | def read_movie_set(data_dir, genre_filter): 39 | res = set() 40 | for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"): 41 | m_id, m_genres = parts[0], parts[5] 42 | if m_genres.find(genre_filter) != -1: 43 | res.add(m_id) 44 | return res 45 | 46 | 47 | def read_phrases(data_dir, movies=None): 48 | res = {} 49 | for parts in iterate_entries(data_dir, "movie_lines.txt"): 50 | l_id, m_id, l_str = parts[0], parts[2], parts[4] 51 | if movies and m_id not in movies: 52 | continue 53 | tokens = utils.tokenize(l_str) 54 | if tokens: 55 | res[l_id] = tokens 56 | return res 57 | 58 | 59 | def load_conversations(data_dir, lines, movies=None): 60 | res = [] 61 | for parts in iterate_entries(data_dir, "movie_conversations.txt"): 62 | m_id, dial_s = parts[2], parts[3] 63 | if movies and m_id not in movies: 64 | continue 65 | l_ids = dial_s.strip("[]").split(", ") 66 | l_ids = list(map(lambda s: s.strip("'"), l_ids)) 67 | dial = [lines[l_id] for l_id in l_ids if l_id in lines] 68 | if dial: 69 | res.append(dial) 70 | return res 71 | 72 | 73 | def read_genres(data_dir): 74 | res = {} 75 | for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"): 76 | m_id, m_genres = parts[0], parts[5] 77 | l_genres = m_genres.strip("[]").split(", ") 78 | l_genres = list(map(lambda s: s.strip("'"), l_genres)) 79 | res[m_id] = l_genres 80 | return res -------------------------------------------------------------------------------- /libbots/data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import sys 4 | import logging 5 | import itertools 6 | import pickle 7 | 8 | from . import cornell 9 | 10 | UNKNOWN_TOKEN = '#UNK' 11 | BEGIN_TOKEN = "#BEG" 12 | END_TOKEN = "#END" 13 | MAX_TOKENS = 20 14 | MIN_TOKEN_FEQ = 10 15 | SHUFFLE_SEED = 5871 16 | 17 | EMB_DICT_NAME = "emb_dict.dat" 18 | EMB_NAME = "emb.npy" 19 | 20 | log = logging.getLogger("data") 21 | 22 | 23 | def save_emb_dict(dir_name, emb_dict): 24 | with open(os.path.join(dir_name, EMB_DICT_NAME), "wb") as fd: 25 | pickle.dump(emb_dict, fd) 26 | 27 | 28 | def load_emb_dict(dir_name): 29 | with open(os.path.join(dir_name, EMB_DICT_NAME), "rb") as fd: 30 | return pickle.load(fd) 31 | 32 | 33 | def encode_words(words, emb_dict): 34 | """ 35 | Convert list of words into list of embeddings indices, adding our tokens 36 | :param words: list of strings 37 | :param emb_dict: embeddings dictionary 38 | :return: list of IDs 39 | """ 40 | res = [emb_dict[BEGIN_TOKEN]] 41 | unk_idx = emb_dict[UNKNOWN_TOKEN] 42 | for w in words: 43 | idx = emb_dict.get(w.lower(), unk_idx) 44 | res.append(idx) 45 | res.append(emb_dict[END_TOKEN]) 46 | return res 47 | 48 | 49 | def encode_phrase_pairs(phrase_pairs, emb_dict, filter_unknows=True): 50 | """ 51 | Convert list of phrase pairs to training data 52 | :param phrase_pairs: list of (phrase, phrase) 53 | :param emb_dict: embeddings dictionary (word -> id) 54 | :return: list of tuples ([input_id_seq], [output_id_seq]) 55 | """ 56 | unk_token = emb_dict[UNKNOWN_TOKEN] 57 | result = [] 58 | for q, p in phrase_pairs: 59 | pq = encode_words(q, emb_dict), encode_words(p, emb_dict) 60 | if unk_token in pq[0] or unk_token in pq[1]: 61 | continue 62 | result.append(pq) 63 | return result 64 | 65 | 66 | def group_train_data(training_data): 67 | """ 68 | Group training pairs by first phrase 69 | :param training_data: list of (seq1, seq2) pairs 70 | :return: list of (seq1, [seq*]) pairs 71 | """ 72 | groups = collections.defaultdict(list) 73 | for p1, p2 in training_data: 74 | l = groups[tuple(p1)] 75 | l.append(p2) 76 | return list(groups.items()) 77 | 78 | 79 | def iterate_batches(data, batch_size): 80 | assert isinstance(data, list) 81 | assert isinstance(batch_size, int) 82 | 83 | ofs = 0 84 | while True: 85 | batch = data[ofs*batch_size:(ofs+1)*batch_size] 86 | if len(batch) <= 1: 87 | break 88 | yield batch 89 | ofs += 1 90 | 91 | def yield_samples(data): 92 | assert isinstance(data, list) 93 | 94 | ofs = 0 95 | while True: 96 | batch = data[ofs:(ofs+1)] 97 | if len(batch) <= 1: 98 | break 99 | yield batch 100 | ofs += 1 101 | 102 | 103 | def load_data(genre_filter, max_tokens=MAX_TOKENS, min_token_freq=MIN_TOKEN_FEQ): 104 | dialogues = cornell.load_dialogues(genre_filter=genre_filter) 105 | if not dialogues: 106 | log.error("No dialogues found, exit!") 107 | sys.exit() 108 | log.info("Loaded %d dialogues with %d phrases, generating training pairs", 109 | len(dialogues), sum(map(len, dialogues))) 110 | phrase_pairs = dialogues_to_pairs(dialogues, max_tokens=max_tokens) 111 | log.info("Counting freq of words...") 112 | word_counts = collections.Counter() 113 | for dial in dialogues: 114 | for p in dial: 115 | word_counts.update(p) 116 | freq_set = set(map(lambda p: p[0], filter(lambda p: p[1] >= min_token_freq, word_counts.items()))) 117 | log.info("Data has %d uniq words, %d of them occur more than %d", 118 | len(word_counts), len(freq_set), min_token_freq) 119 | phrase_dict = phrase_pairs_dict(phrase_pairs, freq_set) 120 | return phrase_pairs, phrase_dict 121 | 122 | 123 | def phrase_pairs_dict(phrase_pairs, freq_set): 124 | """ 125 | Return the dict of words in the dialogues mapped to their IDs 126 | :param phrase_pairs: list of (phrase, phrase) pairs 127 | :return: dict 128 | """ 129 | res = {UNKNOWN_TOKEN: 0, BEGIN_TOKEN: 1, END_TOKEN: 2} 130 | next_id = 3 131 | for p1, p2 in phrase_pairs: 132 | for w in map(str.lower, itertools.chain(p1, p2)): 133 | if w not in res and w in freq_set: 134 | res[w] = next_id 135 | next_id += 1 136 | return res 137 | 138 | 139 | def dialogues_to_pairs(dialogues, max_tokens=None): 140 | """ 141 | Convert dialogues to training pairs of phrases 142 | :param dialogues: 143 | :param max_tokens: limit of tokens in both question and reply 144 | :return: list of (phrase, phrase) pairs 145 | """ 146 | result = [] 147 | for dial in dialogues: 148 | prev_phrase = None 149 | for phrase in dial: 150 | if prev_phrase is not None: 151 | if max_tokens is None or (len(prev_phrase) <= max_tokens and len(phrase) <= max_tokens): 152 | result.append((prev_phrase, phrase)) 153 | prev_phrase = phrase 154 | return result 155 | 156 | 157 | def decode_words(indices, rev_emb_dict): 158 | return [rev_emb_dict.get(int(idx), UNKNOWN_TOKEN) for idx in indices] 159 | 160 | 161 | def trim_tokens_seq(tokens, end_token): 162 | res = [] 163 | for t in tokens: 164 | res.append(t) 165 | if t == end_token: 166 | break 167 | return res 168 | 169 | 170 | def split_train_test(data, train_ratio=0.95): 171 | count = int(len(data) * train_ratio) 172 | return data[:count], data[count:] -------------------------------------------------------------------------------- /libbots/model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.utils.rnn as rnn_utils 7 | import torch.nn.functional as F 8 | from . import utils, data 9 | from torch.autograd import Variable 10 | 11 | 12 | HIDDEN_STATE_SIZE = 512 13 | EMBEDDING_DIM = 50 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | class PhraseModel(nn.Module): 18 | def __init__(self, emb_size, dict_size, hid_size): 19 | super(PhraseModel, self).__init__() 20 | 21 | self.emb = nn.Embedding(num_embeddings=dict_size, embedding_dim=emb_size) 22 | self.encoder = nn.LSTM(input_size=emb_size, hidden_size=hid_size, 23 | num_layers=1, batch_first=True) 24 | self.decoder = nn.LSTM(input_size=emb_size, hidden_size=hid_size, 25 | num_layers=1, batch_first=True) 26 | self.output = nn.Sequential(nn.Linear(hid_size, dict_size)) 27 | 28 | def encode(self, x): 29 | _, hid = self.encoder(x) 30 | return hid 31 | 32 | def get_encoded_item(self, encoded, index): 33 | # For RNN 34 | # return encoded[:, index:index+1] 35 | # For LSTM 36 | return encoded[0][:, index:index+1].contiguous(), \ 37 | encoded[1][:, index:index+1].contiguous() 38 | 39 | def decode_teacher(self, hid, input_seq, detach=False): 40 | # Method assumes batch of size=1 41 | if detach: 42 | out, _ = self.decoder(input_seq, hid) 43 | out = self.output(out.data).detach() 44 | else: 45 | out, _ = self.decoder(input_seq, hid) 46 | out = self.output(out.data) 47 | return out 48 | 49 | def decode_one(self, hid, input_x, detach=False): 50 | if detach: 51 | out, new_hid = self.decoder(input_x.unsqueeze(0), hid) 52 | out = self.output(out).detach() 53 | else: 54 | out, new_hid = self.decoder(input_x.unsqueeze(0), hid) 55 | out = self.output(out) 56 | return out.squeeze(dim=0), new_hid 57 | 58 | def decode_chain_argmax(self, hid, begin_emb, seq_len, stop_at_token=None): 59 | """ 60 | Decode sequence by feeding predicted token to the net again. Act greedily 61 | """ 62 | res_logits = [] 63 | res_tokens = [] 64 | cur_emb = begin_emb 65 | 66 | for _ in range(seq_len): 67 | out_logits, hid = self.decode_one(hid, cur_emb) 68 | out_token_v = torch.max(out_logits, dim=1)[1] 69 | out_token = out_token_v.data.cpu().numpy()[0] 70 | 71 | cur_emb = self.emb(out_token_v) 72 | 73 | res_logits.append(out_logits) 74 | res_tokens.append(out_token) 75 | if stop_at_token is not None and out_token == stop_at_token: 76 | break 77 | return torch.cat(res_logits), res_tokens 78 | 79 | def decode_chain_sampling(self, hid, begin_emb, seq_len, stop_at_token=None): 80 | """ 81 | Decode sequence by feeding predicted token to the net again. 82 | Act according to probabilities 83 | """ 84 | res_logits = [] 85 | res_actions = [] 86 | cur_emb = begin_emb 87 | 88 | for _ in range(seq_len): 89 | out_logits, hid = self.decode_one(hid, cur_emb) 90 | out_probs_v = F.softmax(out_logits, dim=1) 91 | out_probs = out_probs_v.data.cpu().numpy()[0] 92 | action = np.random.choice(out_probs.shape[0], p=out_probs) 93 | action_v = torch.LongTensor([action]).to(begin_emb.device) 94 | cur_emb = self.emb(action_v) 95 | 96 | res_logits.append(out_logits) 97 | res_actions.append(action) 98 | if stop_at_token is not None and action == stop_at_token: 99 | break 100 | return torch.cat(res_logits), res_actions 101 | 102 | def decode_rl_chain_argmax(self, hid, begin_emb, seq_len, stop_at_token=None): 103 | """ 104 | Decode sequence by feeding predicted token to the net again. Act greedily 105 | """ 106 | q_list = [] 107 | res_tokens = [] 108 | hidden_list = [] 109 | emb_list = [] 110 | cur_emb = begin_emb 111 | hidden_list.append(hid) 112 | emb_list.append(cur_emb) 113 | 114 | 115 | for _ in range(seq_len): 116 | out_q, hid = self.decode_one(hid, cur_emb) 117 | out_token_v = torch.max(out_q, dim=1)[1] 118 | out_token = out_token_v.data.cpu().numpy()[0] 119 | 120 | cur_emb = self.emb(out_token_v) 121 | 122 | 123 | hidden_list.append(hid) 124 | emb_list.append(cur_emb) 125 | q_list.append(out_q[0][out_token]) 126 | res_tokens.append(out_token) 127 | if stop_at_token is not None and out_token == stop_at_token: 128 | break 129 | return q_list, hidden_list, torch.cat(emb_list), res_tokens 130 | 131 | def decode_rl_chain_sampling(self, hid, begin_emb, seq_len, stop_at_token=None): 132 | """ 133 | Decode sequence by feeding predicted token to the net again. 134 | Act according to probabilities 135 | """ 136 | q_list = [] 137 | res_tokens = [] 138 | hidden_list = [] 139 | emb_list = [] 140 | 141 | cur_emb = begin_emb 142 | hidden_list.append(hid) 143 | emb_list.append(cur_emb) 144 | 145 | for _ in range(seq_len): 146 | out_logits, hid = self.decode_one(hid, cur_emb) 147 | out_probs_v = F.softmax(out_logits, dim=1) 148 | out_probs = out_probs_v.data.cpu().numpy()[0] 149 | action = int(np.random.choice(out_probs.shape[0], p=out_probs)) 150 | action_v = torch.LongTensor([action]).to(begin_emb.device) 151 | cur_emb = self.emb(action_v) 152 | 153 | q_list.append(out_logits[0][action]) 154 | res_tokens.append(action) 155 | hidden_list.append(hid) 156 | emb_list.append(cur_emb) 157 | if stop_at_token is not None and action == stop_at_token: 158 | break 159 | return q_list, hidden_list, torch.cat(emb_list), res_tokens 160 | 161 | def decode_batch(self, state): 162 | """ 163 | Decode sequence by feeding predicted token to the net again. Act greedily 164 | """ 165 | res_logits = [] 166 | for idx in range(len(state)): 167 | cur_q, _ = self.decode_one(state[idx][0], state[idx][1].unsqueeze(0)) 168 | res_logits.append(cur_q) 169 | 170 | return torch.cat(res_logits) 171 | 172 | def decode_k_best(self, hid, begin_emb, k, seq_len, stop_at_token=None): 173 | """ 174 | Decode sequence by feeding predicted token to the net again. 175 | Act according to probabilities 176 | """ 177 | cur_emb = begin_emb 178 | 179 | out_logits, hid = self.decode_one(hid, cur_emb) 180 | out_probs_v = F.log_softmax(out_logits, dim=1) 181 | highest_values = torch.topk(out_probs_v, k) 182 | first_beam_probs = highest_values[0] 183 | first_beam_vals = highest_values[1] 184 | beam_emb = [] 185 | beam_hid = [] 186 | beam_vals = [] 187 | beam_prob = [] 188 | for i in range(k): 189 | beam_vals.append([]) 190 | for i in range(k): 191 | beam_emb.append(self.emb(first_beam_vals)[:, i, :]) 192 | beam_hid.append(hid) 193 | beam_vals[i].append(int(first_beam_vals[0][i].data.cpu().numpy())) 194 | beam_prob.append(first_beam_probs[0][i].data.cpu().numpy()) 195 | 196 | for num_words in range(1, seq_len): 197 | possible_sentences = [] 198 | possible_probs = [] 199 | for i in range(k): 200 | cur_emb = beam_emb[i] 201 | cur_hid = beam_hid[i] 202 | cur_prob = beam_prob[i] 203 | if stop_at_token is not None and beam_vals[i][-1] == stop_at_token: 204 | possible_probs.append(cur_prob) 205 | possible_sentences.append((beam_vals[i], beam_hid[i], beam_emb[i])) 206 | continue 207 | 208 | else: 209 | out_logits, hid = self.decode_one(cur_hid, cur_emb) 210 | log_probs = torch.tensor(F.log_softmax(out_logits, dim=1)).to(device) 211 | highest_values = torch.topk(log_probs, k) 212 | probs = highest_values[0] 213 | vals = highest_values[1] 214 | for j in range(k): 215 | prob = (probs[0][j].data.cpu().numpy() + cur_prob) * (num_words / (num_words + 1)) 216 | possible_probs.append(prob) 217 | emb = self.emb(vals)[:, j, :] 218 | val = vals[0][j].data.cpu().numpy() 219 | cur_seq_vals = beam_vals[i][:] 220 | cur_seq_vals.append(int(val)) 221 | possible_sentences.append((cur_seq_vals, hid, emb)) 222 | 223 | max_indices = sorted(range(len(possible_probs)), key=lambda s: possible_probs[s])[-k:] 224 | beam_emb = [] 225 | beam_hid = [] 226 | beam_vals = [] 227 | beam_prob = [] 228 | for i in range(k): 229 | (seq_vals, hid, emb) = possible_sentences[max_indices[i]] 230 | beam_emb.append(emb) 231 | beam_hid.append(hid) 232 | beam_vals.append(seq_vals) 233 | beam_prob.append(possible_probs[max_indices[i]]) 234 | t_beam_prob = [] 235 | for i in range(k): 236 | t_beam_prob.append(Variable(torch.tensor(beam_prob[i]).to(device), requires_grad=True)) 237 | 238 | return beam_vals, t_beam_prob 239 | 240 | def decode_k_sampling(self, hid, begin_emb, k, seq_len, stop_at_token=None): 241 | list_res_logit = [] 242 | list_res_actions = [] 243 | for i in range(k): 244 | res_logits = [] 245 | res_actions = [] 246 | cur_emb = begin_emb 247 | 248 | for _ in range(seq_len): 249 | out_logits, hid = self.decode_one(hid, cur_emb) 250 | out_probs_v = F.softmax(out_logits, dim=1) 251 | out_probs = out_probs_v.data.cpu().numpy()[0] 252 | action = int(np.random.choice(out_probs.shape[0], p=out_probs)) 253 | action_v = torch.LongTensor([action]).to(begin_emb.device) 254 | cur_emb = self.emb(action_v) 255 | 256 | res_logits.append(out_logits) 257 | res_actions.append(action) 258 | if stop_at_token is not None and action == stop_at_token: 259 | break 260 | 261 | list_res_logit.append(torch.cat(res_logits)) 262 | list_res_actions.append(res_actions) 263 | 264 | return list_res_actions, list_res_logit 265 | 266 | def get_qp_prob(self, hid, begin_emb, p_tokens): 267 | """ 268 | Find the probability of P(q|p) - the question asked (q, source) given the answer (p, target) 269 | """ 270 | cur_emb = begin_emb 271 | prob = 0 272 | for word_num in range(len(p_tokens)): 273 | out_logits, hid = self.decode_one(hid, cur_emb) 274 | out_token_v = F.log_softmax(out_logits, dim=1)[0] 275 | prob = prob + out_token_v[p_tokens[word_num]].data.cpu().numpy() 276 | 277 | out_token = p_tokens[word_num] 278 | action_v = torch.LongTensor([out_token]).to(begin_emb.device) 279 | cur_emb = self.emb(action_v) 280 | if (len(p_tokens)) == 0: 281 | return -20 282 | else: 283 | return prob/(len(p_tokens)) 284 | 285 | def get_mean_emb(self, begin_emb, p_tokens): 286 | """ 287 | Find the embbeding of the entire sentence 288 | """ 289 | emb_list = [] 290 | cur_emb = begin_emb 291 | emb_list.append(cur_emb) 292 | if isinstance(p_tokens, list): 293 | if len(p_tokens) == 1: 294 | p_tokens = p_tokens[0] 295 | for word_num in range(len(p_tokens)): 296 | out_token = p_tokens[word_num] 297 | action_v = torch.LongTensor([out_token]).to(begin_emb.device) 298 | cur_emb = self.emb(action_v).detach() 299 | emb_list.append(cur_emb) 300 | else: 301 | out_token = p_tokens 302 | action_v = torch.LongTensor([out_token]).to(begin_emb.device) 303 | cur_emb = self.emb(action_v).detach() 304 | emb_list.append(cur_emb) 305 | 306 | return sum(emb_list)/len(emb_list) 307 | 308 | def get_beam_sentences(self, tokens, emb_dict, k_sentences): 309 | # Forward 310 | source_seq = pack_input(tokens, self.emb) 311 | enc = self.encode(source_seq) 312 | end_token = emb_dict[data.END_TOKEN] 313 | probs, list_of_out_tokens = self.decode_k_sampling(enc, source_seq.data[0:1], k_sentences, 314 | seq_len=data.MAX_TOKENS, stop_at_token=end_token) 315 | # list_of_out_tokens, probs = self.decode_k_sent(enc, source_seq.data[0:1], k_sentences, seq_len=data.MAX_TOKENS, 316 | # stop_at_token=end_token) 317 | return list_of_out_tokens, probs 318 | 319 | def get_action_prob(self, source, target, emb_dict): 320 | source_seq = pack_input(source, self.emb, "cuda") 321 | hid = self.encode(source_seq) 322 | end_token = emb_dict[data.END_TOKEN] 323 | cur_emb = source_seq.data[0:1] 324 | probs = [] 325 | for word_num in range(len(target)): 326 | out_logits, hid = self.decode_one(hid, cur_emb) 327 | out_token_v = F.log_softmax(out_logits, dim=1)[0] 328 | 329 | probs.append(out_token_v[target[word_num]]) 330 | out_token = target[word_num] 331 | action_v = torch.LongTensor([out_token]).to(device) 332 | cur_emb = self.emb(action_v) 333 | t_prob = probs[0] 334 | for i in range(1, len(probs)): 335 | t_prob = t_prob + probs[i] 336 | 337 | return t_prob/len(probs) 338 | 339 | def get_logits(self, hid, begin_emb, seq_len, res_action, stop_at_token=None): 340 | """ 341 | Decode sequence by feeding predicted token to the net again. Act greedily 342 | """ 343 | res_logits = [] 344 | cur_emb = begin_emb 345 | 346 | for i in range(seq_len): 347 | out_logits, hid = self.decode_one(hid, cur_emb, True) 348 | out_token_v = torch.tensor([res_action[i]]).to(device) 349 | out_token = out_token_v.data.cpu().numpy()[0] 350 | 351 | cur_emb = self.emb(out_token_v).detach() 352 | 353 | res_logits.append(out_logits) 354 | if stop_at_token is not None and out_token == stop_at_token: 355 | break 356 | return torch.cat(res_logits) 357 | 358 | 359 | 360 | def pack_batch_no_out(batch, embeddings, device=device): 361 | assert isinstance(batch, list) 362 | # Sort descending (CuDNN requirements) 363 | batch.sort(key=lambda s: len(s[0]), reverse=True) 364 | input_idx, output_idx = zip(*batch) 365 | # create padded matrix of inputs 366 | lens = list(map(len, input_idx)) 367 | input_mat = np.zeros((len(batch), lens[0]), dtype=np.int64) 368 | for idx, x in enumerate(input_idx): 369 | input_mat[idx, :len(x)] = x 370 | input_v = torch.tensor(input_mat).to(device) 371 | input_seq = rnn_utils.pack_padded_sequence(input_v, lens, batch_first=True) 372 | # lookup embeddings 373 | r = embeddings(input_seq.data) 374 | emb_input_seq = rnn_utils.PackedSequence(r, input_seq.batch_sizes) 375 | return emb_input_seq, input_idx, output_idx 376 | 377 | def pack_batch_no_in(batch, embeddings, device=device): 378 | assert isinstance(batch, list) 379 | # Sort descending (CuDNN requirements) 380 | batch.sort(key=lambda s: len(s[1]), reverse=True) 381 | input_idx, output_idx = zip(*batch) 382 | # create padded matrix of inputs 383 | lens = list(map(len, output_idx)) 384 | output_mat = np.zeros((len(batch), lens[0]), dtype=np.int64) 385 | for idx, x in enumerate(output_idx): 386 | output_mat[idx, :len(x)] = x 387 | ouput_v = torch.tensor(output_mat).to(device) 388 | output_seq = rnn_utils.pack_padded_sequence(ouput_v, lens, batch_first=True) 389 | # lookup embeddings 390 | r = embeddings(output_seq.data) 391 | emb_input_seq = rnn_utils.PackedSequence(r, output_seq.batch_sizes) 392 | return emb_input_seq, output_idx, input_idx 393 | 394 | def pack_input(input_data, embeddings, device=device, detach=False): 395 | input_v = torch.LongTensor([input_data]).to(device) 396 | if detach: 397 | r = embeddings(input_v).detach() 398 | else: 399 | r = embeddings(input_v) 400 | return rnn_utils.pack_padded_sequence(r, [len(input_data)], batch_first=True) 401 | 402 | def pack_batch(batch, embeddings, device=device, detach=False): 403 | emb_input_seq, input_idx, output_idx = pack_batch_no_out(batch, embeddings, device) 404 | 405 | # prepare output sequences, with end token stripped 406 | output_seq_list = [] 407 | for out in output_idx: 408 | if len(out) == 1: 409 | out = out[0] 410 | output_seq_list.append(pack_input(out[:-1], embeddings, device, detach)) 411 | return emb_input_seq, output_seq_list, input_idx, output_idx 412 | 413 | def pack_backward_batch(batch, embeddings, device=device): 414 | emb_target_seq, target_idx, source_idx = pack_batch_no_in(batch, embeddings, device) 415 | 416 | # prepare output sequences, with end token stripped 417 | source_seq = [] 418 | for src in source_idx: 419 | source_seq.append(pack_input(src[:-1], embeddings, device)) 420 | backward_input = emb_target_seq 421 | backward_output = source_seq 422 | return backward_input, backward_output, target_idx, source_idx 423 | 424 | def seq_bleu(model_out, ref_seq): 425 | model_seq = torch.max(model_out.data, dim=1)[1] 426 | model_seq = model_seq.cpu().numpy() 427 | return utils.calc_bleu(model_seq, ref_seq) 428 | 429 | def mutual_words_to_words(words, data, k, emb_dict, rev_emb_dict, net, back_net): 430 | # Forward 431 | tokens = data.encode_words(words, emb_dict) 432 | source_seq = pack_input(tokens, net.emb, "cuda") 433 | enc = net.encode(source_seq) 434 | end_token = emb_dict[data.END_TOKEN] 435 | list_of_out_tokens, probs = net.decode_k_best(enc, source_seq.data[0:1], k, seq_len=data.MAX_TOKENS, 436 | stop_at_token=end_token) 437 | list_of_out_words = [] 438 | for iTokens in range(len(list_of_out_tokens)): 439 | if list_of_out_tokens[iTokens][-1] == end_token: 440 | list_of_out_tokens[iTokens] = list_of_out_tokens[iTokens][:-1] 441 | list_of_out_words.append(data.decode_words(list_of_out_tokens[iTokens], rev_emb_dict)) 442 | 443 | # Backward 444 | back_seq2seq_prob = [] 445 | for iTarget in range(len(list_of_out_words)): 446 | b_tokens = data.encode_words(list_of_out_words[iTarget], emb_dict) 447 | target_seq = pack_input(b_tokens, back_net.emb, "cuda") 448 | b_enc = back_net.encode(target_seq) 449 | back_seq2seq_prob.append(back_net.get_qp_prob(b_enc, target_seq.data[0:1], tokens[1:])) 450 | 451 | mutual_prob = [] 452 | for i in range(len(probs)): 453 | mutual_prob.append(probs[i] + back_seq2seq_prob[i]) 454 | 455 | return list_of_out_words, mutual_prob 456 | 457 | 458 | 459 | 460 | 461 | 462 | -------------------------------------------------------------------------------- /libbots/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import string 4 | from nltk.translate import bleu_score 5 | from nltk.tokenize import TweetTokenizer 6 | import numpy as np 7 | from . import model 8 | import torch 9 | import torch.nn as nn 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | def calc_bleu_many(cand_seq, ref_sequences): 15 | sf = bleu_score.SmoothingFunction() 16 | return bleu_score.sentence_bleu(ref_sequences, cand_seq, 17 | smoothing_function=sf.method1, 18 | weights=(0.5, 0.5)) 19 | 20 | def calc_preplexity_many(probs, actions): 21 | criterion = nn.CrossEntropyLoss() 22 | loss = criterion(torch.tensor(probs).to(device), torch.tensor(actions).to(device)) 23 | return np.exp(loss.item()) 24 | 25 | def calc_bleu(cand_seq, ref_seq): 26 | return calc_bleu_many(cand_seq, [ref_seq]) 27 | 28 | def tokenize(s): 29 | return TweetTokenizer(preserve_case=False).tokenize(s) 30 | 31 | def untokenize(words): 32 | return "".join([" " + i if not i.startswith("'") and i not in string.punctuation else i for i in words]).strip() 33 | 34 | def calc_mutual(net, back_net, p1, p2): 35 | # Pack forward and backward 36 | criterion = nn.CrossEntropyLoss() 37 | p2 = [1] + p2 38 | input_seq_forward = model.pack_input(p1, net.emb, device) 39 | output_seq_forward = model.pack_input(p2[:-1], net.emb, device) 40 | input_seq_backward = model.pack_input(tuple(p2), back_net.emb, device) 41 | output_seq_backward = model.pack_input(list(p1)[:-1], back_net.emb, device) 42 | # Enc forward and backward 43 | enc_forward = net.encode(input_seq_forward) 44 | enc_backward = back_net.encode(input_seq_backward) 45 | 46 | r_forward = net.decode_teacher(enc_forward, output_seq_forward).detach() 47 | r_backward = back_net.decode_teacher(enc_backward, output_seq_backward).detach() 48 | fw = criterion(torch.tensor(r_forward).to(device), torch.tensor(p2[1:]).to(device)).detach() 49 | bw = criterion(torch.tensor(r_backward).to(device), torch.tensor(p1[1:]).to(device)).detach() 50 | return (-1)*float(fw + bw) 51 | 52 | def calc_cosine_many(mean_emb_pred, mean_emb_ref): 53 | norm_pred = mean_emb_pred / mean_emb_pred.norm(dim=1)[:, None] 54 | norm_ref = mean_emb_ref / mean_emb_ref.norm(dim=1)[:, None] 55 | return torch.mm(norm_pred, norm_ref.transpose(1, 0)) 56 | 57 | -------------------------------------------------------------------------------- /model_test.py: -------------------------------------------------------------------------------- 1 | 2 | from libbots import model, utils, data 3 | 4 | def run_test(test_data, per_net, end_token, device="cuda"): 5 | bleu_sum = 0.0 6 | bleu_count = 0 7 | for p1, p2 in test_data: 8 | input_seq = model.pack_input(p1, per_net.emb, device) 9 | enc = per_net.encode(input_seq) 10 | _, tokens = per_net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, 11 | stop_at_token=end_token) 12 | ref_indices = p2[1:] 13 | bleu_sum += utils.calc_bleu_many(tokens, [ref_indices]) 14 | bleu_count += 1 15 | return bleu_sum / bleu_count 16 | 17 | def run_test_preplexity(test_data, per_net, net, end_token, device="cpu"): 18 | preplexity_sum = 0.0 19 | preplexity_count = 0 20 | for p1, p2 in test_data: 21 | input_seq = model.pack_input(p1, per_net.emb, device) 22 | enc = per_net.encode(input_seq) 23 | logits, tokens = per_net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, 24 | stop_at_token=end_token) 25 | r = net.get_logits(enc, input_seq.data[0:1], data.MAX_TOKENS, tokens, 26 | stop_at_token=end_token) 27 | preplexity_sum += utils.calc_preplexity_many(r, tokens) 28 | preplexity_count += 1 29 | return preplexity_sum / preplexity_count 30 | 31 | def run_test_mutual(test_data, rl_net, net, back_net, end_token, device="cuda"): 32 | mutual_sum = 0.0 33 | mutual_count = 0 34 | for p1, p2 in test_data: 35 | input_seq_bleu = model.pack_input(p1, rl_net.emb, device) 36 | enc_bleu = rl_net.encode(input_seq_bleu) 37 | _, actions = rl_net.decode_chain_argmax(enc_bleu, input_seq_bleu.data[0:1], seq_len=data.MAX_TOKENS, 38 | stop_at_token=end_token) 39 | 40 | mutual_sum += utils.calc_mutual(net, back_net, p1, actions) 41 | mutual_count += 1 42 | return mutual_sum / mutual_count 43 | 44 | def run_test_cosine(test_data, rl_net, net, beg_token, end_token, device="cuda"): 45 | cosine_sum = 0.0 46 | cosine_count = 0 47 | beg_embedding = rl_net.emb(beg_token) 48 | for p1, p2 in test_data: 49 | input_seq_cosine = model.pack_input(p1, rl_net.emb, device) 50 | enc_cosine = rl_net.encode(input_seq_cosine) 51 | r_argmax, actions = rl_net.decode_chain_argmax(enc_cosine, beg_embedding, data.MAX_TOKENS, 52 | stop_at_token=end_token) 53 | mean_emb_max = net.get_mean_emb(beg_embedding, actions) 54 | mean_emb_ref_list = [] 55 | for iRef in p2: 56 | mean_emb_ref_list.append(rl_net.get_mean_emb(beg_embedding, iRef)) 57 | mean_emb_ref = sum(mean_emb_ref_list) / len(mean_emb_ref_list) 58 | cosine_sum += utils.calc_cosine_many(mean_emb_max, mean_emb_ref) 59 | cosine_count += 1 60 | return cosine_sum / cosine_count -------------------------------------------------------------------------------- /test_all_modells.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import logging 4 | import numpy as np 5 | import torch 6 | import argparse 7 | 8 | from libbots import data, model 9 | from model_test import run_test, run_test_mutual, run_test_preplexity, run_test_cosine 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory') 16 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training') 17 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data') 18 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat', 19 | help='Pre-trained seq2seq model location') 20 | parser.add_argument('-laod_b_seq2seq_path', type=str, 21 | default='Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat', 22 | help='Pre-trained backward seq2seq model location') 23 | parser.add_argument('-bleu_model_path', type=str, default='Final_Saves/RL_BLUE/bleu_0.135_177.dat', 24 | help='Pre-trained BLEU model location') 25 | parser.add_argument('-mutual_model_path', type=str, default='Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat', 26 | help='Pre-trained MMI model location') 27 | parser.add_argument('-prep_model_path', type=str, default='Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat', 28 | help='Pre-trained Perplexity model location') 29 | parser.add_argument('-cos_model_path', type=str, default='Final_Saves/RL_COSINE/cosine_0.621_03.dat', 30 | help='Pre-trained Cosine Similarity model location') 31 | args = parser.parse_args() 32 | 33 | 34 | log = logging.getLogger("test") 35 | 36 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 37 | arg_data = 'comedy' 38 | 39 | # Load Seq2Seq 40 | seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(args.load_seq2seq_path)) 41 | seq2seq_rev_emb_dict = {idx: word for word, idx in seq2seq_emb_dict.items()} 42 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(seq2seq_emb_dict), 43 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 44 | net.load_state_dict(torch.load(args.load_seq2seq_path)) 45 | 46 | # Load Back Seq2Seq 47 | b_seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(args.laod_b_seq2seq_path)) 48 | b_seq2seq_rev_emb_dict = {idx: word for word, idx in b_seq2seq_emb_dict.items()} 49 | b_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(b_seq2seq_emb_dict), 50 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 51 | b_net.load_state_dict(torch.load(args.laod_b_seq2seq_path)) 52 | 53 | # Load BLEU 54 | bleu_emb_dict = data.load_emb_dict(os.path.dirname(args.bleu_model_path)) 55 | bleu_rev_emb_dict = {idx: word for word, idx in bleu_emb_dict.items()} 56 | bleu_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(bleu_emb_dict), 57 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 58 | bleu_net.load_state_dict(torch.load(args.bleu_model_path)) 59 | 60 | # Load Mutual 61 | mutual_emb_dict = data.load_emb_dict(os.path.dirname(args.mutual_model_path)) 62 | mutual_rev_emb_dict = {idx: word for word, idx in mutual_emb_dict.items()} 63 | mutual_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(mutual_emb_dict), 64 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 65 | mutual_net.load_state_dict(torch.load(args.mutual_model_path)) 66 | 67 | 68 | # Load Preplexity 69 | prep_emb_dict = data.load_emb_dict(os.path.dirname(args.prep_model_path)) 70 | prep_rev_emb_dict = {idx: word for word, idx in prep_emb_dict.items()} 71 | prep_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(prep_emb_dict), 72 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 73 | prep_net.load_state_dict(torch.load(args.prep_model_path)) 74 | 75 | 76 | # Load Cosine Similarity 77 | cos_emb_dict = data.load_emb_dict(os.path.dirname(args.cos_model_path)) 78 | cos_rev_emb_dict = {idx: word for word, idx in cos_emb_dict.items()} 79 | cos_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(cos_emb_dict), 80 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 81 | cos_net.load_state_dict(torch.load(args.cos_model_path)) 82 | 83 | 84 | phrase_pairs, emb_dict = data.load_data(genre_filter=arg_data) 85 | end_token = emb_dict[data.END_TOKEN] 86 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 87 | rand = np.random.RandomState(data.SHUFFLE_SEED) 88 | rand.shuffle(train_data) 89 | train_data, test_data = data.split_train_test(train_data) 90 | 91 | # BEGIN token 92 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device) 93 | 94 | # # Test Seq2Seq model 95 | bleu_test_seq2seq = run_test(test_data, net, end_token, device) 96 | mutual_test_seq2seq = run_test_mutual(test_data, net, net, b_net, beg_token, end_token, device) 97 | preplexity_test_seq2seq = run_test_preplexity(test_data, net, net, end_token, device) 98 | cosine_test_seq2seq = run_test_cosine(test_data, net, beg_token, end_token, device) 99 | # 100 | # # Test BLEU model 101 | bleu_test_bleu = run_test(test_data, bleu_net, end_token, device) 102 | mutual_test_bleu = run_test_mutual(test_data, bleu_net, net, b_net, beg_token, end_token, device) 103 | preplexity_test_bleu = run_test_preplexity(test_data, bleu_net, net, end_token, device) 104 | cosine_test_bleu = run_test_cosine(test_data, bleu_net, beg_token, end_token, device) 105 | 106 | # # Test Mutual Information model 107 | bleu_test_mutual = run_test(test_data, mutual_net, end_token, device) 108 | mutual_test_mutual = run_test_mutual(test_data, mutual_net, net, b_net, beg_token, end_token, device) 109 | preplexity_test_mutual = run_test_preplexity(test_data, mutual_net, net, end_token, device) 110 | cosine_test_mutual = run_test_cosine(test_data, mutual_net, beg_token, end_token, device) 111 | 112 | # Test Perplexity model 113 | bleu_test_per = run_test(test_data, prep_net, end_token, device) 114 | mutual_test_per = run_test_mutual(test_data, prep_net, net, b_net, beg_token, end_token, device) 115 | preplexity_test_per = run_test_preplexity(test_data, prep_net, net, end_token, device) 116 | cosine_test_per = run_test_cosine(test_data, prep_net, beg_token, end_token, device) 117 | 118 | # Test Cosine Similarity model 119 | bleu_test_cos = run_test(test_data, cos_net, end_token, device) 120 | mutual_test_cos = run_test_mutual(test_data, cos_net, net, b_net, beg_token, end_token, device) 121 | preplexity_test_cos = run_test_preplexity(test_data, cos_net, net, end_token, device) 122 | cosine_test_cos = run_test_cosine(test_data, cos_net, beg_token, end_token, device) 123 | 124 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 125 | log.info("-----------------------------------------------") 126 | log.info("BLEU scores:") 127 | log.info(" Seq2Seq - %.3f", bleu_test_seq2seq) 128 | log.info(" BLEU - %.3f", bleu_test_bleu) 129 | log.info(" MMI - %.3f", bleu_test_mutual) 130 | log.info(" Perplexity - %.3f", bleu_test_per) 131 | log.info(" Cosine Similatirity - %.3f", bleu_test_cos) 132 | log.info("-----------------------------------------------") 133 | log.info("Max Mutual Information scores:") 134 | log.info(" Seq2Seq - %.3f", mutual_test_seq2seq) 135 | log.info(" BLEU - %.3f", mutual_test_bleu) 136 | log.info(" MMI - %.3f", mutual_test_mutual) 137 | log.info(" Perplexity - %.3f", preplexity_test_mutual) 138 | log.info(" Cosine Similatirity - %.3f", mutual_test_cos) 139 | log.info("-----------------------------------------------") 140 | log.info("Perplexity scores:") 141 | log.info(" Seq2Seq - %.3f", preplexity_test_seq2seq) 142 | log.info(" BLEU - %.3f", preplexity_test_bleu) 143 | log.info(" MMI - %.3f", preplexity_test_mutual) 144 | log.info(" Perplexity - %.3f", preplexity_test_per) 145 | log.info(" Cosine Similatirity - %.3f", preplexity_test_cos) 146 | log.info("-----------------------------------------------") 147 | log.info("Cosine Similarity scores:") 148 | log.info(" Seq2Seq - %.3f", cosine_test_seq2seq) 149 | log.info(" BLEU - %.3f", cosine_test_bleu) 150 | log.info(" MMI - %.3f", mutual_test_cos) 151 | log.info(" Perplexity - %.3f", cosine_test_per) 152 | log.info(" Cosine Similatirity - %.3f", cosine_test_cos) 153 | log.info("-----------------------------------------------") 154 | 155 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import libbots.data 4 | from libbots import data, subtitles 5 | 6 | 7 | class TestData(TestCase): 8 | emb_dict = { 9 | data.BEGIN_TOKEN: 0, 10 | data.END_TOKEN: 1, 11 | data.UNKNOWN_TOKEN: 2, 12 | 'a': 3, 13 | 'b': 4 14 | } 15 | 16 | def test_encode_words(self): 17 | res = data.encode_words(['a', 'b', 'c'], self.emb_dict) 18 | self.assertEqual(res, [0, 3, 4, 2, 1]) 19 | 20 | # def test_dialogues_to_train(self): 21 | # dialogues = [ 22 | # [ 23 | # libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1), 24 | # libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3), 25 | # libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3), 26 | # ], 27 | # [ 28 | # libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1), 29 | # ] 30 | # ] 31 | # 32 | # res = data.dialogues_to_train(dialogues, self.emb_dict) 33 | # self.assertEqual(res, [ 34 | # ([0, 3, 4, 1], [0, 4, 3, 1]), 35 | # ([0, 4, 3, 1], [0, 4, 3, 1]), 36 | # ]) -------------------------------------------------------------------------------- /tests/test_subtitles.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from unittest import TestCase 3 | 4 | import libbots.data 5 | from libbots import subtitles 6 | 7 | 8 | class TestPhrases(TestCase): 9 | def test_split_phrase(self): 10 | phrase = libbots.data.Phrase(words=["a", "b", "c"], time_start=datetime.timedelta(seconds=0), 11 | time_stop=datetime.timedelta(seconds=10)) 12 | res = subtitles.split_phrase(phrase) 13 | self.assertIsInstance(res, list) 14 | self.assertEqual(len(res), 1) 15 | self.assertEqual(res[0], phrase) 16 | 17 | phrase = libbots.data.Phrase(words=["a", "b", "-", "c"], time_start=datetime.timedelta(seconds=0), 18 | time_stop=datetime.timedelta(seconds=10)) 19 | res = subtitles.split_phrase(phrase) 20 | self.assertEqual(len(res), 2) 21 | self.assertEqual(res[0].words, ["a", "b"]) 22 | self.assertEqual(res[1].words, ["c"]) 23 | self.assertAlmostEqual(res[0].time_start.total_seconds(), 0) 24 | self.assertAlmostEqual(res[0].time_stop.total_seconds(), 5) 25 | self.assertAlmostEqual(res[1].time_start.total_seconds(), 5) 26 | self.assertAlmostEqual(res[1].time_stop.total_seconds(), 10) 27 | 28 | phrase = libbots.data.Phrase(words=['-', 'Wait', 'a', 'sec', '.', '-'], time_start=datetime.timedelta(0, 588, 204000), 29 | time_stop=datetime.timedelta(0, 590, 729000)) 30 | res = subtitles.split_phrase(phrase) 31 | self.assertEqual(res[0].words, ["Wait", "a", "sec", "."]) 32 | 33 | 34 | class TestUtils(TestCase): 35 | def test_parse_time(self): 36 | self.assertEqual(subtitles.parse_time("00:00:33,074"), 37 | datetime.timedelta(seconds=33, milliseconds=74)) 38 | 39 | def test_remove_braced_words(self): 40 | self.assertEqual(subtitles.remove_braced_words(['a', 'b', 'c']), 41 | ['a', 'b', 'c']) 42 | self.assertEqual(subtitles.remove_braced_words(['a', '[', 'b', ']', 'c']), 43 | ['a', 'c']) 44 | self.assertEqual(subtitles.remove_braced_words(['a', '[', 'b', 'c']), 45 | ['a']) 46 | self.assertEqual(subtitles.remove_braced_words(['a', ']', 'b', 'c']), 47 | ['a', 'b', 'c']) 48 | self.assertEqual(subtitles.remove_braced_words(['a', '(', 'b', ']', 'c']), 49 | ['a', 'c']) 50 | self.assertEqual(subtitles.remove_braced_words(['a', '(', 'b', 'c']), 51 | ['a']) 52 | self.assertEqual(subtitles.remove_braced_words(['a', ')', 'b', 'c']), 53 | ['a', 'b', 'c']) -------------------------------------------------------------------------------- /train_crossent.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import logging 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import argparse 11 | 12 | from libbots import data, model, utils 13 | from model_test import run_test 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory') 20 | parser.add_argument('-name', type=str, default='seq2seq', help='Specific model saves directory') 21 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training') 22 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-3, help='Learning Rate') 23 | parser.add_argument('-MAX_EPOCHES', type=int, default=100, help='Number of training iterations') 24 | parser.add_argument('-TEACHER_PROB', type=float, default=0.5, help='Probability to force reference inputs') 25 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data') 26 | parser.add_argument('-train_backward', type=bool, default=False, help='Choose - train backward/forward model') 27 | args = parser.parse_args() 28 | 29 | saves_path = os.path.join(args.SAVES_DIR, args.name) 30 | os.makedirs(saves_path, exist_ok=True) 31 | 32 | log = logging.getLogger("train") 33 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 34 | 35 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data) 36 | data.save_emb_dict(saves_path, emb_dict) 37 | end_token = emb_dict[data.END_TOKEN] 38 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 39 | rand = np.random.RandomState(data.SHUFFLE_SEED) 40 | rand.shuffle(train_data) 41 | train_data, test_data = data.split_train_test(train_data) 42 | 43 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 44 | log.info("Training data converted, got %d samples", len(train_data)) 45 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data)) 46 | 47 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 48 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 49 | 50 | writer = SummaryWriter(comment="-" + args.name) 51 | 52 | optimiser = optim.Adam(net.parameters(), lr=args.LEARNING_RATE) 53 | best_bleu = None 54 | for epoch in range(args.MAX_EPOCHES): 55 | losses = [] 56 | bleu_sum = 0.0 57 | bleu_count = 0 58 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE): 59 | optimiser.zero_grad() 60 | if args.train_backward: 61 | input_seq, out_seq_list, _, out_idx = model.pack_backward_batch(batch, net.emb, device) 62 | else: 63 | input_seq, out_seq_list, _, out_idx = model.pack_batch(batch, net.emb, device) 64 | enc = net.encode(input_seq) 65 | 66 | net_results = [] 67 | net_targets = [] 68 | for idx, out_seq in enumerate(out_seq_list): 69 | ref_indices = out_idx[idx][1:] 70 | enc_item = net.get_encoded_item(enc, idx) 71 | if random.random() < args.TEACHER_PROB: 72 | r = net.decode_teacher(enc_item, out_seq) 73 | bleu_sum += model.seq_bleu(r, ref_indices) 74 | else: 75 | r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1], 76 | len(ref_indices)) 77 | bleu_sum += utils.calc_bleu(seq, ref_indices) 78 | net_results.append(r) 79 | net_targets.extend(ref_indices) 80 | bleu_count += 1 81 | results_v = torch.cat(net_results) 82 | targets_v = torch.LongTensor(net_targets).to(device) 83 | loss_v = F.cross_entropy(results_v, targets_v) 84 | loss_v.backward() 85 | optimiser.step() 86 | 87 | losses.append(loss_v.item()) 88 | bleu = bleu_sum / bleu_count 89 | bleu_test = run_test(test_data, net, end_token, device) 90 | log.info("Epoch %d: mean loss %.3f, mean BLEU %.3f, test BLEU %.3f", 91 | epoch, np.mean(losses), bleu, bleu_test) 92 | writer.add_scalar("loss", np.mean(losses), epoch) 93 | writer.add_scalar("bleu", bleu, epoch) 94 | writer.add_scalar("bleu_test", bleu_test, epoch) 95 | if best_bleu is None or best_bleu < bleu_test: 96 | if best_bleu is not None: 97 | out_name = os.path.join(saves_path, "pre_bleu_%.3f_%02d.dat" % 98 | (bleu_test, epoch)) 99 | torch.save(net.state_dict(), out_name) 100 | log.info("Best BLEU updated %.3f", bleu_test) 101 | best_bleu = bleu_test 102 | 103 | if epoch % 10 == 0: 104 | out_name = os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % 105 | (epoch, bleu, bleu_test)) 106 | torch.save(net.state_dict(), out_name) 107 | 108 | writer.close() 109 | -------------------------------------------------------------------------------- /train_rl_BLEU.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import random 5 | import logging 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import argparse 12 | 13 | from libbots import data, model, utils 14 | from model_test import run_test 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory') 19 | parser.add_argument('-name', type=str, default='RL_BLUE', help='Specific model saves directory') 20 | parser.add_argument('-BATCH_SIZE', type=int, default=16, help='Batch Size for training') 21 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate') 22 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations') 23 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data') 24 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example') 25 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat', 26 | help='Pre-trained seq2seq model location') 27 | args = parser.parse_args() 28 | 29 | saves_path = os.path.join(args.SAVES_DIR, args.name) 30 | os.makedirs(saves_path, exist_ok=True) 31 | 32 | 33 | log = logging.getLogger("train") 34 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 35 | 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data) 39 | data.save_emb_dict(saves_path, emb_dict) 40 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 41 | rand = np.random.RandomState(data.SHUFFLE_SEED) 42 | rand.shuffle(train_data) 43 | train_data, test_data = data.split_train_test(train_data) 44 | train_data = data.group_train_data(train_data) 45 | test_data = data.group_train_data(test_data) 46 | 47 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 48 | log.info("Training data converted, got %d samples", len(train_data)) 49 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data)) 50 | 51 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()} 52 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 53 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 54 | loaded_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 55 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 56 | 57 | writer = SummaryWriter(comment="-" + args.name) 58 | net.load_state_dict(torch.load(args.load_seq2seq_path)) 59 | 60 | # BEGIN & END tokens 61 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device) 62 | end_token = emb_dict[data.END_TOKEN] 63 | 64 | 65 | optimiser = optim.Adam(net.parameters(), lr=args.LEARNING_RATE, eps=1e-3) 66 | batch_idx = 0 67 | best_bleu = None 68 | for epoch in range(args.MAX_EPOCHES): 69 | random.shuffle(train_data) 70 | dial_shown = False 71 | 72 | total_samples = 0 73 | bleus_argmax = [] 74 | bleus_sample = [] 75 | 76 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE): 77 | batch_idx += 1 78 | optimiser.zero_grad() 79 | input_seq, input_batch, output_batch = model.pack_batch_no_out(batch, net.emb, device) 80 | enc = net.encode(input_seq) 81 | 82 | net_policies = [] 83 | net_actions = [] 84 | net_advantages = [] 85 | beg_embedding = net.emb(beg_token) 86 | 87 | for idx, inp_idx in enumerate(input_batch): 88 | total_samples += 1 89 | ref_indices = [indices[1:] for indices in output_batch[idx]] 90 | item_enc = net.get_encoded_item(enc, idx) 91 | r_argmax, actions = net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS, 92 | stop_at_token=end_token) 93 | argmax_bleu = utils.calc_bleu_many(actions, ref_indices) 94 | bleus_argmax.append(argmax_bleu) 95 | 96 | 97 | if not dial_shown: 98 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict))) 99 | ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices] 100 | log.info("Refer: %s", " ~~|~~ ".join(ref_words)) 101 | log.info("Argmax: %s, bleu=%.4f", 102 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), 103 | argmax_bleu) 104 | 105 | for _ in range(args.num_of_samples): 106 | r_sample, actions = net.decode_chain_sampling(item_enc, beg_embedding, 107 | data.MAX_TOKENS, stop_at_token=end_token) 108 | sample_bleu = utils.calc_bleu_many(actions, ref_indices) 109 | 110 | if not dial_shown: 111 | log.info("Sample: %s, bleu=%.4f", 112 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), 113 | sample_bleu) 114 | 115 | net_policies.append(r_sample) 116 | net_actions.extend(actions) 117 | net_advantages.extend([sample_bleu - argmax_bleu] * len(actions)) 118 | bleus_sample.append(sample_bleu) 119 | dial_shown = True 120 | 121 | if not net_policies: 122 | continue 123 | 124 | policies_v = torch.cat(net_policies) 125 | actions_t = torch.LongTensor(net_actions).to(device) 126 | adv_v = torch.FloatTensor(net_advantages).to(device) 127 | log_prob_v = F.log_softmax(policies_v, dim=1) 128 | log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t] 129 | loss_policy_v = -log_prob_actions_v.mean() 130 | 131 | loss_v = loss_policy_v 132 | loss_v.backward() 133 | optimiser.step() 134 | 135 | bleu_test = run_test(test_data, net, end_token, device) 136 | bleu = np.mean(bleus_argmax) 137 | writer.add_scalar("bleu_test", bleu_test, batch_idx) 138 | writer.add_scalar("bleu_argmax", bleu, batch_idx) 139 | writer.add_scalar("bleu_sample", np.mean(bleus_sample), batch_idx) 140 | writer.add_scalar("epoch", batch_idx, epoch) 141 | 142 | log.info("Epoch %d, test BLEU: %.3f", epoch, bleu_test) 143 | if best_bleu is None or best_bleu < bleu_test: 144 | best_bleu = bleu_test 145 | log.info("Best bleu updated: %.4f", bleu_test) 146 | torch.save(net.state_dict(), os.path.join(saves_path, "bleu_%.3f_%02d.dat" % (bleu_test, epoch))) 147 | if epoch % 10 == 0: 148 | torch.save(net.state_dict(), 149 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, bleu, bleu_test))) 150 | 151 | writer.close() -------------------------------------------------------------------------------- /train_rl_MMI.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import random 5 | import logging 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | import argparse 12 | 13 | from libbots import data, model, utils 14 | from model_test import run_test_mutual 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory') 21 | parser.add_argument('-name', type=str, default='RL_Mutual', help='Specific model saves directory') 22 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training') 23 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate') 24 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations') 25 | parser.add_argument('-CROSS_ENT_PROB', type=float, default=0.3, help='Probability to run a CE batch') 26 | parser.add_argument('-TEACHER_PROB', type=float, default=0.8, help='Probability to run an imitation batch in case ' 27 | 'of using CE') 28 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data') 29 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example') 30 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat', 31 | help='Pre-trained seq2seq model location') 32 | parser.add_argument('-laod_b_seq2seq_path', type=str, default='Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat', 33 | help='Pre-trained backward seq2seq model location') 34 | args = parser.parse_args() 35 | 36 | saves_path = os.path.join(args.SAVES_DIR, args.name) 37 | os.makedirs(saves_path, exist_ok=True) 38 | 39 | log = logging.getLogger("train") 40 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 41 | 42 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data) 43 | data.save_emb_dict(saves_path, emb_dict) 44 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 45 | rand = np.random.RandomState(data.SHUFFLE_SEED) 46 | rand.shuffle(train_data) 47 | train_data, test_data = data.split_train_test(train_data) 48 | 49 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 50 | log.info("Training data converted, got %d samples", len(train_data)) 51 | 52 | 53 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()} 54 | 55 | # Load pre-trained nets 56 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 57 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 58 | net.load_state_dict(torch.load(args.load_seq2seq_path)) 59 | back_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 60 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 61 | back_net.load_state_dict(torch.load(args.laod_b_seq2seq_path)) 62 | 63 | rl_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 64 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 65 | rl_net.load_state_dict(torch.load(args.load_seq2seq_path)) 66 | 67 | writer = SummaryWriter(comment="-" + args.name) 68 | 69 | # BEGIN & END tokens 70 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device) 71 | end_token = emb_dict[data.END_TOKEN] 72 | 73 | 74 | 75 | optimiser = optim.Adam(rl_net.parameters(), lr=args.LEARNING_RATE, eps=1e-3) 76 | batch_idx = 0 77 | best_mutual = None 78 | for epoch in range(args.MAX_EPOCHES): 79 | dial_shown = False 80 | random.shuffle(train_data) 81 | 82 | total_samples = 0 83 | mutuals_argmax = [] 84 | mutuals_sample = [] 85 | 86 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE): 87 | batch_idx += 1 88 | optimiser.zero_grad() 89 | input_seq, out_seq_list, input_batch, output_batch = model.pack_batch(batch, rl_net.emb, device) 90 | enc = rl_net.encode(input_seq) 91 | 92 | net_policies = [] 93 | net_actions = [] 94 | net_advantages = [] 95 | beg_embedding = rl_net.emb(beg_token) 96 | 97 | if random.random() < args.CROSS_ENT_PROB: 98 | net_results = [] 99 | net_targets = [] 100 | for idx, out_seq in enumerate(out_seq_list): 101 | ref_indices = output_batch[idx][1:] 102 | enc_item = rl_net.get_encoded_item(enc, idx) 103 | if random.random() < args.TEACHER_PROB: 104 | r = rl_net.decode_teacher(enc_item, out_seq) 105 | else: 106 | r, seq = rl_net.decode_chain_argmax(enc_item, out_seq.data[0:1], 107 | len(ref_indices)) 108 | net_results.append(r) 109 | net_targets.extend(ref_indices) 110 | results_v = torch.cat(net_results) 111 | targets_v = torch.LongTensor(net_targets).to(device) 112 | loss_v = F.cross_entropy(results_v, targets_v) 113 | loss_v.backward() 114 | for param in rl_net.parameters(): 115 | param.grad.data.clamp_(-0.2, 0.2) 116 | optimiser.step() 117 | else: 118 | for idx, inp_idx in enumerate(input_batch): 119 | total_samples += 1 120 | ref_indices = output_batch[idx][1:] 121 | item_enc = rl_net.get_encoded_item(enc, idx) 122 | r_argmax, actions = rl_net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS, 123 | stop_at_token=end_token) 124 | argmax_mutual = utils.calc_mutual(net, back_net, inp_idx, actions) 125 | mutuals_argmax.append(argmax_mutual) 126 | 127 | if not dial_shown: 128 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict))) 129 | ref_words = [utils.untokenize(data.decode_words([ref], rev_emb_dict)) for ref in ref_indices] 130 | log.info("Refer: %s", " ".join(ref_words)) 131 | log.info("Argmax: %s, mutual=%.4f", 132 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), argmax_mutual) 133 | 134 | for _ in range(args.num_of_samples): 135 | r_sample, actions = rl_net.decode_chain_sampling(item_enc, beg_embedding, 136 | data.MAX_TOKENS, stop_at_token=end_token) 137 | sample_mutual = utils.calc_mutual(net, back_net, inp_idx, actions) 138 | if not dial_shown: 139 | log.info("Sample: %s, mutual=%.4f", 140 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), sample_mutual) 141 | 142 | net_policies.append(r_sample) 143 | net_actions.extend(actions) 144 | net_advantages.extend([sample_mutual - argmax_mutual] * len(actions)) 145 | mutuals_sample.append(sample_mutual) 146 | dial_shown = True 147 | 148 | if not net_policies: 149 | continue 150 | 151 | policies_v = torch.cat(net_policies) 152 | actions_t = torch.LongTensor(net_actions).to(device) 153 | adv_v = torch.FloatTensor(net_advantages).to(device) 154 | log_prob_v = F.log_softmax(policies_v, dim=1) 155 | log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t] 156 | loss_policy_v = -log_prob_actions_v.mean() 157 | 158 | loss_v = loss_policy_v 159 | loss_v.backward() 160 | for param in rl_net.parameters(): 161 | param.grad.data.clamp_(-0.2, 0.2) 162 | optimiser.step() 163 | 164 | mutual_test = run_test_mutual(test_data, rl_net, net, back_net, beg_token, end_token, device) 165 | mutual = np.mean(mutuals_argmax) 166 | writer.add_scalar("mutual_test", mutual_test, batch_idx) 167 | writer.add_scalar("mutual_argmax", mutual, batch_idx) 168 | writer.add_scalar("mutual_sample", np.mean(mutuals_sample), batch_idx) 169 | writer.add_scalar("epoch", batch_idx, epoch) 170 | log.info("Epoch %d, test mutual: %.3f", epoch, mutual_test) 171 | if best_mutual is None or best_mutual < mutual_test: 172 | best_mutual = mutual_test 173 | log.info("Best mutual updated: %.4f", best_mutual) 174 | torch.save(rl_net.state_dict(), os.path.join(saves_path, "mutual_%.3f_%02d.dat" % (mutual_test, epoch))) 175 | if epoch % 10 == 0: 176 | torch.save(rl_net.state_dict(), 177 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, mutual, mutual_test))) 178 | 179 | writer.close() -------------------------------------------------------------------------------- /train_rl_PREPLEXITY.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import random 4 | import logging 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import argparse 11 | 12 | from libbots import data, model, utils 13 | from model_test import run_test_preplexity 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory') 20 | parser.add_argument('-name', type=str, default='RL_PREPLEXITY', help='Specific model saves directory') 21 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training') 22 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate') 23 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations') 24 | parser.add_argument('-CROSS_ENT_PROB', type=float, default=0.5, help='Probability to run a CE batch') 25 | parser.add_argument('-TEACHER_PROB', type=float, default=0.5, help='Probability to run an imitation batch in case ' 26 | 'of using CE') 27 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data') 28 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example') 29 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat', 30 | help='Pre-trained seq2seq model location') 31 | args = parser.parse_args() 32 | 33 | saves_path = os.path.join(args.SAVES_DIR, args.name) 34 | os.makedirs(saves_path, exist_ok=True) 35 | 36 | log = logging.getLogger("train") 37 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 38 | 39 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data) 40 | data.save_emb_dict(saves_path, emb_dict) 41 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 42 | rand = np.random.RandomState(data.SHUFFLE_SEED) 43 | rand.shuffle(train_data) 44 | train_data, test_data = data.split_train_test(train_data) 45 | 46 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 47 | log.info("Training data converted, got %d samples", len(train_data)) 48 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data)) 49 | 50 | # Load pre-trained nets 51 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()} 52 | per_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 53 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 54 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 55 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 56 | per_net.load_state_dict(torch.load(args.load_seq2seq_path)) 57 | net.load_state_dict(torch.load(args.load_seq2seq_path)) 58 | 59 | writer = SummaryWriter(comment="-" + args.name) 60 | 61 | 62 | # BEGIN & END tokens 63 | end_token = emb_dict[data.END_TOKEN] 64 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device) 65 | 66 | optimiser = optim.Adam(per_net.parameters(), lr=args.LEARNING_RATE, eps=1e-3) 67 | batch_idx = 0 68 | best_preplexity = None 69 | for epoch in range(args.MAX_EPOCHES): 70 | dial_shown = False 71 | random.shuffle(train_data) 72 | 73 | total_samples = 0 74 | preplexity_argmax = [] 75 | preplexity_sample = [] 76 | 77 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE): 78 | batch_idx += 1 79 | optimiser.zero_grad() 80 | input_seq, out_seq_list, input_batch, output_batch = model.pack_batch(batch, per_net.emb, device) 81 | enc = per_net.encode(input_seq) 82 | 83 | net_policies = [] 84 | net_actions = [] 85 | net_advantages = [] 86 | beg_embedding = per_net.emb(beg_token) 87 | 88 | if random.random() < args.CROSS_ENT_PROB: 89 | net_results = [] 90 | net_targets = [] 91 | for idx, out_seq in enumerate(out_seq_list): 92 | ref_indices = output_batch[idx][1:] 93 | enc_item = per_net.get_encoded_item(enc, idx) 94 | if random.random() < args.TEACHER_PROB: 95 | r = per_net.decode_teacher(enc_item, out_seq) 96 | else: 97 | r, seq = per_net.decode_chain_argmax(enc_item, out_seq.data[0:1], 98 | len(ref_indices)) 99 | net_results.append(r) 100 | net_targets.extend(ref_indices) 101 | results_v = torch.cat(net_results) 102 | targets_v = torch.LongTensor(net_targets).to(device) 103 | loss_v = F.cross_entropy(results_v, targets_v) 104 | loss_v.backward() 105 | for param in per_net.parameters(): 106 | param.grad.data.clamp_(-0.2, 0.2) 107 | optimiser.step() 108 | else: 109 | for idx, inp_idx in enumerate(input_batch): 110 | total_samples += 1 111 | ref_indices = output_batch[idx][1:] 112 | item_enc = per_net.get_encoded_item(enc, idx) 113 | r_argmax, actions = per_net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS, 114 | stop_at_token=end_token) 115 | r_net_argmax = net.get_logits(item_enc, beg_embedding, data.MAX_TOKENS, actions, 116 | stop_at_token=end_token) 117 | argmax_preplexity = utils.calc_preplexity_many(r_net_argmax, actions) 118 | preplexity_argmax.append(argmax_preplexity) 119 | 120 | if not dial_shown: 121 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict))) 122 | ref_words = [utils.untokenize(data.decode_words([ref], rev_emb_dict)) for ref in ref_indices] 123 | log.info("Refer: %s", " ".join(ref_words)) 124 | log.info("Argmax: %s, preplexity=%.4f", 125 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), argmax_preplexity) 126 | 127 | for _ in range(args.num_of_samples): 128 | r_sample, actions = per_net.decode_chain_sampling(item_enc, beg_embedding, 129 | data.MAX_TOKENS, stop_at_token=end_token) 130 | r_net_sample = net.get_logits(item_enc, beg_embedding, data.MAX_TOKENS, actions, 131 | stop_at_token=end_token) 132 | sample_preplexity = utils.calc_preplexity_many(r_net_sample, actions) 133 | if not dial_shown: 134 | log.info("Sample: %s, preplexity=%.4f", 135 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), sample_preplexity) 136 | 137 | net_policies.append(r_sample) 138 | net_actions.extend(actions) 139 | net_advantages.extend([sample_preplexity - argmax_preplexity] * len(actions)) 140 | preplexity_sample.append(sample_preplexity) 141 | dial_shown = True 142 | 143 | if not net_policies: 144 | continue 145 | 146 | policies_v = torch.cat(net_policies) 147 | actions_t = torch.LongTensor(net_actions).to(device) 148 | adv_v = torch.FloatTensor(net_advantages).to(device) 149 | log_prob_v = F.log_softmax(policies_v, dim=1) 150 | log_prob_actions_v = (-1) * adv_v * log_prob_v[range(len(net_actions)), actions_t] 151 | loss_policy_v = -log_prob_actions_v.mean() 152 | 153 | loss_v = loss_policy_v 154 | loss_v.backward() 155 | for param in per_net.parameters(): 156 | param.grad.data.clamp_(-0.2, 0.2) 157 | optimiser.step() 158 | 159 | 160 | preplexity_test = run_test_preplexity(test_data, per_net, net, end_token, device) 161 | preplexity = np.mean(preplexity_argmax) 162 | writer.add_scalar("preplexity_test", preplexity_test, batch_idx) 163 | writer.add_scalar("preplexity_argmax", preplexity, batch_idx) 164 | writer.add_scalar("preplexity_sample", np.mean(preplexity_sample), batch_idx) 165 | writer.add_scalar("epoch", batch_idx, epoch) 166 | log.info("Epoch %d, test PREPLEXITY: %.3f", epoch, preplexity_test) 167 | if best_preplexity is None or best_preplexity > preplexity_test: 168 | best_preplexity = preplexity_test 169 | log.info("Best preplexity updated: %.4f", preplexity_test) 170 | torch.save(per_net.state_dict(), os.path.join(saves_path, "preplexity_%.3f_%02d.dat" % (preplexity_test, epoch))) 171 | if epoch % 10 == 0: 172 | torch.save(per_net.state_dict(), 173 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, preplexity, preplexity_test))) 174 | 175 | writer.close() -------------------------------------------------------------------------------- /train_rl_cosine.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import random 5 | import logging 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | from libbots import data, model, utils 9 | import torch 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import argparse 13 | 14 | from model_test import run_test_cosine 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory') 21 | parser.add_argument('-name', type=str, default='RL_COSINE', help='Specific model saves directory') 22 | parser.add_argument('-BATCH_SIZE', type=int, default=16, help='Batch Size for training') 23 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate') 24 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations') 25 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data') 26 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example') 27 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat', 28 | help='Pre-trained seq2seq model location') 29 | args = parser.parse_args() 30 | 31 | saves_path = os.path.join(args.SAVES_DIR, args.name) 32 | os.makedirs(saves_path, exist_ok=True) 33 | 34 | log = logging.getLogger("train") 35 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 36 | 37 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data) 38 | data.save_emb_dict(saves_path, emb_dict) 39 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict) 40 | rand = np.random.RandomState(data.SHUFFLE_SEED) 41 | rand.shuffle(train_data) 42 | train_data, test_data = data.split_train_test(train_data) 43 | train_data = data.group_train_data(train_data) 44 | test_data = data.group_train_data(test_data) 45 | 46 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict)) 47 | log.info("Training data converted, got %d samples", len(train_data)) 48 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data)) 49 | 50 | # Load pre-trained seq2seq net 51 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()} 52 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 53 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 54 | cos_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), 55 | hid_size=model.HIDDEN_STATE_SIZE).to(device) 56 | net.load_state_dict(torch.load(args.load_seq2seq_path)) 57 | cos_net.load_state_dict(torch.load(args.load_seq2seq_path)) 58 | 59 | writer = SummaryWriter(comment="-" + args.name) 60 | 61 | # BEGIN & END tokens 62 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device) 63 | end_token = emb_dict[data.END_TOKEN] 64 | 65 | 66 | optimiser = optim.Adam(cos_net.parameters(), lr=args.LEARNING_RATE, eps=1e-3) 67 | batch_idx = 0 68 | best_cosine = None 69 | for epoch in range(args.MAX_EPOCHES): 70 | random.shuffle(train_data) 71 | dial_shown = False 72 | 73 | total_samples = 0 74 | skipped_samples = 0 75 | cosine_argmax = [] 76 | cosine_sample = [] 77 | 78 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE): 79 | batch_idx += 1 80 | optimiser.zero_grad() 81 | input_seq, input_batch, output_batch = model.pack_batch_no_out(batch, cos_net.emb, device) 82 | enc = cos_net.encode(input_seq) 83 | 84 | net_policies = [] 85 | net_actions = [] 86 | net_advantages = [] 87 | beg_embedding = cos_net.emb(beg_token) 88 | 89 | for idx, inp_idx in enumerate(input_batch): 90 | total_samples += 1 91 | ref_indices = [indices[1:] for indices in output_batch[idx]] 92 | item_enc = cos_net.get_encoded_item(enc, idx) 93 | r_argmax, actions = cos_net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS, 94 | stop_at_token=end_token) 95 | mean_emb_max = net.get_mean_emb(beg_embedding, actions) 96 | mean_emb_ref_list = [] 97 | for iRef in ref_indices: 98 | mean_emb_ref_list.append(net.get_mean_emb(beg_embedding, iRef)) 99 | mean_emb_ref = sum(mean_emb_ref_list)/len(mean_emb_ref_list) 100 | argmax_cosine = utils.calc_cosine_many(mean_emb_max, mean_emb_ref) 101 | cosine_argmax.append(float(argmax_cosine)) 102 | 103 | 104 | if not dial_shown: 105 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict))) 106 | ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices] 107 | log.info("Refer: %s", " ~~|~~ ".join(ref_words)) 108 | log.info("Argmax: %s, cosine=%.4f", 109 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), 110 | argmax_cosine) 111 | 112 | for _ in range(args.num_of_samples): 113 | r_sample, actions = cos_net.decode_chain_sampling(item_enc, beg_embedding, data.MAX_TOKENS, 114 | stop_at_token=end_token) 115 | mean_emb_samp = net.get_mean_emb(beg_embedding, actions) 116 | sample_cosine = utils.calc_cosine_many(mean_emb_samp, mean_emb_ref) 117 | 118 | if not dial_shown: 119 | log.info("Sample: %s, cosine=%.4f", 120 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), 121 | sample_cosine) 122 | 123 | net_policies.append(r_sample) 124 | net_actions.extend(actions) 125 | net_advantages.extend([float(sample_cosine) - argmax_cosine] * len(actions)) 126 | cosine_sample.append(float(sample_cosine)) 127 | dial_shown = True 128 | 129 | if not net_policies: 130 | continue 131 | 132 | policies_v = torch.cat(net_policies) 133 | actions_t = torch.LongTensor(net_actions).to(device) 134 | adv_v = torch.FloatTensor(net_advantages).to(device) 135 | log_prob_v = F.log_softmax(policies_v, dim=1) 136 | log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t] 137 | loss_policy_v = -log_prob_actions_v.mean() 138 | 139 | loss_v = loss_policy_v 140 | loss_v.backward() 141 | optimiser.step() 142 | 143 | cosine_test = run_test_cosine(test_data, cos_net, net, beg_token, end_token, device="cuda") 144 | cosine = np.mean(cosine_argmax) 145 | writer.add_scalar("cosine_test", cosine_test, batch_idx) 146 | writer.add_scalar("cosine_argmax", cosine, batch_idx) 147 | writer.add_scalar("cosine_sample", np.mean(cosine_sample), batch_idx) 148 | writer.add_scalar("epoch", batch_idx, epoch) 149 | log.info("Epoch %d, test COSINE: %.3f", epoch, cosine_test) 150 | if best_cosine is None or best_cosine < cosine_test: 151 | best_cosine = cosine_test 152 | log.info("Best cosine updated: %.4f", cosine_test) 153 | torch.save(cos_net.state_dict(), os.path.join(saves_path, "cosine_%.3f_%02d.dat" % (cosine_test, epoch))) 154 | if epoch % 10 == 0: 155 | torch.save(cos_net.state_dict(), 156 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, cosine, cosine_test))) 157 | 158 | writer.close() -------------------------------------------------------------------------------- /use_model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import logging 5 | import torch 6 | 7 | from libbots import data, model, utils 8 | 9 | 10 | log = logging.getLogger("use") 11 | k_sentences = 100 12 | 13 | def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False): 14 | tokens = data.encode_words(words, emb_dict) 15 | input_seq = model.pack_input(tokens, net.emb) 16 | enc = net.encode(input_seq) 17 | end_token = emb_dict[data.END_TOKEN] 18 | if use_sampling: 19 | _, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, 20 | stop_at_token=end_token) 21 | else: 22 | _, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, 23 | stop_at_token=end_token) 24 | if out_tokens[-1] == end_token: 25 | out_tokens = out_tokens[:-1] 26 | out_words = data.decode_words(out_tokens, rev_emb_dict) 27 | return out_words 28 | 29 | def mutual_words_to_words(words, emb_dict, rev_emb_dict, back_emb_dict, net, back_net): 30 | # Forward 31 | tokens = data.encode_words(words, emb_dict) 32 | source_seq = model.pack_input(tokens, net.emb) 33 | enc = net.encode(source_seq) 34 | end_token = emb_dict[data.END_TOKEN] 35 | list_of_out_tokens, probs = net.decode_k_best(enc, source_seq.data[0:1], k_sentences, seq_len=data.MAX_TOKENS, 36 | stop_at_token=end_token) 37 | list_of_out_words = [] 38 | for iTokens in range(len(list_of_out_tokens)): 39 | if list_of_out_tokens[iTokens][-1] == end_token: 40 | list_of_out_tokens[iTokens] = list_of_out_tokens[iTokens][:-1] 41 | list_of_out_words.append(data.decode_words(list_of_out_tokens[iTokens], rev_emb_dict)) 42 | 43 | # Backward 44 | back_seq2seq_prob = [] 45 | for iTarget in range(len(list_of_out_words)): 46 | b_tokens = data.encode_words(list_of_out_words[iTarget], back_emb_dict) 47 | target_seq = model.pack_input(b_tokens, back_net.emb) 48 | b_enc = back_net.encode(target_seq) 49 | back_seq2seq_prob.append(back_net.get_qp_prob(b_enc, target_seq.data[0:1], tokens[1:])) 50 | 51 | mutual_prob = [] 52 | for i in range(len(probs)): 53 | mutual_prob.append(probs[i] + back_seq2seq_prob[i]) 54 | most_prob_mutual_sen_id = sorted(range(len(mutual_prob)), key=lambda s: mutual_prob[s])[-1:][0] 55 | 56 | return list_of_out_words[most_prob_mutual_sen_id], mutual_prob[most_prob_mutual_sen_id] 57 | 58 | def process_string(words, emb_dict, rev_emb_dict, net, use_sampling=False): 59 | out_words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=use_sampling) 60 | print(" ".join(out_words)) 61 | 62 | 63 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO) 64 | 65 | load_seq2seq_path = 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' 66 | laod_b_seq2seq_path = 'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat' 67 | bleu_model_path = 'Final_Saves/RL_BLUE/bleu_0.135_177.dat' 68 | mutual_model_path = 'Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat' 69 | prep_model_path = 'Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat' 70 | cos_model_path = 'Final_Saves/RL_COSINE/cosine_0.621_03.dat' 71 | 72 | input_sentences = [] 73 | input_sentences.append("Do you want to stay?") 74 | input_sentences.append("What's your full name?") 75 | input_sentences.append("Where are you going?") 76 | input_sentences.append("How old are you?") 77 | input_sentences.append("Where are you?") 78 | input_sentences.append("hi, joey.") 79 | input_sentences.append("let's go.") 80 | input_sentences.append("excuse me?") 81 | input_sentences.append("what's that?") 82 | input_sentences.append("Stop!") 83 | input_sentences.append("where ya goin?") 84 | input_sentences.append("what's this?") 85 | input_sentences.append("Do you play football?") 86 | input_sentences.append("who is she?") 87 | input_sentences.append("who is he?") 88 | input_sentences.append("Are you sure?") 89 | input_sentences.append("Did you see that?") 90 | input_sentences.append("Hello.") 91 | 92 | sample = False 93 | mutual = True 94 | RL = True 95 | self = 1 96 | device = torch.device("cuda") # "cuda"/"cpu" 97 | 98 | # Load Seq2Seq 99 | seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(load_seq2seq_path)) 100 | seq2seq_rev_emb_dict = {idx: word for word, idx in seq2seq_emb_dict.items()} 101 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(seq2seq_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 102 | net.load_state_dict(torch.load(load_seq2seq_path)) 103 | 104 | # Load Back Seq2Seq 105 | b_seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(laod_b_seq2seq_path)) 106 | b_seq2seq_rev_emb_dict = {idx: word for word, idx in b_seq2seq_emb_dict.items()} 107 | b_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(b_seq2seq_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 108 | b_net.load_state_dict(torch.load(laod_b_seq2seq_path)) 109 | 110 | # Load BLEU 111 | bleu_emb_dict = data.load_emb_dict(os.path.dirname(bleu_model_path)) 112 | bleu_rev_emb_dict = {idx: word for word, idx in bleu_emb_dict.items()} 113 | bleu_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(bleu_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 114 | bleu_net.load_state_dict(torch.load(bleu_model_path)) 115 | 116 | # Load Mutual 117 | mutual_emb_dict = data.load_emb_dict(os.path.dirname(mutual_model_path)) 118 | mutual_rev_emb_dict = {idx: word for word, idx in mutual_emb_dict.items()} 119 | mutual_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(mutual_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 120 | mutual_net.load_state_dict(torch.load(mutual_model_path)) 121 | 122 | 123 | # Load Preplexity 124 | prep_emb_dict = data.load_emb_dict(os.path.dirname(prep_model_path)) 125 | prep_rev_emb_dict = {idx: word for word, idx in prep_emb_dict.items()} 126 | prep_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(prep_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 127 | prep_net.load_state_dict(torch.load(prep_model_path)) 128 | 129 | # Load Preplexity 130 | prep1_emb_dict = data.load_emb_dict(os.path.dirname(prep_model_path)) 131 | prep1_rev_emb_dict = {idx: word for word, idx in prep1_emb_dict.items()} 132 | prep1_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(prep1_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 133 | prep1_net.load_state_dict(torch.load(prep_model_path)) 134 | 135 | # Load Cosine Similarity 136 | cos_emb_dict = data.load_emb_dict(os.path.dirname(cos_model_path)) 137 | cos_rev_emb_dict = {idx: word for word, idx in cos_emb_dict.items()} 138 | cos_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(cos_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device) 139 | cos_net.load_state_dict(torch.load(cos_model_path)) 140 | 141 | while True: 142 | input_sentence = input_sentences[0] 143 | if input_sentence: 144 | input_string = input_sentence 145 | else: 146 | input_string = input(">>> ") 147 | if not input_string: 148 | break 149 | 150 | words = utils.tokenize(input_string) 151 | for _ in range(self): 152 | if RL: 153 | words_seq2seq = words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, net, use_sampling=sample) 154 | words_bleu = words_to_words(words, bleu_emb_dict, bleu_rev_emb_dict, bleu_net, use_sampling=sample) 155 | words_mutual_RL = words_to_words(words, mutual_emb_dict, mutual_rev_emb_dict, mutual_net, use_sampling=sample) 156 | words_mutual, _ = mutual_words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, b_seq2seq_emb_dict, 157 | net, b_net) 158 | 159 | words_prep = words_to_words(words, prep_emb_dict, prep_rev_emb_dict, prep_net, use_sampling=sample) 160 | words_prep = words_to_words(words, prep1_emb_dict, prep1_rev_emb_dict, prep1_net, use_sampling=sample) 161 | words_cosine = words_to_words(words, cos_emb_dict, cos_rev_emb_dict, cos_net, use_sampling=sample) 162 | 163 | print('Seq2Seq: ', utils.untokenize(words_seq2seq)) 164 | print('BLEU: ', utils.untokenize(words_bleu)) 165 | print('Mutual Information (RL): ', utils.untokenize(words_mutual_RL)) 166 | print('Mutual Information: ', utils.untokenize(words_mutual)) 167 | print('Perplexity: ', utils.untokenize(words_prep)) 168 | print('Perplexity: ', utils.untokenize(words_prep)) 169 | print('Cosine Similarity: ', utils.untokenize(words_cosine)) 170 | else: 171 | if mutual: 172 | words, _ = mutual_words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, b_seq2seq_emb_dict, 173 | net, b_net) 174 | else: 175 | words = words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, net, use_sampling=sample) 176 | 177 | 178 | if input_string: 179 | break 180 | pass 181 | --------------------------------------------------------------------------------