├── .gitignore
├── Deep Reinforcement Learning For Language Models.pdf
├── Final_Saves
├── RL_BLUE
│ ├── bleu_0.135_177.dat
│ └── emb_dict.dat
├── RL_COSINE
│ ├── cosine_0.621_03.dat
│ └── emb_dict.dat
├── RL_Mutual
│ ├── emb_dict.dat
│ └── epoch_180_-4.325_-7.192.dat
├── RL_Perplexity
│ ├── emb_dict.dat
│ └── epoch_050_1.463_3.701.dat
├── backward_seq2seq
│ ├── emb_dict.dat
│ └── epoch_080_0.780_0.104.dat
└── seq2seq
│ ├── emb_dict.dat
│ └── epoch_090_0.800_0.107.dat
├── Images
├── MMI comparison.PNG
├── Training BLEU RL.PNG
├── Training CE Seq2Seq.PNG
├── Training RL Cosine.PNG
├── Training RL MMI.PNG
├── Training RL Perplexity.PNG
├── automatic evaluations.PNG
└── many generated sequences.PNG
├── README.md
├── cur_reader.py
├── data_test.py
├── libbots
├── cornell.py
├── data.py
├── model.py
└── utils.py
├── model_test.py
├── test_all_modells.py
├── tests
├── test_data.py
└── test_subtitles.py
├── train_crossent.py
├── train_rl_BLEU.py
├── train_rl_MMI.py
├── train_rl_PREPLEXITY.py
├── train_rl_cosine.py
└── use_model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | cornell
2 | data
--------------------------------------------------------------------------------
/Deep Reinforcement Learning For Language Models.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Deep Reinforcement Learning For Language Models.pdf
--------------------------------------------------------------------------------
/Final_Saves/RL_BLUE/bleu_0.135_177.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_BLUE/bleu_0.135_177.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_BLUE/emb_dict.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_BLUE/emb_dict.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_COSINE/cosine_0.621_03.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_COSINE/cosine_0.621_03.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_COSINE/emb_dict.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_COSINE/emb_dict.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_Mutual/emb_dict.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Mutual/emb_dict.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_Perplexity/emb_dict.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Perplexity/emb_dict.dat
--------------------------------------------------------------------------------
/Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat
--------------------------------------------------------------------------------
/Final_Saves/backward_seq2seq/emb_dict.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/backward_seq2seq/emb_dict.dat
--------------------------------------------------------------------------------
/Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat
--------------------------------------------------------------------------------
/Final_Saves/seq2seq/emb_dict.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/seq2seq/emb_dict.dat
--------------------------------------------------------------------------------
/Final_Saves/seq2seq/epoch_090_0.800_0.107.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Final_Saves/seq2seq/epoch_090_0.800_0.107.dat
--------------------------------------------------------------------------------
/Images/MMI comparison.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/MMI comparison.PNG
--------------------------------------------------------------------------------
/Images/Training BLEU RL.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training BLEU RL.PNG
--------------------------------------------------------------------------------
/Images/Training CE Seq2Seq.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training CE Seq2Seq.PNG
--------------------------------------------------------------------------------
/Images/Training RL Cosine.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training RL Cosine.PNG
--------------------------------------------------------------------------------
/Images/Training RL MMI.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training RL MMI.PNG
--------------------------------------------------------------------------------
/Images/Training RL Perplexity.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/Training RL Perplexity.PNG
--------------------------------------------------------------------------------
/Images/automatic evaluations.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/automatic evaluations.PNG
--------------------------------------------------------------------------------
/Images/many generated sequences.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eyalbd2/RL-based-Language-Modeling/b3467430be102ba23b1ab7e7a9515496c140f524/Images/many generated sequences.PNG
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving Language Models Using Reinforcement Learning
2 |
3 | In this work, we show how to use traditional
4 | evaluation methods used in NLP to improve a language model. We construct
5 | couple of models, using different objective for each model, and comparing between
6 | them, while offering some intuition regarding the results.
7 |
8 | The project implementation contains:
9 | 1. Basic Cross-entropy seq2seq model training - Using couple of tricks to get a faster training convergence
10 | 2. 4 Policy Gradient model - using 4 different rewards (for further explanation, please refer to "Deep Reinforcement Learning For Language Models.pdf", which is in this repository):
11 | - [ ] BLUEU score
12 | - [ ] Maximum Mutual Information score
13 | - [ ] Perplexity score
14 | - [ ] Cosine Similarity score
15 | 3. Testing generated dialogues - we test each of our models and sum up the results (for all results, please refer to "Deep Reinforcement Learning For Language Models.pdf", which is in this repository).
16 | 4. Automatic evaluation - we evaluate each of the models using traditional NLP evaluation methods.
17 |
18 |
19 |
20 |
21 |
22 | Here we present a comparison of target sentences generation given a source sequence between Perplexity, MMI, BLEU and Seq2Seq.
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | This is a comparison of target sentences generation given a source sequence between a model of MMI trained using reinforcement learning policy gradient and a model which is calculating MMI only at test time.
31 |
32 |
33 | ## Getting Started
34 |
35 | - [ ] Install and arrange an environment that meets the prerequisites (use below subsection)
36 | - [ ] Pycharm is recommended for code development
37 | - [ ] Clone this directory to your computer
38 | - [ ] Train a cross-entropy seq2seq model using 'train_crossent' function
39 | - [ ] Train RL models using appropriate functions (e.g. use 'train_rl_MMI' to use MMI criterion for policy gradient)
40 | - [ ] Test models using automatic evaluations using 'test_all_modells'
41 | - [ ] Generate responces using 'use_model'
42 |
43 | ### Prerequisites
44 |
45 | Before starting, make sure you already have the followings, if you don't - you should install before trying to use this code:
46 | - [ ] the code was tested only on python 3.6
47 | - [ ] numpy
48 | - [ ] scipy
49 | - [ ] pickle
50 | - [ ] pytorch
51 | - [ ] TensorboardX
52 |
53 |
54 | ## Running the Code
55 |
56 | ### Training all models
57 |
58 | #### Train Cross-entropy Seq2Seq and backward Seq2Seq
59 | Parameters for calling 'train_crossent' ( can be seen when writing --help ):
60 |
61 | | Name Of Param | Description | Default Value | Type |
62 | | --- | --- | --- | --- |
63 | | SAVES_DIR | Save directory | 'saves' | str |
64 | | name | Specific model saves directory | 'seq2seq' | str |
65 | | BATCH_SIZE | Batch Size for training | 32 | int |
66 | | LEARNING_RATE | Learning Rate | 1e-3 | float |
67 | | MAX_EPOCHES | Number of training iterations | 100 | int |
68 | | TEACHER_PROB | Probability to force reference inputs | 0.5 | float |
69 | | data | Genre to use - for data | 'comedy' | str |
70 | | train_backward | Choose - train backward/forward model | False | bool |
71 |
72 | Run Seq2Seq using:
73 | ```
74 | python train_crossent.py -SAVES_DIR <'saves'> -name <'seq2seq'> -BATCH_SIZE <32> -LEARNING_RATE <0.001> -MAX_EPOCHES <100> -TEACHER_PROB <0.5> -data <'comedy'> -train_backward
75 | ```
76 |
77 | Run backward Seq2Seq using:
78 | ```
79 | python train_crossent.py -SAVES_DIR <'saves'> -name <'backward_seq2seq'> -BATCH_SIZE <32> -LEARNING_RATE <0.001> -MAX_EPOCHES <100> -TEACHER_PROB <0.5> -data <'comedy'> -train_backward
80 | ```
81 |
82 | We present below some training graph's of the CE Seq2Seq and backward Seq2Seq (Loss and BLEU score for both models)
83 |
84 |
85 |
86 |
87 |
88 | #### Train Policy Gradient using BLEU reward
89 | Parameters for calling 'train_rl_BLEU' ( can be seen when writing --help ):
90 |
91 | | Name Of Param | Description | Default Value | Type |
92 | | --- | --- | --- | --- |
93 | | SAVES_DIR | Save directory | 'saves' | str |
94 | | name | Specific model saves directory | 'RL_BLUE' | str |
95 | | BATCH_SIZE | Batch Size for training | 16 | int |
96 | | LEARNING_RATE | Learning Rate | 1e-4 | float |
97 | | MAX_EPOCHES | Number of training iterations | 10000 | int |
98 | | data | Genre to use - for data | 'comedy' | str |
99 | | num_of_samples | Number of samples per per each example | 4 | int |
100 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str |
101 |
102 | Run using:
103 | ```
104 | python train_rl_BLEU.py -SAVES_DIR <'saves'> -name <'RL_BLUE'> -BATCH_SIZE <16> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'>
105 | ```
106 |
107 | BLEU agent training graph:
108 |
109 |
110 |
111 |
112 | #### Train Policy Gradient using MMI reward
113 | Parameters for calling 'train_rl_MMI' ( can be seen when writing --help ):
114 |
115 | | Name Of Param | Description | Default Value | Type |
116 | | --- | --- | --- | --- |
117 | | SAVES_DIR | Save directory | 'saves' | str |
118 | | name | Specific model saves directory | 'RL_Mutual' | str |
119 | | BATCH_SIZE | Batch Size for training | 32 | int |
120 | | LEARNING_RATE | Learning Rate | 1e-4 | float |
121 | | MAX_EPOCHES | Number of training iterations | 10000 | int |
122 | | CROSS_ENT_PROB | Probability to run a CE batch | 0.3 | float |
123 | | TEACHER_PROB | Probability to run an imitation batch in case of using CE | 0.8 | float |
124 | | data | Genre to use - for data | 'comedy' | str |
125 | | num_of_samples | Number of samples per per each example | 4 | int |
126 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str |
127 | | laod_b_seq2seq_path | Pre-trained backward seq2seq model location | 'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat' | str |
128 |
129 | Run using:
130 | ```
131 | python train_rl_MMI.py -SAVES_DIR <'saves'> -name <'RL_Mutual'> -BATCH_SIZE <32> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -CROSS_ENT_PROB <0.3> -TEACHER_PROB <0.8> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> -laod_b_seq2seq_path <'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat'>
132 | ```
133 |
134 | Training graph of the MMI agent:
135 |
136 |
137 |
138 |
139 | #### Train Policy Gradient using Perplexity reward
140 | Parameters for calling 'train_rl_PREPLEXITY' ( can be seen when writing --help ):
141 |
142 | | Name Of Param | Description | Default Value | Type |
143 | | --- | --- | --- | --- |
144 | | SAVES_DIR | Save directory | 'saves' | str |
145 | | name | Specific model saves directory | 'RL_PREPLEXITY' | str |
146 | | BATCH_SIZE | Batch Size for training | 32 | int |
147 | | LEARNING_RATE | Learning Rate | 1e-4 | float |
148 | | MAX_EPOCHES | Number of training iterations | 10000 | int |
149 | | CROSS_ENT_PROB | Probability to run a CE batch | 0.5 | float |
150 | | TEACHER_PROB | Probability to run an imitation batch in case of using CE | 0.5 | float |
151 | | data | Genre to use - for data | 'comedy' | str |
152 | | num_of_samples | Number of samples per per each example | 4 | int |
153 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str |
154 |
155 | Run using:
156 | ```
157 | python train_rl_PREPLEXITY.py -SAVES_DIR <'saves'> -name <'RL_PREPLEXITY'> -BATCH_SIZE <32> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -CROSS_ENT_PROB <0.5> -TEACHER_PROB <0.5> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'>
158 | ```
159 |
160 | Below is presented with training graph of the perplexity agent:
161 |
162 |
163 |
164 |
165 | #### Train Policy Gradient using Cosine Similarity reward
166 | Parameters for calling 'train_rl_cosine' ( can be seen when writing --help ):
167 |
168 | | Name Of Param | Description | Default Value | Type |
169 | | --- | --- | --- | --- |
170 | | SAVES_DIR | Save directory | 'saves' | str |
171 | | name | Specific model saves directory | 'RL_COSINE' | str |
172 | | BATCH_SIZE | Batch Size for training | 16 | int |
173 | | LEARNING_RATE | Learning Rate | 1e-4 | float |
174 | | MAX_EPOCHES | Number of training iterations | 10000 | int |
175 | | data | Genre to use - for data | 'comedy' | str |
176 | | num_of_samples | Number of samples per per each example | 4 | int |
177 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str |
178 |
179 | Run using:
180 | ```
181 | python train_rl_cosine.py -SAVES_DIR <'saves'> -name <'RL_COSINE'> -BATCH_SIZE <16> -LEARNING_RATE <0.0001> -MAX_EPOCHES <10000> -data <'comedy'> -num_of_samples <4> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'>
182 | ```
183 |
184 | Below we present training graph of the cosine agent:
185 |
186 |
187 |
188 |
189 | ### Testing all models
190 | There are two types of test: automatic and qualitative.
191 |
192 | For qualitative test, it is recommended to use the 'use_model' function with pycharm interface. I have suggested (can be seen in code) many sentences to use as source sentences.
193 |
194 | For automatic test, a user can call 'test_all_modells' function, with the parameters:
195 |
196 | | Name Of Param | Description | Default Value | Type |
197 | | --- | --- | --- | --- |
198 | | SAVES_DIR | Save directory | 'saves' | str |
199 | | BATCH_SIZE | Batch Size for training | 32 | int |
200 | | data | Genre to use - for data | 'comedy' | str |
201 | | load_seq2seq_path | Pre-trained seq2seq model location | 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat' | str |
202 | | laod_b_seq2seq_path | Pre-trained backward seq2seq model location | 'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat' | str |
203 | | bleu_model_path | Pre-trained BLEU model location | 'Final_Saves/RL_BLUE/bleu_0.135_177.dat' | str |
204 | | mutual_model_path | Pre-trained MMI model location | 'Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat' | str |
205 | | prep_model_path | Pre-trained Perplexity model location | 'Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat' | str |
206 | | cos_model_path | Pre-trained Cosine Similarity model location | 'Final_Saves/RL_COSINE/cosine_0.621_03.dat' | str |
207 |
208 |
209 | Run using:
210 | ```
211 | python test_all_modells.py -SAVES_DIR <'saves'> -BATCH_SIZE <16> -data <'comedy'> -load_seq2seq_path <'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'> -laod_b_seq2seq_path <'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat'> -bleu_model_path <'Final_Saves/RL_BLUE/bleu_0.135_177.dat'> -mutual_model_path <'Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat'> -prep_model_path <'Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat'> -cos_model_path <'Final_Saves/RL_COSINE/cosine_0.621_03.dat'>
212 | ```
213 |
214 |
215 | ## Results for tests
216 | WE present below a table with automatic tests results on 'comedy' movies. For more results and conclusions, please review 'Deep Reinforcement Learning For Language Models.pdf' file in this directory.
217 |
218 |
219 |
220 |
221 |
222 | ## Contributing
223 |
224 | Please read [CONTRIBUTING.md](https://gist.github.com/PurpleBooth/b24679402957c63ec426) for details on our code of conduct, and the process for submitting pull requests to us.
225 |
226 |
227 |
228 | ## Authors
229 |
230 | * **Eyal Ben David**
231 |
232 |
233 | ## License
234 |
235 | This project is licensed under the MIT License
236 |
237 | ## Acknowledgments
238 |
239 | * Inspiration for using RL in text tasks - "Deep Reinforcement Learning For Dialogue Generation", (Ritter et al., 1996)
240 | * Policy Gradient Implementation basics in dialogue agents - "Deep Reinforcement Learning Hands-On", by Maxim Lapan (Chapter 12)
--------------------------------------------------------------------------------
/cur_reader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import collections
4 |
5 | from libbots import cornell, data
6 |
7 |
8 | if __name__ == "__main__":
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("-g", "--genre", default='', help="Genre to show dialogs from")
11 | parser.add_argument("--show-genres", action='store_true', default=False, help="Display genres stats")
12 | parser.add_argument("--show-dials", action='store_true', default=False, help="Display dialogs")
13 | parser.add_argument("--show-train", action='store_true', default=False, help="Display training pairs")
14 | parser.add_argument("--show-dict-freq", action='store_true', default=False, help="Display dictionary frequency")
15 | args = parser.parse_args()
16 |
17 | if args.show_genres:
18 | genre_counts = collections.Counter()
19 | genres = cornell.read_genres(cornell.DATA_DIR)
20 | for movie, g_list in genres.items():
21 | for g in g_list:
22 | genre_counts[g] += 1
23 | print("Genres:")
24 | for g, count in genre_counts.most_common():
25 | print("%s: %d" % (g, count))
26 |
27 | if args.show_dials:
28 | dials = cornell.load_dialogues(genre_filter=args.genre)
29 | for d_idx, dial in enumerate(dials):
30 | print("Dialog %d with %d phrases:" % (d_idx, len(dial)))
31 | for p in dial:
32 | print(" ".join(p))
33 | print()
34 |
35 | if args.show_train or args.show_dict_freq:
36 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.genre)
37 |
38 | if args.show_train:
39 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
40 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
41 | train_data = data.group_train_data(train_data)
42 | unk_token = emb_dict[data.UNKNOWN_TOKEN]
43 |
44 | print("Training pairs (%d total)" % len(train_data))
45 | train_data.sort(key=lambda p: len(p[1]), reverse=True)
46 | for idx, (p1, p2_group) in enumerate(train_data):
47 | w1 = data.decode_words(p1, rev_emb_dict)
48 | w2_group = [data.decode_words(p2, rev_emb_dict) for p2 in p2_group]
49 | print("%d:" % idx, " ".join(w1))
50 | for w2 in w2_group:
51 | print("%s:" % (" " * len(str(idx))), " ".join(w2))
52 |
53 | if args.show_dict_freq:
54 | words_stat = collections.Counter()
55 | for p1, p2 in phrase_pairs:
56 | words_stat.update(p1)
57 | print("Frequency stats for %d tokens in the dict" % len(emb_dict))
58 | for token, count in words_stat.most_common():
59 | print("%s: %d" % (token, count))
60 | pass
--------------------------------------------------------------------------------
/data_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import logging
4 |
5 | from libbots import data, model, utils
6 |
7 | import torch
8 |
9 | log = logging.getLogger("data_test")
10 |
11 | if __name__ == "__main__":
12 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--data", required=True,
15 | help="Category to use for training. Empty string to train on full dataset")
16 | parser.add_argument("-m", "--model", required=True, help="Model name to load")
17 | args = parser.parse_args()
18 |
19 | phrase_pairs, emb_dict = data.load_data(args.data)
20 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
21 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
22 | train_data = data.group_train_data(train_data)
23 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
24 |
25 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
26 | net.load_state_dict(torch.load(args.model))
27 |
28 | end_token = emb_dict[data.END_TOKEN]
29 |
30 | seq_count = 0
31 | sum_bleu = 0.0
32 |
33 | for seq_1, targets in train_data:
34 | input_seq = model.pack_input(seq_1, net.emb)
35 | enc = net.encode(input_seq)
36 | _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
37 | seq_len=data.MAX_TOKENS, stop_at_token=end_token)
38 | references = [seq[1:] for seq in targets]
39 | bleu = utils.calc_bleu_many(tokens, references)
40 | sum_bleu += bleu
41 | seq_count += 1
42 |
43 | log.info("Processed %d phrases, mean BLEU = %.4f", seq_count, sum_bleu / seq_count)
--------------------------------------------------------------------------------
/libbots/cornell.py:
--------------------------------------------------------------------------------
1 | """
2 | Cornel Movies Dialogs Corpus
3 | https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
4 | """
5 | import os
6 | import logging
7 |
8 | from . import utils
9 |
10 | log = logging.getLogger("cornell")
11 | DATA_DIR = "data/cornell"
12 | SEPARATOR = "+++$+++"
13 |
14 |
15 | def load_dialogues(data_dir=DATA_DIR, genre_filter=''):
16 | """
17 | Load dialogues from cornell data
18 | :return: list of list of list of words
19 | """
20 | movie_set = None
21 | if genre_filter:
22 | movie_set = read_movie_set(data_dir, genre_filter)
23 | log.info("Loaded %d movies with genre %s", len(movie_set), genre_filter)
24 | log.info("Read and tokenise phrases...")
25 | lines = read_phrases(data_dir, movies=movie_set)
26 | log.info("Loaded %d phrases", len(lines))
27 | dialogues = load_conversations(data_dir, lines, movie_set)
28 | return dialogues
29 |
30 |
31 | def iterate_entries(data_dir, file_name):
32 | with open(os.path.join(data_dir, file_name), "rb") as fd:
33 | for l in fd:
34 | l = str(l, encoding='utf-8', errors='ignore')
35 | yield list(map(str.strip, l.split(SEPARATOR)))
36 |
37 |
38 | def read_movie_set(data_dir, genre_filter):
39 | res = set()
40 | for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"):
41 | m_id, m_genres = parts[0], parts[5]
42 | if m_genres.find(genre_filter) != -1:
43 | res.add(m_id)
44 | return res
45 |
46 |
47 | def read_phrases(data_dir, movies=None):
48 | res = {}
49 | for parts in iterate_entries(data_dir, "movie_lines.txt"):
50 | l_id, m_id, l_str = parts[0], parts[2], parts[4]
51 | if movies and m_id not in movies:
52 | continue
53 | tokens = utils.tokenize(l_str)
54 | if tokens:
55 | res[l_id] = tokens
56 | return res
57 |
58 |
59 | def load_conversations(data_dir, lines, movies=None):
60 | res = []
61 | for parts in iterate_entries(data_dir, "movie_conversations.txt"):
62 | m_id, dial_s = parts[2], parts[3]
63 | if movies and m_id not in movies:
64 | continue
65 | l_ids = dial_s.strip("[]").split(", ")
66 | l_ids = list(map(lambda s: s.strip("'"), l_ids))
67 | dial = [lines[l_id] for l_id in l_ids if l_id in lines]
68 | if dial:
69 | res.append(dial)
70 | return res
71 |
72 |
73 | def read_genres(data_dir):
74 | res = {}
75 | for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"):
76 | m_id, m_genres = parts[0], parts[5]
77 | l_genres = m_genres.strip("[]").split(", ")
78 | l_genres = list(map(lambda s: s.strip("'"), l_genres))
79 | res[m_id] = l_genres
80 | return res
--------------------------------------------------------------------------------
/libbots/data.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import sys
4 | import logging
5 | import itertools
6 | import pickle
7 |
8 | from . import cornell
9 |
10 | UNKNOWN_TOKEN = '#UNK'
11 | BEGIN_TOKEN = "#BEG"
12 | END_TOKEN = "#END"
13 | MAX_TOKENS = 20
14 | MIN_TOKEN_FEQ = 10
15 | SHUFFLE_SEED = 5871
16 |
17 | EMB_DICT_NAME = "emb_dict.dat"
18 | EMB_NAME = "emb.npy"
19 |
20 | log = logging.getLogger("data")
21 |
22 |
23 | def save_emb_dict(dir_name, emb_dict):
24 | with open(os.path.join(dir_name, EMB_DICT_NAME), "wb") as fd:
25 | pickle.dump(emb_dict, fd)
26 |
27 |
28 | def load_emb_dict(dir_name):
29 | with open(os.path.join(dir_name, EMB_DICT_NAME), "rb") as fd:
30 | return pickle.load(fd)
31 |
32 |
33 | def encode_words(words, emb_dict):
34 | """
35 | Convert list of words into list of embeddings indices, adding our tokens
36 | :param words: list of strings
37 | :param emb_dict: embeddings dictionary
38 | :return: list of IDs
39 | """
40 | res = [emb_dict[BEGIN_TOKEN]]
41 | unk_idx = emb_dict[UNKNOWN_TOKEN]
42 | for w in words:
43 | idx = emb_dict.get(w.lower(), unk_idx)
44 | res.append(idx)
45 | res.append(emb_dict[END_TOKEN])
46 | return res
47 |
48 |
49 | def encode_phrase_pairs(phrase_pairs, emb_dict, filter_unknows=True):
50 | """
51 | Convert list of phrase pairs to training data
52 | :param phrase_pairs: list of (phrase, phrase)
53 | :param emb_dict: embeddings dictionary (word -> id)
54 | :return: list of tuples ([input_id_seq], [output_id_seq])
55 | """
56 | unk_token = emb_dict[UNKNOWN_TOKEN]
57 | result = []
58 | for q, p in phrase_pairs:
59 | pq = encode_words(q, emb_dict), encode_words(p, emb_dict)
60 | if unk_token in pq[0] or unk_token in pq[1]:
61 | continue
62 | result.append(pq)
63 | return result
64 |
65 |
66 | def group_train_data(training_data):
67 | """
68 | Group training pairs by first phrase
69 | :param training_data: list of (seq1, seq2) pairs
70 | :return: list of (seq1, [seq*]) pairs
71 | """
72 | groups = collections.defaultdict(list)
73 | for p1, p2 in training_data:
74 | l = groups[tuple(p1)]
75 | l.append(p2)
76 | return list(groups.items())
77 |
78 |
79 | def iterate_batches(data, batch_size):
80 | assert isinstance(data, list)
81 | assert isinstance(batch_size, int)
82 |
83 | ofs = 0
84 | while True:
85 | batch = data[ofs*batch_size:(ofs+1)*batch_size]
86 | if len(batch) <= 1:
87 | break
88 | yield batch
89 | ofs += 1
90 |
91 | def yield_samples(data):
92 | assert isinstance(data, list)
93 |
94 | ofs = 0
95 | while True:
96 | batch = data[ofs:(ofs+1)]
97 | if len(batch) <= 1:
98 | break
99 | yield batch
100 | ofs += 1
101 |
102 |
103 | def load_data(genre_filter, max_tokens=MAX_TOKENS, min_token_freq=MIN_TOKEN_FEQ):
104 | dialogues = cornell.load_dialogues(genre_filter=genre_filter)
105 | if not dialogues:
106 | log.error("No dialogues found, exit!")
107 | sys.exit()
108 | log.info("Loaded %d dialogues with %d phrases, generating training pairs",
109 | len(dialogues), sum(map(len, dialogues)))
110 | phrase_pairs = dialogues_to_pairs(dialogues, max_tokens=max_tokens)
111 | log.info("Counting freq of words...")
112 | word_counts = collections.Counter()
113 | for dial in dialogues:
114 | for p in dial:
115 | word_counts.update(p)
116 | freq_set = set(map(lambda p: p[0], filter(lambda p: p[1] >= min_token_freq, word_counts.items())))
117 | log.info("Data has %d uniq words, %d of them occur more than %d",
118 | len(word_counts), len(freq_set), min_token_freq)
119 | phrase_dict = phrase_pairs_dict(phrase_pairs, freq_set)
120 | return phrase_pairs, phrase_dict
121 |
122 |
123 | def phrase_pairs_dict(phrase_pairs, freq_set):
124 | """
125 | Return the dict of words in the dialogues mapped to their IDs
126 | :param phrase_pairs: list of (phrase, phrase) pairs
127 | :return: dict
128 | """
129 | res = {UNKNOWN_TOKEN: 0, BEGIN_TOKEN: 1, END_TOKEN: 2}
130 | next_id = 3
131 | for p1, p2 in phrase_pairs:
132 | for w in map(str.lower, itertools.chain(p1, p2)):
133 | if w not in res and w in freq_set:
134 | res[w] = next_id
135 | next_id += 1
136 | return res
137 |
138 |
139 | def dialogues_to_pairs(dialogues, max_tokens=None):
140 | """
141 | Convert dialogues to training pairs of phrases
142 | :param dialogues:
143 | :param max_tokens: limit of tokens in both question and reply
144 | :return: list of (phrase, phrase) pairs
145 | """
146 | result = []
147 | for dial in dialogues:
148 | prev_phrase = None
149 | for phrase in dial:
150 | if prev_phrase is not None:
151 | if max_tokens is None or (len(prev_phrase) <= max_tokens and len(phrase) <= max_tokens):
152 | result.append((prev_phrase, phrase))
153 | prev_phrase = phrase
154 | return result
155 |
156 |
157 | def decode_words(indices, rev_emb_dict):
158 | return [rev_emb_dict.get(int(idx), UNKNOWN_TOKEN) for idx in indices]
159 |
160 |
161 | def trim_tokens_seq(tokens, end_token):
162 | res = []
163 | for t in tokens:
164 | res.append(t)
165 | if t == end_token:
166 | break
167 | return res
168 |
169 |
170 | def split_train_test(data, train_ratio=0.95):
171 | count = int(len(data) * train_ratio)
172 | return data[:count], data[count:]
--------------------------------------------------------------------------------
/libbots/model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.utils.rnn as rnn_utils
7 | import torch.nn.functional as F
8 | from . import utils, data
9 | from torch.autograd import Variable
10 |
11 |
12 | HIDDEN_STATE_SIZE = 512
13 | EMBEDDING_DIM = 50
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 | class PhraseModel(nn.Module):
18 | def __init__(self, emb_size, dict_size, hid_size):
19 | super(PhraseModel, self).__init__()
20 |
21 | self.emb = nn.Embedding(num_embeddings=dict_size, embedding_dim=emb_size)
22 | self.encoder = nn.LSTM(input_size=emb_size, hidden_size=hid_size,
23 | num_layers=1, batch_first=True)
24 | self.decoder = nn.LSTM(input_size=emb_size, hidden_size=hid_size,
25 | num_layers=1, batch_first=True)
26 | self.output = nn.Sequential(nn.Linear(hid_size, dict_size))
27 |
28 | def encode(self, x):
29 | _, hid = self.encoder(x)
30 | return hid
31 |
32 | def get_encoded_item(self, encoded, index):
33 | # For RNN
34 | # return encoded[:, index:index+1]
35 | # For LSTM
36 | return encoded[0][:, index:index+1].contiguous(), \
37 | encoded[1][:, index:index+1].contiguous()
38 |
39 | def decode_teacher(self, hid, input_seq, detach=False):
40 | # Method assumes batch of size=1
41 | if detach:
42 | out, _ = self.decoder(input_seq, hid)
43 | out = self.output(out.data).detach()
44 | else:
45 | out, _ = self.decoder(input_seq, hid)
46 | out = self.output(out.data)
47 | return out
48 |
49 | def decode_one(self, hid, input_x, detach=False):
50 | if detach:
51 | out, new_hid = self.decoder(input_x.unsqueeze(0), hid)
52 | out = self.output(out).detach()
53 | else:
54 | out, new_hid = self.decoder(input_x.unsqueeze(0), hid)
55 | out = self.output(out)
56 | return out.squeeze(dim=0), new_hid
57 |
58 | def decode_chain_argmax(self, hid, begin_emb, seq_len, stop_at_token=None):
59 | """
60 | Decode sequence by feeding predicted token to the net again. Act greedily
61 | """
62 | res_logits = []
63 | res_tokens = []
64 | cur_emb = begin_emb
65 |
66 | for _ in range(seq_len):
67 | out_logits, hid = self.decode_one(hid, cur_emb)
68 | out_token_v = torch.max(out_logits, dim=1)[1]
69 | out_token = out_token_v.data.cpu().numpy()[0]
70 |
71 | cur_emb = self.emb(out_token_v)
72 |
73 | res_logits.append(out_logits)
74 | res_tokens.append(out_token)
75 | if stop_at_token is not None and out_token == stop_at_token:
76 | break
77 | return torch.cat(res_logits), res_tokens
78 |
79 | def decode_chain_sampling(self, hid, begin_emb, seq_len, stop_at_token=None):
80 | """
81 | Decode sequence by feeding predicted token to the net again.
82 | Act according to probabilities
83 | """
84 | res_logits = []
85 | res_actions = []
86 | cur_emb = begin_emb
87 |
88 | for _ in range(seq_len):
89 | out_logits, hid = self.decode_one(hid, cur_emb)
90 | out_probs_v = F.softmax(out_logits, dim=1)
91 | out_probs = out_probs_v.data.cpu().numpy()[0]
92 | action = np.random.choice(out_probs.shape[0], p=out_probs)
93 | action_v = torch.LongTensor([action]).to(begin_emb.device)
94 | cur_emb = self.emb(action_v)
95 |
96 | res_logits.append(out_logits)
97 | res_actions.append(action)
98 | if stop_at_token is not None and action == stop_at_token:
99 | break
100 | return torch.cat(res_logits), res_actions
101 |
102 | def decode_rl_chain_argmax(self, hid, begin_emb, seq_len, stop_at_token=None):
103 | """
104 | Decode sequence by feeding predicted token to the net again. Act greedily
105 | """
106 | q_list = []
107 | res_tokens = []
108 | hidden_list = []
109 | emb_list = []
110 | cur_emb = begin_emb
111 | hidden_list.append(hid)
112 | emb_list.append(cur_emb)
113 |
114 |
115 | for _ in range(seq_len):
116 | out_q, hid = self.decode_one(hid, cur_emb)
117 | out_token_v = torch.max(out_q, dim=1)[1]
118 | out_token = out_token_v.data.cpu().numpy()[0]
119 |
120 | cur_emb = self.emb(out_token_v)
121 |
122 |
123 | hidden_list.append(hid)
124 | emb_list.append(cur_emb)
125 | q_list.append(out_q[0][out_token])
126 | res_tokens.append(out_token)
127 | if stop_at_token is not None and out_token == stop_at_token:
128 | break
129 | return q_list, hidden_list, torch.cat(emb_list), res_tokens
130 |
131 | def decode_rl_chain_sampling(self, hid, begin_emb, seq_len, stop_at_token=None):
132 | """
133 | Decode sequence by feeding predicted token to the net again.
134 | Act according to probabilities
135 | """
136 | q_list = []
137 | res_tokens = []
138 | hidden_list = []
139 | emb_list = []
140 |
141 | cur_emb = begin_emb
142 | hidden_list.append(hid)
143 | emb_list.append(cur_emb)
144 |
145 | for _ in range(seq_len):
146 | out_logits, hid = self.decode_one(hid, cur_emb)
147 | out_probs_v = F.softmax(out_logits, dim=1)
148 | out_probs = out_probs_v.data.cpu().numpy()[0]
149 | action = int(np.random.choice(out_probs.shape[0], p=out_probs))
150 | action_v = torch.LongTensor([action]).to(begin_emb.device)
151 | cur_emb = self.emb(action_v)
152 |
153 | q_list.append(out_logits[0][action])
154 | res_tokens.append(action)
155 | hidden_list.append(hid)
156 | emb_list.append(cur_emb)
157 | if stop_at_token is not None and action == stop_at_token:
158 | break
159 | return q_list, hidden_list, torch.cat(emb_list), res_tokens
160 |
161 | def decode_batch(self, state):
162 | """
163 | Decode sequence by feeding predicted token to the net again. Act greedily
164 | """
165 | res_logits = []
166 | for idx in range(len(state)):
167 | cur_q, _ = self.decode_one(state[idx][0], state[idx][1].unsqueeze(0))
168 | res_logits.append(cur_q)
169 |
170 | return torch.cat(res_logits)
171 |
172 | def decode_k_best(self, hid, begin_emb, k, seq_len, stop_at_token=None):
173 | """
174 | Decode sequence by feeding predicted token to the net again.
175 | Act according to probabilities
176 | """
177 | cur_emb = begin_emb
178 |
179 | out_logits, hid = self.decode_one(hid, cur_emb)
180 | out_probs_v = F.log_softmax(out_logits, dim=1)
181 | highest_values = torch.topk(out_probs_v, k)
182 | first_beam_probs = highest_values[0]
183 | first_beam_vals = highest_values[1]
184 | beam_emb = []
185 | beam_hid = []
186 | beam_vals = []
187 | beam_prob = []
188 | for i in range(k):
189 | beam_vals.append([])
190 | for i in range(k):
191 | beam_emb.append(self.emb(first_beam_vals)[:, i, :])
192 | beam_hid.append(hid)
193 | beam_vals[i].append(int(first_beam_vals[0][i].data.cpu().numpy()))
194 | beam_prob.append(first_beam_probs[0][i].data.cpu().numpy())
195 |
196 | for num_words in range(1, seq_len):
197 | possible_sentences = []
198 | possible_probs = []
199 | for i in range(k):
200 | cur_emb = beam_emb[i]
201 | cur_hid = beam_hid[i]
202 | cur_prob = beam_prob[i]
203 | if stop_at_token is not None and beam_vals[i][-1] == stop_at_token:
204 | possible_probs.append(cur_prob)
205 | possible_sentences.append((beam_vals[i], beam_hid[i], beam_emb[i]))
206 | continue
207 |
208 | else:
209 | out_logits, hid = self.decode_one(cur_hid, cur_emb)
210 | log_probs = torch.tensor(F.log_softmax(out_logits, dim=1)).to(device)
211 | highest_values = torch.topk(log_probs, k)
212 | probs = highest_values[0]
213 | vals = highest_values[1]
214 | for j in range(k):
215 | prob = (probs[0][j].data.cpu().numpy() + cur_prob) * (num_words / (num_words + 1))
216 | possible_probs.append(prob)
217 | emb = self.emb(vals)[:, j, :]
218 | val = vals[0][j].data.cpu().numpy()
219 | cur_seq_vals = beam_vals[i][:]
220 | cur_seq_vals.append(int(val))
221 | possible_sentences.append((cur_seq_vals, hid, emb))
222 |
223 | max_indices = sorted(range(len(possible_probs)), key=lambda s: possible_probs[s])[-k:]
224 | beam_emb = []
225 | beam_hid = []
226 | beam_vals = []
227 | beam_prob = []
228 | for i in range(k):
229 | (seq_vals, hid, emb) = possible_sentences[max_indices[i]]
230 | beam_emb.append(emb)
231 | beam_hid.append(hid)
232 | beam_vals.append(seq_vals)
233 | beam_prob.append(possible_probs[max_indices[i]])
234 | t_beam_prob = []
235 | for i in range(k):
236 | t_beam_prob.append(Variable(torch.tensor(beam_prob[i]).to(device), requires_grad=True))
237 |
238 | return beam_vals, t_beam_prob
239 |
240 | def decode_k_sampling(self, hid, begin_emb, k, seq_len, stop_at_token=None):
241 | list_res_logit = []
242 | list_res_actions = []
243 | for i in range(k):
244 | res_logits = []
245 | res_actions = []
246 | cur_emb = begin_emb
247 |
248 | for _ in range(seq_len):
249 | out_logits, hid = self.decode_one(hid, cur_emb)
250 | out_probs_v = F.softmax(out_logits, dim=1)
251 | out_probs = out_probs_v.data.cpu().numpy()[0]
252 | action = int(np.random.choice(out_probs.shape[0], p=out_probs))
253 | action_v = torch.LongTensor([action]).to(begin_emb.device)
254 | cur_emb = self.emb(action_v)
255 |
256 | res_logits.append(out_logits)
257 | res_actions.append(action)
258 | if stop_at_token is not None and action == stop_at_token:
259 | break
260 |
261 | list_res_logit.append(torch.cat(res_logits))
262 | list_res_actions.append(res_actions)
263 |
264 | return list_res_actions, list_res_logit
265 |
266 | def get_qp_prob(self, hid, begin_emb, p_tokens):
267 | """
268 | Find the probability of P(q|p) - the question asked (q, source) given the answer (p, target)
269 | """
270 | cur_emb = begin_emb
271 | prob = 0
272 | for word_num in range(len(p_tokens)):
273 | out_logits, hid = self.decode_one(hid, cur_emb)
274 | out_token_v = F.log_softmax(out_logits, dim=1)[0]
275 | prob = prob + out_token_v[p_tokens[word_num]].data.cpu().numpy()
276 |
277 | out_token = p_tokens[word_num]
278 | action_v = torch.LongTensor([out_token]).to(begin_emb.device)
279 | cur_emb = self.emb(action_v)
280 | if (len(p_tokens)) == 0:
281 | return -20
282 | else:
283 | return prob/(len(p_tokens))
284 |
285 | def get_mean_emb(self, begin_emb, p_tokens):
286 | """
287 | Find the embbeding of the entire sentence
288 | """
289 | emb_list = []
290 | cur_emb = begin_emb
291 | emb_list.append(cur_emb)
292 | if isinstance(p_tokens, list):
293 | if len(p_tokens) == 1:
294 | p_tokens = p_tokens[0]
295 | for word_num in range(len(p_tokens)):
296 | out_token = p_tokens[word_num]
297 | action_v = torch.LongTensor([out_token]).to(begin_emb.device)
298 | cur_emb = self.emb(action_v).detach()
299 | emb_list.append(cur_emb)
300 | else:
301 | out_token = p_tokens
302 | action_v = torch.LongTensor([out_token]).to(begin_emb.device)
303 | cur_emb = self.emb(action_v).detach()
304 | emb_list.append(cur_emb)
305 |
306 | return sum(emb_list)/len(emb_list)
307 |
308 | def get_beam_sentences(self, tokens, emb_dict, k_sentences):
309 | # Forward
310 | source_seq = pack_input(tokens, self.emb)
311 | enc = self.encode(source_seq)
312 | end_token = emb_dict[data.END_TOKEN]
313 | probs, list_of_out_tokens = self.decode_k_sampling(enc, source_seq.data[0:1], k_sentences,
314 | seq_len=data.MAX_TOKENS, stop_at_token=end_token)
315 | # list_of_out_tokens, probs = self.decode_k_sent(enc, source_seq.data[0:1], k_sentences, seq_len=data.MAX_TOKENS,
316 | # stop_at_token=end_token)
317 | return list_of_out_tokens, probs
318 |
319 | def get_action_prob(self, source, target, emb_dict):
320 | source_seq = pack_input(source, self.emb, "cuda")
321 | hid = self.encode(source_seq)
322 | end_token = emb_dict[data.END_TOKEN]
323 | cur_emb = source_seq.data[0:1]
324 | probs = []
325 | for word_num in range(len(target)):
326 | out_logits, hid = self.decode_one(hid, cur_emb)
327 | out_token_v = F.log_softmax(out_logits, dim=1)[0]
328 |
329 | probs.append(out_token_v[target[word_num]])
330 | out_token = target[word_num]
331 | action_v = torch.LongTensor([out_token]).to(device)
332 | cur_emb = self.emb(action_v)
333 | t_prob = probs[0]
334 | for i in range(1, len(probs)):
335 | t_prob = t_prob + probs[i]
336 |
337 | return t_prob/len(probs)
338 |
339 | def get_logits(self, hid, begin_emb, seq_len, res_action, stop_at_token=None):
340 | """
341 | Decode sequence by feeding predicted token to the net again. Act greedily
342 | """
343 | res_logits = []
344 | cur_emb = begin_emb
345 |
346 | for i in range(seq_len):
347 | out_logits, hid = self.decode_one(hid, cur_emb, True)
348 | out_token_v = torch.tensor([res_action[i]]).to(device)
349 | out_token = out_token_v.data.cpu().numpy()[0]
350 |
351 | cur_emb = self.emb(out_token_v).detach()
352 |
353 | res_logits.append(out_logits)
354 | if stop_at_token is not None and out_token == stop_at_token:
355 | break
356 | return torch.cat(res_logits)
357 |
358 |
359 |
360 | def pack_batch_no_out(batch, embeddings, device=device):
361 | assert isinstance(batch, list)
362 | # Sort descending (CuDNN requirements)
363 | batch.sort(key=lambda s: len(s[0]), reverse=True)
364 | input_idx, output_idx = zip(*batch)
365 | # create padded matrix of inputs
366 | lens = list(map(len, input_idx))
367 | input_mat = np.zeros((len(batch), lens[0]), dtype=np.int64)
368 | for idx, x in enumerate(input_idx):
369 | input_mat[idx, :len(x)] = x
370 | input_v = torch.tensor(input_mat).to(device)
371 | input_seq = rnn_utils.pack_padded_sequence(input_v, lens, batch_first=True)
372 | # lookup embeddings
373 | r = embeddings(input_seq.data)
374 | emb_input_seq = rnn_utils.PackedSequence(r, input_seq.batch_sizes)
375 | return emb_input_seq, input_idx, output_idx
376 |
377 | def pack_batch_no_in(batch, embeddings, device=device):
378 | assert isinstance(batch, list)
379 | # Sort descending (CuDNN requirements)
380 | batch.sort(key=lambda s: len(s[1]), reverse=True)
381 | input_idx, output_idx = zip(*batch)
382 | # create padded matrix of inputs
383 | lens = list(map(len, output_idx))
384 | output_mat = np.zeros((len(batch), lens[0]), dtype=np.int64)
385 | for idx, x in enumerate(output_idx):
386 | output_mat[idx, :len(x)] = x
387 | ouput_v = torch.tensor(output_mat).to(device)
388 | output_seq = rnn_utils.pack_padded_sequence(ouput_v, lens, batch_first=True)
389 | # lookup embeddings
390 | r = embeddings(output_seq.data)
391 | emb_input_seq = rnn_utils.PackedSequence(r, output_seq.batch_sizes)
392 | return emb_input_seq, output_idx, input_idx
393 |
394 | def pack_input(input_data, embeddings, device=device, detach=False):
395 | input_v = torch.LongTensor([input_data]).to(device)
396 | if detach:
397 | r = embeddings(input_v).detach()
398 | else:
399 | r = embeddings(input_v)
400 | return rnn_utils.pack_padded_sequence(r, [len(input_data)], batch_first=True)
401 |
402 | def pack_batch(batch, embeddings, device=device, detach=False):
403 | emb_input_seq, input_idx, output_idx = pack_batch_no_out(batch, embeddings, device)
404 |
405 | # prepare output sequences, with end token stripped
406 | output_seq_list = []
407 | for out in output_idx:
408 | if len(out) == 1:
409 | out = out[0]
410 | output_seq_list.append(pack_input(out[:-1], embeddings, device, detach))
411 | return emb_input_seq, output_seq_list, input_idx, output_idx
412 |
413 | def pack_backward_batch(batch, embeddings, device=device):
414 | emb_target_seq, target_idx, source_idx = pack_batch_no_in(batch, embeddings, device)
415 |
416 | # prepare output sequences, with end token stripped
417 | source_seq = []
418 | for src in source_idx:
419 | source_seq.append(pack_input(src[:-1], embeddings, device))
420 | backward_input = emb_target_seq
421 | backward_output = source_seq
422 | return backward_input, backward_output, target_idx, source_idx
423 |
424 | def seq_bleu(model_out, ref_seq):
425 | model_seq = torch.max(model_out.data, dim=1)[1]
426 | model_seq = model_seq.cpu().numpy()
427 | return utils.calc_bleu(model_seq, ref_seq)
428 |
429 | def mutual_words_to_words(words, data, k, emb_dict, rev_emb_dict, net, back_net):
430 | # Forward
431 | tokens = data.encode_words(words, emb_dict)
432 | source_seq = pack_input(tokens, net.emb, "cuda")
433 | enc = net.encode(source_seq)
434 | end_token = emb_dict[data.END_TOKEN]
435 | list_of_out_tokens, probs = net.decode_k_best(enc, source_seq.data[0:1], k, seq_len=data.MAX_TOKENS,
436 | stop_at_token=end_token)
437 | list_of_out_words = []
438 | for iTokens in range(len(list_of_out_tokens)):
439 | if list_of_out_tokens[iTokens][-1] == end_token:
440 | list_of_out_tokens[iTokens] = list_of_out_tokens[iTokens][:-1]
441 | list_of_out_words.append(data.decode_words(list_of_out_tokens[iTokens], rev_emb_dict))
442 |
443 | # Backward
444 | back_seq2seq_prob = []
445 | for iTarget in range(len(list_of_out_words)):
446 | b_tokens = data.encode_words(list_of_out_words[iTarget], emb_dict)
447 | target_seq = pack_input(b_tokens, back_net.emb, "cuda")
448 | b_enc = back_net.encode(target_seq)
449 | back_seq2seq_prob.append(back_net.get_qp_prob(b_enc, target_seq.data[0:1], tokens[1:]))
450 |
451 | mutual_prob = []
452 | for i in range(len(probs)):
453 | mutual_prob.append(probs[i] + back_seq2seq_prob[i])
454 |
455 | return list_of_out_words, mutual_prob
456 |
457 |
458 |
459 |
460 |
461 |
462 |
--------------------------------------------------------------------------------
/libbots/utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import string
4 | from nltk.translate import bleu_score
5 | from nltk.tokenize import TweetTokenizer
6 | import numpy as np
7 | from . import model
8 | import torch
9 | import torch.nn as nn
10 |
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | def calc_bleu_many(cand_seq, ref_sequences):
15 | sf = bleu_score.SmoothingFunction()
16 | return bleu_score.sentence_bleu(ref_sequences, cand_seq,
17 | smoothing_function=sf.method1,
18 | weights=(0.5, 0.5))
19 |
20 | def calc_preplexity_many(probs, actions):
21 | criterion = nn.CrossEntropyLoss()
22 | loss = criterion(torch.tensor(probs).to(device), torch.tensor(actions).to(device))
23 | return np.exp(loss.item())
24 |
25 | def calc_bleu(cand_seq, ref_seq):
26 | return calc_bleu_many(cand_seq, [ref_seq])
27 |
28 | def tokenize(s):
29 | return TweetTokenizer(preserve_case=False).tokenize(s)
30 |
31 | def untokenize(words):
32 | return "".join([" " + i if not i.startswith("'") and i not in string.punctuation else i for i in words]).strip()
33 |
34 | def calc_mutual(net, back_net, p1, p2):
35 | # Pack forward and backward
36 | criterion = nn.CrossEntropyLoss()
37 | p2 = [1] + p2
38 | input_seq_forward = model.pack_input(p1, net.emb, device)
39 | output_seq_forward = model.pack_input(p2[:-1], net.emb, device)
40 | input_seq_backward = model.pack_input(tuple(p2), back_net.emb, device)
41 | output_seq_backward = model.pack_input(list(p1)[:-1], back_net.emb, device)
42 | # Enc forward and backward
43 | enc_forward = net.encode(input_seq_forward)
44 | enc_backward = back_net.encode(input_seq_backward)
45 |
46 | r_forward = net.decode_teacher(enc_forward, output_seq_forward).detach()
47 | r_backward = back_net.decode_teacher(enc_backward, output_seq_backward).detach()
48 | fw = criterion(torch.tensor(r_forward).to(device), torch.tensor(p2[1:]).to(device)).detach()
49 | bw = criterion(torch.tensor(r_backward).to(device), torch.tensor(p1[1:]).to(device)).detach()
50 | return (-1)*float(fw + bw)
51 |
52 | def calc_cosine_many(mean_emb_pred, mean_emb_ref):
53 | norm_pred = mean_emb_pred / mean_emb_pred.norm(dim=1)[:, None]
54 | norm_ref = mean_emb_ref / mean_emb_ref.norm(dim=1)[:, None]
55 | return torch.mm(norm_pred, norm_ref.transpose(1, 0))
56 |
57 |
--------------------------------------------------------------------------------
/model_test.py:
--------------------------------------------------------------------------------
1 |
2 | from libbots import model, utils, data
3 |
4 | def run_test(test_data, per_net, end_token, device="cuda"):
5 | bleu_sum = 0.0
6 | bleu_count = 0
7 | for p1, p2 in test_data:
8 | input_seq = model.pack_input(p1, per_net.emb, device)
9 | enc = per_net.encode(input_seq)
10 | _, tokens = per_net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
11 | stop_at_token=end_token)
12 | ref_indices = p2[1:]
13 | bleu_sum += utils.calc_bleu_many(tokens, [ref_indices])
14 | bleu_count += 1
15 | return bleu_sum / bleu_count
16 |
17 | def run_test_preplexity(test_data, per_net, net, end_token, device="cpu"):
18 | preplexity_sum = 0.0
19 | preplexity_count = 0
20 | for p1, p2 in test_data:
21 | input_seq = model.pack_input(p1, per_net.emb, device)
22 | enc = per_net.encode(input_seq)
23 | logits, tokens = per_net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
24 | stop_at_token=end_token)
25 | r = net.get_logits(enc, input_seq.data[0:1], data.MAX_TOKENS, tokens,
26 | stop_at_token=end_token)
27 | preplexity_sum += utils.calc_preplexity_many(r, tokens)
28 | preplexity_count += 1
29 | return preplexity_sum / preplexity_count
30 |
31 | def run_test_mutual(test_data, rl_net, net, back_net, end_token, device="cuda"):
32 | mutual_sum = 0.0
33 | mutual_count = 0
34 | for p1, p2 in test_data:
35 | input_seq_bleu = model.pack_input(p1, rl_net.emb, device)
36 | enc_bleu = rl_net.encode(input_seq_bleu)
37 | _, actions = rl_net.decode_chain_argmax(enc_bleu, input_seq_bleu.data[0:1], seq_len=data.MAX_TOKENS,
38 | stop_at_token=end_token)
39 |
40 | mutual_sum += utils.calc_mutual(net, back_net, p1, actions)
41 | mutual_count += 1
42 | return mutual_sum / mutual_count
43 |
44 | def run_test_cosine(test_data, rl_net, net, beg_token, end_token, device="cuda"):
45 | cosine_sum = 0.0
46 | cosine_count = 0
47 | beg_embedding = rl_net.emb(beg_token)
48 | for p1, p2 in test_data:
49 | input_seq_cosine = model.pack_input(p1, rl_net.emb, device)
50 | enc_cosine = rl_net.encode(input_seq_cosine)
51 | r_argmax, actions = rl_net.decode_chain_argmax(enc_cosine, beg_embedding, data.MAX_TOKENS,
52 | stop_at_token=end_token)
53 | mean_emb_max = net.get_mean_emb(beg_embedding, actions)
54 | mean_emb_ref_list = []
55 | for iRef in p2:
56 | mean_emb_ref_list.append(rl_net.get_mean_emb(beg_embedding, iRef))
57 | mean_emb_ref = sum(mean_emb_ref_list) / len(mean_emb_ref_list)
58 | cosine_sum += utils.calc_cosine_many(mean_emb_max, mean_emb_ref)
59 | cosine_count += 1
60 | return cosine_sum / cosine_count
--------------------------------------------------------------------------------
/test_all_modells.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import logging
4 | import numpy as np
5 | import torch
6 | import argparse
7 |
8 | from libbots import data, model
9 | from model_test import run_test, run_test_mutual, run_test_preplexity, run_test_cosine
10 |
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 | if __name__ == '__main__':
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory')
16 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training')
17 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data')
18 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat',
19 | help='Pre-trained seq2seq model location')
20 | parser.add_argument('-laod_b_seq2seq_path', type=str,
21 | default='Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat',
22 | help='Pre-trained backward seq2seq model location')
23 | parser.add_argument('-bleu_model_path', type=str, default='Final_Saves/RL_BLUE/bleu_0.135_177.dat',
24 | help='Pre-trained BLEU model location')
25 | parser.add_argument('-mutual_model_path', type=str, default='Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat',
26 | help='Pre-trained MMI model location')
27 | parser.add_argument('-prep_model_path', type=str, default='Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat',
28 | help='Pre-trained Perplexity model location')
29 | parser.add_argument('-cos_model_path', type=str, default='Final_Saves/RL_COSINE/cosine_0.621_03.dat',
30 | help='Pre-trained Cosine Similarity model location')
31 | args = parser.parse_args()
32 |
33 |
34 | log = logging.getLogger("test")
35 |
36 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
37 | arg_data = 'comedy'
38 |
39 | # Load Seq2Seq
40 | seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(args.load_seq2seq_path))
41 | seq2seq_rev_emb_dict = {idx: word for word, idx in seq2seq_emb_dict.items()}
42 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(seq2seq_emb_dict),
43 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
44 | net.load_state_dict(torch.load(args.load_seq2seq_path))
45 |
46 | # Load Back Seq2Seq
47 | b_seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(args.laod_b_seq2seq_path))
48 | b_seq2seq_rev_emb_dict = {idx: word for word, idx in b_seq2seq_emb_dict.items()}
49 | b_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(b_seq2seq_emb_dict),
50 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
51 | b_net.load_state_dict(torch.load(args.laod_b_seq2seq_path))
52 |
53 | # Load BLEU
54 | bleu_emb_dict = data.load_emb_dict(os.path.dirname(args.bleu_model_path))
55 | bleu_rev_emb_dict = {idx: word for word, idx in bleu_emb_dict.items()}
56 | bleu_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(bleu_emb_dict),
57 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
58 | bleu_net.load_state_dict(torch.load(args.bleu_model_path))
59 |
60 | # Load Mutual
61 | mutual_emb_dict = data.load_emb_dict(os.path.dirname(args.mutual_model_path))
62 | mutual_rev_emb_dict = {idx: word for word, idx in mutual_emb_dict.items()}
63 | mutual_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(mutual_emb_dict),
64 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
65 | mutual_net.load_state_dict(torch.load(args.mutual_model_path))
66 |
67 |
68 | # Load Preplexity
69 | prep_emb_dict = data.load_emb_dict(os.path.dirname(args.prep_model_path))
70 | prep_rev_emb_dict = {idx: word for word, idx in prep_emb_dict.items()}
71 | prep_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(prep_emb_dict),
72 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
73 | prep_net.load_state_dict(torch.load(args.prep_model_path))
74 |
75 |
76 | # Load Cosine Similarity
77 | cos_emb_dict = data.load_emb_dict(os.path.dirname(args.cos_model_path))
78 | cos_rev_emb_dict = {idx: word for word, idx in cos_emb_dict.items()}
79 | cos_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(cos_emb_dict),
80 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
81 | cos_net.load_state_dict(torch.load(args.cos_model_path))
82 |
83 |
84 | phrase_pairs, emb_dict = data.load_data(genre_filter=arg_data)
85 | end_token = emb_dict[data.END_TOKEN]
86 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
87 | rand = np.random.RandomState(data.SHUFFLE_SEED)
88 | rand.shuffle(train_data)
89 | train_data, test_data = data.split_train_test(train_data)
90 |
91 | # BEGIN token
92 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
93 |
94 | # # Test Seq2Seq model
95 | bleu_test_seq2seq = run_test(test_data, net, end_token, device)
96 | mutual_test_seq2seq = run_test_mutual(test_data, net, net, b_net, beg_token, end_token, device)
97 | preplexity_test_seq2seq = run_test_preplexity(test_data, net, net, end_token, device)
98 | cosine_test_seq2seq = run_test_cosine(test_data, net, beg_token, end_token, device)
99 | #
100 | # # Test BLEU model
101 | bleu_test_bleu = run_test(test_data, bleu_net, end_token, device)
102 | mutual_test_bleu = run_test_mutual(test_data, bleu_net, net, b_net, beg_token, end_token, device)
103 | preplexity_test_bleu = run_test_preplexity(test_data, bleu_net, net, end_token, device)
104 | cosine_test_bleu = run_test_cosine(test_data, bleu_net, beg_token, end_token, device)
105 |
106 | # # Test Mutual Information model
107 | bleu_test_mutual = run_test(test_data, mutual_net, end_token, device)
108 | mutual_test_mutual = run_test_mutual(test_data, mutual_net, net, b_net, beg_token, end_token, device)
109 | preplexity_test_mutual = run_test_preplexity(test_data, mutual_net, net, end_token, device)
110 | cosine_test_mutual = run_test_cosine(test_data, mutual_net, beg_token, end_token, device)
111 |
112 | # Test Perplexity model
113 | bleu_test_per = run_test(test_data, prep_net, end_token, device)
114 | mutual_test_per = run_test_mutual(test_data, prep_net, net, b_net, beg_token, end_token, device)
115 | preplexity_test_per = run_test_preplexity(test_data, prep_net, net, end_token, device)
116 | cosine_test_per = run_test_cosine(test_data, prep_net, beg_token, end_token, device)
117 |
118 | # Test Cosine Similarity model
119 | bleu_test_cos = run_test(test_data, cos_net, end_token, device)
120 | mutual_test_cos = run_test_mutual(test_data, cos_net, net, b_net, beg_token, end_token, device)
121 | preplexity_test_cos = run_test_preplexity(test_data, cos_net, net, end_token, device)
122 | cosine_test_cos = run_test_cosine(test_data, cos_net, beg_token, end_token, device)
123 |
124 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
125 | log.info("-----------------------------------------------")
126 | log.info("BLEU scores:")
127 | log.info(" Seq2Seq - %.3f", bleu_test_seq2seq)
128 | log.info(" BLEU - %.3f", bleu_test_bleu)
129 | log.info(" MMI - %.3f", bleu_test_mutual)
130 | log.info(" Perplexity - %.3f", bleu_test_per)
131 | log.info(" Cosine Similatirity - %.3f", bleu_test_cos)
132 | log.info("-----------------------------------------------")
133 | log.info("Max Mutual Information scores:")
134 | log.info(" Seq2Seq - %.3f", mutual_test_seq2seq)
135 | log.info(" BLEU - %.3f", mutual_test_bleu)
136 | log.info(" MMI - %.3f", mutual_test_mutual)
137 | log.info(" Perplexity - %.3f", preplexity_test_mutual)
138 | log.info(" Cosine Similatirity - %.3f", mutual_test_cos)
139 | log.info("-----------------------------------------------")
140 | log.info("Perplexity scores:")
141 | log.info(" Seq2Seq - %.3f", preplexity_test_seq2seq)
142 | log.info(" BLEU - %.3f", preplexity_test_bleu)
143 | log.info(" MMI - %.3f", preplexity_test_mutual)
144 | log.info(" Perplexity - %.3f", preplexity_test_per)
145 | log.info(" Cosine Similatirity - %.3f", preplexity_test_cos)
146 | log.info("-----------------------------------------------")
147 | log.info("Cosine Similarity scores:")
148 | log.info(" Seq2Seq - %.3f", cosine_test_seq2seq)
149 | log.info(" BLEU - %.3f", cosine_test_bleu)
150 | log.info(" MMI - %.3f", mutual_test_cos)
151 | log.info(" Perplexity - %.3f", cosine_test_per)
152 | log.info(" Cosine Similatirity - %.3f", cosine_test_cos)
153 | log.info("-----------------------------------------------")
154 |
155 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import libbots.data
4 | from libbots import data, subtitles
5 |
6 |
7 | class TestData(TestCase):
8 | emb_dict = {
9 | data.BEGIN_TOKEN: 0,
10 | data.END_TOKEN: 1,
11 | data.UNKNOWN_TOKEN: 2,
12 | 'a': 3,
13 | 'b': 4
14 | }
15 |
16 | def test_encode_words(self):
17 | res = data.encode_words(['a', 'b', 'c'], self.emb_dict)
18 | self.assertEqual(res, [0, 3, 4, 2, 1])
19 |
20 | # def test_dialogues_to_train(self):
21 | # dialogues = [
22 | # [
23 | # libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1),
24 | # libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3),
25 | # libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3),
26 | # ],
27 | # [
28 | # libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1),
29 | # ]
30 | # ]
31 | #
32 | # res = data.dialogues_to_train(dialogues, self.emb_dict)
33 | # self.assertEqual(res, [
34 | # ([0, 3, 4, 1], [0, 4, 3, 1]),
35 | # ([0, 4, 3, 1], [0, 4, 3, 1]),
36 | # ])
--------------------------------------------------------------------------------
/tests/test_subtitles.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from unittest import TestCase
3 |
4 | import libbots.data
5 | from libbots import subtitles
6 |
7 |
8 | class TestPhrases(TestCase):
9 | def test_split_phrase(self):
10 | phrase = libbots.data.Phrase(words=["a", "b", "c"], time_start=datetime.timedelta(seconds=0),
11 | time_stop=datetime.timedelta(seconds=10))
12 | res = subtitles.split_phrase(phrase)
13 | self.assertIsInstance(res, list)
14 | self.assertEqual(len(res), 1)
15 | self.assertEqual(res[0], phrase)
16 |
17 | phrase = libbots.data.Phrase(words=["a", "b", "-", "c"], time_start=datetime.timedelta(seconds=0),
18 | time_stop=datetime.timedelta(seconds=10))
19 | res = subtitles.split_phrase(phrase)
20 | self.assertEqual(len(res), 2)
21 | self.assertEqual(res[0].words, ["a", "b"])
22 | self.assertEqual(res[1].words, ["c"])
23 | self.assertAlmostEqual(res[0].time_start.total_seconds(), 0)
24 | self.assertAlmostEqual(res[0].time_stop.total_seconds(), 5)
25 | self.assertAlmostEqual(res[1].time_start.total_seconds(), 5)
26 | self.assertAlmostEqual(res[1].time_stop.total_seconds(), 10)
27 |
28 | phrase = libbots.data.Phrase(words=['-', 'Wait', 'a', 'sec', '.', '-'], time_start=datetime.timedelta(0, 588, 204000),
29 | time_stop=datetime.timedelta(0, 590, 729000))
30 | res = subtitles.split_phrase(phrase)
31 | self.assertEqual(res[0].words, ["Wait", "a", "sec", "."])
32 |
33 |
34 | class TestUtils(TestCase):
35 | def test_parse_time(self):
36 | self.assertEqual(subtitles.parse_time("00:00:33,074"),
37 | datetime.timedelta(seconds=33, milliseconds=74))
38 |
39 | def test_remove_braced_words(self):
40 | self.assertEqual(subtitles.remove_braced_words(['a', 'b', 'c']),
41 | ['a', 'b', 'c'])
42 | self.assertEqual(subtitles.remove_braced_words(['a', '[', 'b', ']', 'c']),
43 | ['a', 'c'])
44 | self.assertEqual(subtitles.remove_braced_words(['a', '[', 'b', 'c']),
45 | ['a'])
46 | self.assertEqual(subtitles.remove_braced_words(['a', ']', 'b', 'c']),
47 | ['a', 'b', 'c'])
48 | self.assertEqual(subtitles.remove_braced_words(['a', '(', 'b', ']', 'c']),
49 | ['a', 'c'])
50 | self.assertEqual(subtitles.remove_braced_words(['a', '(', 'b', 'c']),
51 | ['a'])
52 | self.assertEqual(subtitles.remove_braced_words(['a', ')', 'b', 'c']),
53 | ['a', 'b', 'c'])
--------------------------------------------------------------------------------
/train_crossent.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import logging
5 | import numpy as np
6 | from tensorboardX import SummaryWriter
7 | import torch
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | import argparse
11 |
12 | from libbots import data, model, utils
13 | from model_test import run_test
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory')
20 | parser.add_argument('-name', type=str, default='seq2seq', help='Specific model saves directory')
21 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training')
22 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-3, help='Learning Rate')
23 | parser.add_argument('-MAX_EPOCHES', type=int, default=100, help='Number of training iterations')
24 | parser.add_argument('-TEACHER_PROB', type=float, default=0.5, help='Probability to force reference inputs')
25 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data')
26 | parser.add_argument('-train_backward', type=bool, default=False, help='Choose - train backward/forward model')
27 | args = parser.parse_args()
28 |
29 | saves_path = os.path.join(args.SAVES_DIR, args.name)
30 | os.makedirs(saves_path, exist_ok=True)
31 |
32 | log = logging.getLogger("train")
33 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
34 |
35 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data)
36 | data.save_emb_dict(saves_path, emb_dict)
37 | end_token = emb_dict[data.END_TOKEN]
38 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
39 | rand = np.random.RandomState(data.SHUFFLE_SEED)
40 | rand.shuffle(train_data)
41 | train_data, test_data = data.split_train_test(train_data)
42 |
43 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
44 | log.info("Training data converted, got %d samples", len(train_data))
45 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))
46 |
47 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
48 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
49 |
50 | writer = SummaryWriter(comment="-" + args.name)
51 |
52 | optimiser = optim.Adam(net.parameters(), lr=args.LEARNING_RATE)
53 | best_bleu = None
54 | for epoch in range(args.MAX_EPOCHES):
55 | losses = []
56 | bleu_sum = 0.0
57 | bleu_count = 0
58 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE):
59 | optimiser.zero_grad()
60 | if args.train_backward:
61 | input_seq, out_seq_list, _, out_idx = model.pack_backward_batch(batch, net.emb, device)
62 | else:
63 | input_seq, out_seq_list, _, out_idx = model.pack_batch(batch, net.emb, device)
64 | enc = net.encode(input_seq)
65 |
66 | net_results = []
67 | net_targets = []
68 | for idx, out_seq in enumerate(out_seq_list):
69 | ref_indices = out_idx[idx][1:]
70 | enc_item = net.get_encoded_item(enc, idx)
71 | if random.random() < args.TEACHER_PROB:
72 | r = net.decode_teacher(enc_item, out_seq)
73 | bleu_sum += model.seq_bleu(r, ref_indices)
74 | else:
75 | r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1],
76 | len(ref_indices))
77 | bleu_sum += utils.calc_bleu(seq, ref_indices)
78 | net_results.append(r)
79 | net_targets.extend(ref_indices)
80 | bleu_count += 1
81 | results_v = torch.cat(net_results)
82 | targets_v = torch.LongTensor(net_targets).to(device)
83 | loss_v = F.cross_entropy(results_v, targets_v)
84 | loss_v.backward()
85 | optimiser.step()
86 |
87 | losses.append(loss_v.item())
88 | bleu = bleu_sum / bleu_count
89 | bleu_test = run_test(test_data, net, end_token, device)
90 | log.info("Epoch %d: mean loss %.3f, mean BLEU %.3f, test BLEU %.3f",
91 | epoch, np.mean(losses), bleu, bleu_test)
92 | writer.add_scalar("loss", np.mean(losses), epoch)
93 | writer.add_scalar("bleu", bleu, epoch)
94 | writer.add_scalar("bleu_test", bleu_test, epoch)
95 | if best_bleu is None or best_bleu < bleu_test:
96 | if best_bleu is not None:
97 | out_name = os.path.join(saves_path, "pre_bleu_%.3f_%02d.dat" %
98 | (bleu_test, epoch))
99 | torch.save(net.state_dict(), out_name)
100 | log.info("Best BLEU updated %.3f", bleu_test)
101 | best_bleu = bleu_test
102 |
103 | if epoch % 10 == 0:
104 | out_name = os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" %
105 | (epoch, bleu, bleu_test))
106 | torch.save(net.state_dict(), out_name)
107 |
108 | writer.close()
109 |
--------------------------------------------------------------------------------
/train_rl_BLEU.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import random
5 | import logging
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 | import torch
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | import argparse
12 |
13 | from libbots import data, model, utils
14 | from model_test import run_test
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory')
19 | parser.add_argument('-name', type=str, default='RL_BLUE', help='Specific model saves directory')
20 | parser.add_argument('-BATCH_SIZE', type=int, default=16, help='Batch Size for training')
21 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate')
22 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations')
23 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data')
24 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example')
25 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat',
26 | help='Pre-trained seq2seq model location')
27 | args = parser.parse_args()
28 |
29 | saves_path = os.path.join(args.SAVES_DIR, args.name)
30 | os.makedirs(saves_path, exist_ok=True)
31 |
32 |
33 | log = logging.getLogger("train")
34 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
35 |
36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 |
38 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data)
39 | data.save_emb_dict(saves_path, emb_dict)
40 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
41 | rand = np.random.RandomState(data.SHUFFLE_SEED)
42 | rand.shuffle(train_data)
43 | train_data, test_data = data.split_train_test(train_data)
44 | train_data = data.group_train_data(train_data)
45 | test_data = data.group_train_data(test_data)
46 |
47 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
48 | log.info("Training data converted, got %d samples", len(train_data))
49 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))
50 |
51 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
52 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
53 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
54 | loaded_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
55 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
56 |
57 | writer = SummaryWriter(comment="-" + args.name)
58 | net.load_state_dict(torch.load(args.load_seq2seq_path))
59 |
60 | # BEGIN & END tokens
61 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
62 | end_token = emb_dict[data.END_TOKEN]
63 |
64 |
65 | optimiser = optim.Adam(net.parameters(), lr=args.LEARNING_RATE, eps=1e-3)
66 | batch_idx = 0
67 | best_bleu = None
68 | for epoch in range(args.MAX_EPOCHES):
69 | random.shuffle(train_data)
70 | dial_shown = False
71 |
72 | total_samples = 0
73 | bleus_argmax = []
74 | bleus_sample = []
75 |
76 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE):
77 | batch_idx += 1
78 | optimiser.zero_grad()
79 | input_seq, input_batch, output_batch = model.pack_batch_no_out(batch, net.emb, device)
80 | enc = net.encode(input_seq)
81 |
82 | net_policies = []
83 | net_actions = []
84 | net_advantages = []
85 | beg_embedding = net.emb(beg_token)
86 |
87 | for idx, inp_idx in enumerate(input_batch):
88 | total_samples += 1
89 | ref_indices = [indices[1:] for indices in output_batch[idx]]
90 | item_enc = net.get_encoded_item(enc, idx)
91 | r_argmax, actions = net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
92 | stop_at_token=end_token)
93 | argmax_bleu = utils.calc_bleu_many(actions, ref_indices)
94 | bleus_argmax.append(argmax_bleu)
95 |
96 |
97 | if not dial_shown:
98 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
99 | ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices]
100 | log.info("Refer: %s", " ~~|~~ ".join(ref_words))
101 | log.info("Argmax: %s, bleu=%.4f",
102 | utils.untokenize(data.decode_words(actions, rev_emb_dict)),
103 | argmax_bleu)
104 |
105 | for _ in range(args.num_of_samples):
106 | r_sample, actions = net.decode_chain_sampling(item_enc, beg_embedding,
107 | data.MAX_TOKENS, stop_at_token=end_token)
108 | sample_bleu = utils.calc_bleu_many(actions, ref_indices)
109 |
110 | if not dial_shown:
111 | log.info("Sample: %s, bleu=%.4f",
112 | utils.untokenize(data.decode_words(actions, rev_emb_dict)),
113 | sample_bleu)
114 |
115 | net_policies.append(r_sample)
116 | net_actions.extend(actions)
117 | net_advantages.extend([sample_bleu - argmax_bleu] * len(actions))
118 | bleus_sample.append(sample_bleu)
119 | dial_shown = True
120 |
121 | if not net_policies:
122 | continue
123 |
124 | policies_v = torch.cat(net_policies)
125 | actions_t = torch.LongTensor(net_actions).to(device)
126 | adv_v = torch.FloatTensor(net_advantages).to(device)
127 | log_prob_v = F.log_softmax(policies_v, dim=1)
128 | log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t]
129 | loss_policy_v = -log_prob_actions_v.mean()
130 |
131 | loss_v = loss_policy_v
132 | loss_v.backward()
133 | optimiser.step()
134 |
135 | bleu_test = run_test(test_data, net, end_token, device)
136 | bleu = np.mean(bleus_argmax)
137 | writer.add_scalar("bleu_test", bleu_test, batch_idx)
138 | writer.add_scalar("bleu_argmax", bleu, batch_idx)
139 | writer.add_scalar("bleu_sample", np.mean(bleus_sample), batch_idx)
140 | writer.add_scalar("epoch", batch_idx, epoch)
141 |
142 | log.info("Epoch %d, test BLEU: %.3f", epoch, bleu_test)
143 | if best_bleu is None or best_bleu < bleu_test:
144 | best_bleu = bleu_test
145 | log.info("Best bleu updated: %.4f", bleu_test)
146 | torch.save(net.state_dict(), os.path.join(saves_path, "bleu_%.3f_%02d.dat" % (bleu_test, epoch)))
147 | if epoch % 10 == 0:
148 | torch.save(net.state_dict(),
149 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, bleu, bleu_test)))
150 |
151 | writer.close()
--------------------------------------------------------------------------------
/train_rl_MMI.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import random
5 | import logging
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 | import torch
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | import argparse
12 |
13 | from libbots import data, model, utils
14 | from model_test import run_test_mutual
15 |
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory')
21 | parser.add_argument('-name', type=str, default='RL_Mutual', help='Specific model saves directory')
22 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training')
23 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate')
24 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations')
25 | parser.add_argument('-CROSS_ENT_PROB', type=float, default=0.3, help='Probability to run a CE batch')
26 | parser.add_argument('-TEACHER_PROB', type=float, default=0.8, help='Probability to run an imitation batch in case '
27 | 'of using CE')
28 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data')
29 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example')
30 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat',
31 | help='Pre-trained seq2seq model location')
32 | parser.add_argument('-laod_b_seq2seq_path', type=str, default='Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat',
33 | help='Pre-trained backward seq2seq model location')
34 | args = parser.parse_args()
35 |
36 | saves_path = os.path.join(args.SAVES_DIR, args.name)
37 | os.makedirs(saves_path, exist_ok=True)
38 |
39 | log = logging.getLogger("train")
40 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
41 |
42 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data)
43 | data.save_emb_dict(saves_path, emb_dict)
44 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
45 | rand = np.random.RandomState(data.SHUFFLE_SEED)
46 | rand.shuffle(train_data)
47 | train_data, test_data = data.split_train_test(train_data)
48 |
49 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
50 | log.info("Training data converted, got %d samples", len(train_data))
51 |
52 |
53 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
54 |
55 | # Load pre-trained nets
56 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
57 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
58 | net.load_state_dict(torch.load(args.load_seq2seq_path))
59 | back_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
60 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
61 | back_net.load_state_dict(torch.load(args.laod_b_seq2seq_path))
62 |
63 | rl_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
64 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
65 | rl_net.load_state_dict(torch.load(args.load_seq2seq_path))
66 |
67 | writer = SummaryWriter(comment="-" + args.name)
68 |
69 | # BEGIN & END tokens
70 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
71 | end_token = emb_dict[data.END_TOKEN]
72 |
73 |
74 |
75 | optimiser = optim.Adam(rl_net.parameters(), lr=args.LEARNING_RATE, eps=1e-3)
76 | batch_idx = 0
77 | best_mutual = None
78 | for epoch in range(args.MAX_EPOCHES):
79 | dial_shown = False
80 | random.shuffle(train_data)
81 |
82 | total_samples = 0
83 | mutuals_argmax = []
84 | mutuals_sample = []
85 |
86 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE):
87 | batch_idx += 1
88 | optimiser.zero_grad()
89 | input_seq, out_seq_list, input_batch, output_batch = model.pack_batch(batch, rl_net.emb, device)
90 | enc = rl_net.encode(input_seq)
91 |
92 | net_policies = []
93 | net_actions = []
94 | net_advantages = []
95 | beg_embedding = rl_net.emb(beg_token)
96 |
97 | if random.random() < args.CROSS_ENT_PROB:
98 | net_results = []
99 | net_targets = []
100 | for idx, out_seq in enumerate(out_seq_list):
101 | ref_indices = output_batch[idx][1:]
102 | enc_item = rl_net.get_encoded_item(enc, idx)
103 | if random.random() < args.TEACHER_PROB:
104 | r = rl_net.decode_teacher(enc_item, out_seq)
105 | else:
106 | r, seq = rl_net.decode_chain_argmax(enc_item, out_seq.data[0:1],
107 | len(ref_indices))
108 | net_results.append(r)
109 | net_targets.extend(ref_indices)
110 | results_v = torch.cat(net_results)
111 | targets_v = torch.LongTensor(net_targets).to(device)
112 | loss_v = F.cross_entropy(results_v, targets_v)
113 | loss_v.backward()
114 | for param in rl_net.parameters():
115 | param.grad.data.clamp_(-0.2, 0.2)
116 | optimiser.step()
117 | else:
118 | for idx, inp_idx in enumerate(input_batch):
119 | total_samples += 1
120 | ref_indices = output_batch[idx][1:]
121 | item_enc = rl_net.get_encoded_item(enc, idx)
122 | r_argmax, actions = rl_net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
123 | stop_at_token=end_token)
124 | argmax_mutual = utils.calc_mutual(net, back_net, inp_idx, actions)
125 | mutuals_argmax.append(argmax_mutual)
126 |
127 | if not dial_shown:
128 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
129 | ref_words = [utils.untokenize(data.decode_words([ref], rev_emb_dict)) for ref in ref_indices]
130 | log.info("Refer: %s", " ".join(ref_words))
131 | log.info("Argmax: %s, mutual=%.4f",
132 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), argmax_mutual)
133 |
134 | for _ in range(args.num_of_samples):
135 | r_sample, actions = rl_net.decode_chain_sampling(item_enc, beg_embedding,
136 | data.MAX_TOKENS, stop_at_token=end_token)
137 | sample_mutual = utils.calc_mutual(net, back_net, inp_idx, actions)
138 | if not dial_shown:
139 | log.info("Sample: %s, mutual=%.4f",
140 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), sample_mutual)
141 |
142 | net_policies.append(r_sample)
143 | net_actions.extend(actions)
144 | net_advantages.extend([sample_mutual - argmax_mutual] * len(actions))
145 | mutuals_sample.append(sample_mutual)
146 | dial_shown = True
147 |
148 | if not net_policies:
149 | continue
150 |
151 | policies_v = torch.cat(net_policies)
152 | actions_t = torch.LongTensor(net_actions).to(device)
153 | adv_v = torch.FloatTensor(net_advantages).to(device)
154 | log_prob_v = F.log_softmax(policies_v, dim=1)
155 | log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t]
156 | loss_policy_v = -log_prob_actions_v.mean()
157 |
158 | loss_v = loss_policy_v
159 | loss_v.backward()
160 | for param in rl_net.parameters():
161 | param.grad.data.clamp_(-0.2, 0.2)
162 | optimiser.step()
163 |
164 | mutual_test = run_test_mutual(test_data, rl_net, net, back_net, beg_token, end_token, device)
165 | mutual = np.mean(mutuals_argmax)
166 | writer.add_scalar("mutual_test", mutual_test, batch_idx)
167 | writer.add_scalar("mutual_argmax", mutual, batch_idx)
168 | writer.add_scalar("mutual_sample", np.mean(mutuals_sample), batch_idx)
169 | writer.add_scalar("epoch", batch_idx, epoch)
170 | log.info("Epoch %d, test mutual: %.3f", epoch, mutual_test)
171 | if best_mutual is None or best_mutual < mutual_test:
172 | best_mutual = mutual_test
173 | log.info("Best mutual updated: %.4f", best_mutual)
174 | torch.save(rl_net.state_dict(), os.path.join(saves_path, "mutual_%.3f_%02d.dat" % (mutual_test, epoch)))
175 | if epoch % 10 == 0:
176 | torch.save(rl_net.state_dict(),
177 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, mutual, mutual_test)))
178 |
179 | writer.close()
--------------------------------------------------------------------------------
/train_rl_PREPLEXITY.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import random
4 | import logging
5 | import numpy as np
6 | from tensorboardX import SummaryWriter
7 | import torch
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | import argparse
11 |
12 | from libbots import data, model, utils
13 | from model_test import run_test_preplexity
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory')
20 | parser.add_argument('-name', type=str, default='RL_PREPLEXITY', help='Specific model saves directory')
21 | parser.add_argument('-BATCH_SIZE', type=int, default=32, help='Batch Size for training')
22 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate')
23 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations')
24 | parser.add_argument('-CROSS_ENT_PROB', type=float, default=0.5, help='Probability to run a CE batch')
25 | parser.add_argument('-TEACHER_PROB', type=float, default=0.5, help='Probability to run an imitation batch in case '
26 | 'of using CE')
27 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data')
28 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example')
29 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat',
30 | help='Pre-trained seq2seq model location')
31 | args = parser.parse_args()
32 |
33 | saves_path = os.path.join(args.SAVES_DIR, args.name)
34 | os.makedirs(saves_path, exist_ok=True)
35 |
36 | log = logging.getLogger("train")
37 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
38 |
39 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data)
40 | data.save_emb_dict(saves_path, emb_dict)
41 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
42 | rand = np.random.RandomState(data.SHUFFLE_SEED)
43 | rand.shuffle(train_data)
44 | train_data, test_data = data.split_train_test(train_data)
45 |
46 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
47 | log.info("Training data converted, got %d samples", len(train_data))
48 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))
49 |
50 | # Load pre-trained nets
51 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
52 | per_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
53 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
54 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
55 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
56 | per_net.load_state_dict(torch.load(args.load_seq2seq_path))
57 | net.load_state_dict(torch.load(args.load_seq2seq_path))
58 |
59 | writer = SummaryWriter(comment="-" + args.name)
60 |
61 |
62 | # BEGIN & END tokens
63 | end_token = emb_dict[data.END_TOKEN]
64 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
65 |
66 | optimiser = optim.Adam(per_net.parameters(), lr=args.LEARNING_RATE, eps=1e-3)
67 | batch_idx = 0
68 | best_preplexity = None
69 | for epoch in range(args.MAX_EPOCHES):
70 | dial_shown = False
71 | random.shuffle(train_data)
72 |
73 | total_samples = 0
74 | preplexity_argmax = []
75 | preplexity_sample = []
76 |
77 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE):
78 | batch_idx += 1
79 | optimiser.zero_grad()
80 | input_seq, out_seq_list, input_batch, output_batch = model.pack_batch(batch, per_net.emb, device)
81 | enc = per_net.encode(input_seq)
82 |
83 | net_policies = []
84 | net_actions = []
85 | net_advantages = []
86 | beg_embedding = per_net.emb(beg_token)
87 |
88 | if random.random() < args.CROSS_ENT_PROB:
89 | net_results = []
90 | net_targets = []
91 | for idx, out_seq in enumerate(out_seq_list):
92 | ref_indices = output_batch[idx][1:]
93 | enc_item = per_net.get_encoded_item(enc, idx)
94 | if random.random() < args.TEACHER_PROB:
95 | r = per_net.decode_teacher(enc_item, out_seq)
96 | else:
97 | r, seq = per_net.decode_chain_argmax(enc_item, out_seq.data[0:1],
98 | len(ref_indices))
99 | net_results.append(r)
100 | net_targets.extend(ref_indices)
101 | results_v = torch.cat(net_results)
102 | targets_v = torch.LongTensor(net_targets).to(device)
103 | loss_v = F.cross_entropy(results_v, targets_v)
104 | loss_v.backward()
105 | for param in per_net.parameters():
106 | param.grad.data.clamp_(-0.2, 0.2)
107 | optimiser.step()
108 | else:
109 | for idx, inp_idx in enumerate(input_batch):
110 | total_samples += 1
111 | ref_indices = output_batch[idx][1:]
112 | item_enc = per_net.get_encoded_item(enc, idx)
113 | r_argmax, actions = per_net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
114 | stop_at_token=end_token)
115 | r_net_argmax = net.get_logits(item_enc, beg_embedding, data.MAX_TOKENS, actions,
116 | stop_at_token=end_token)
117 | argmax_preplexity = utils.calc_preplexity_many(r_net_argmax, actions)
118 | preplexity_argmax.append(argmax_preplexity)
119 |
120 | if not dial_shown:
121 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
122 | ref_words = [utils.untokenize(data.decode_words([ref], rev_emb_dict)) for ref in ref_indices]
123 | log.info("Refer: %s", " ".join(ref_words))
124 | log.info("Argmax: %s, preplexity=%.4f",
125 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), argmax_preplexity)
126 |
127 | for _ in range(args.num_of_samples):
128 | r_sample, actions = per_net.decode_chain_sampling(item_enc, beg_embedding,
129 | data.MAX_TOKENS, stop_at_token=end_token)
130 | r_net_sample = net.get_logits(item_enc, beg_embedding, data.MAX_TOKENS, actions,
131 | stop_at_token=end_token)
132 | sample_preplexity = utils.calc_preplexity_many(r_net_sample, actions)
133 | if not dial_shown:
134 | log.info("Sample: %s, preplexity=%.4f",
135 | utils.untokenize(data.decode_words(actions, rev_emb_dict)), sample_preplexity)
136 |
137 | net_policies.append(r_sample)
138 | net_actions.extend(actions)
139 | net_advantages.extend([sample_preplexity - argmax_preplexity] * len(actions))
140 | preplexity_sample.append(sample_preplexity)
141 | dial_shown = True
142 |
143 | if not net_policies:
144 | continue
145 |
146 | policies_v = torch.cat(net_policies)
147 | actions_t = torch.LongTensor(net_actions).to(device)
148 | adv_v = torch.FloatTensor(net_advantages).to(device)
149 | log_prob_v = F.log_softmax(policies_v, dim=1)
150 | log_prob_actions_v = (-1) * adv_v * log_prob_v[range(len(net_actions)), actions_t]
151 | loss_policy_v = -log_prob_actions_v.mean()
152 |
153 | loss_v = loss_policy_v
154 | loss_v.backward()
155 | for param in per_net.parameters():
156 | param.grad.data.clamp_(-0.2, 0.2)
157 | optimiser.step()
158 |
159 |
160 | preplexity_test = run_test_preplexity(test_data, per_net, net, end_token, device)
161 | preplexity = np.mean(preplexity_argmax)
162 | writer.add_scalar("preplexity_test", preplexity_test, batch_idx)
163 | writer.add_scalar("preplexity_argmax", preplexity, batch_idx)
164 | writer.add_scalar("preplexity_sample", np.mean(preplexity_sample), batch_idx)
165 | writer.add_scalar("epoch", batch_idx, epoch)
166 | log.info("Epoch %d, test PREPLEXITY: %.3f", epoch, preplexity_test)
167 | if best_preplexity is None or best_preplexity > preplexity_test:
168 | best_preplexity = preplexity_test
169 | log.info("Best preplexity updated: %.4f", preplexity_test)
170 | torch.save(per_net.state_dict(), os.path.join(saves_path, "preplexity_%.3f_%02d.dat" % (preplexity_test, epoch)))
171 | if epoch % 10 == 0:
172 | torch.save(per_net.state_dict(),
173 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, preplexity, preplexity_test)))
174 |
175 | writer.close()
--------------------------------------------------------------------------------
/train_rl_cosine.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import random
5 | import logging
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 | from libbots import data, model, utils
9 | import torch
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | import argparse
13 |
14 | from model_test import run_test_cosine
15 |
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-SAVES_DIR', type=str, default='saves', help='Save directory')
21 | parser.add_argument('-name', type=str, default='RL_COSINE', help='Specific model saves directory')
22 | parser.add_argument('-BATCH_SIZE', type=int, default=16, help='Batch Size for training')
23 | parser.add_argument('-LEARNING_RATE', type=float, default=1e-4, help='Learning Rate')
24 | parser.add_argument('-MAX_EPOCHES', type=int, default=10000, help='Number of training iterations')
25 | parser.add_argument('-data', type=str, default='comedy', help='Genre to use - for data')
26 | parser.add_argument('-num_of_samples', type=int, default=4, help='Number of samples per per each example')
27 | parser.add_argument('-load_seq2seq_path', type=str, default='Final_Saves/seq2seq/epoch_090_0.800_0.107.dat',
28 | help='Pre-trained seq2seq model location')
29 | args = parser.parse_args()
30 |
31 | saves_path = os.path.join(args.SAVES_DIR, args.name)
32 | os.makedirs(saves_path, exist_ok=True)
33 |
34 | log = logging.getLogger("train")
35 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
36 |
37 | phrase_pairs, emb_dict = data.load_data(genre_filter=args.data)
38 | data.save_emb_dict(saves_path, emb_dict)
39 | train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
40 | rand = np.random.RandomState(data.SHUFFLE_SEED)
41 | rand.shuffle(train_data)
42 | train_data, test_data = data.split_train_test(train_data)
43 | train_data = data.group_train_data(train_data)
44 | test_data = data.group_train_data(test_data)
45 |
46 | log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
47 | log.info("Training data converted, got %d samples", len(train_data))
48 | log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))
49 |
50 | # Load pre-trained seq2seq net
51 | rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
52 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
53 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
54 | cos_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
55 | hid_size=model.HIDDEN_STATE_SIZE).to(device)
56 | net.load_state_dict(torch.load(args.load_seq2seq_path))
57 | cos_net.load_state_dict(torch.load(args.load_seq2seq_path))
58 |
59 | writer = SummaryWriter(comment="-" + args.name)
60 |
61 | # BEGIN & END tokens
62 | beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
63 | end_token = emb_dict[data.END_TOKEN]
64 |
65 |
66 | optimiser = optim.Adam(cos_net.parameters(), lr=args.LEARNING_RATE, eps=1e-3)
67 | batch_idx = 0
68 | best_cosine = None
69 | for epoch in range(args.MAX_EPOCHES):
70 | random.shuffle(train_data)
71 | dial_shown = False
72 |
73 | total_samples = 0
74 | skipped_samples = 0
75 | cosine_argmax = []
76 | cosine_sample = []
77 |
78 | for batch in data.iterate_batches(train_data, args.BATCH_SIZE):
79 | batch_idx += 1
80 | optimiser.zero_grad()
81 | input_seq, input_batch, output_batch = model.pack_batch_no_out(batch, cos_net.emb, device)
82 | enc = cos_net.encode(input_seq)
83 |
84 | net_policies = []
85 | net_actions = []
86 | net_advantages = []
87 | beg_embedding = cos_net.emb(beg_token)
88 |
89 | for idx, inp_idx in enumerate(input_batch):
90 | total_samples += 1
91 | ref_indices = [indices[1:] for indices in output_batch[idx]]
92 | item_enc = cos_net.get_encoded_item(enc, idx)
93 | r_argmax, actions = cos_net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
94 | stop_at_token=end_token)
95 | mean_emb_max = net.get_mean_emb(beg_embedding, actions)
96 | mean_emb_ref_list = []
97 | for iRef in ref_indices:
98 | mean_emb_ref_list.append(net.get_mean_emb(beg_embedding, iRef))
99 | mean_emb_ref = sum(mean_emb_ref_list)/len(mean_emb_ref_list)
100 | argmax_cosine = utils.calc_cosine_many(mean_emb_max, mean_emb_ref)
101 | cosine_argmax.append(float(argmax_cosine))
102 |
103 |
104 | if not dial_shown:
105 | log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
106 | ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices]
107 | log.info("Refer: %s", " ~~|~~ ".join(ref_words))
108 | log.info("Argmax: %s, cosine=%.4f",
109 | utils.untokenize(data.decode_words(actions, rev_emb_dict)),
110 | argmax_cosine)
111 |
112 | for _ in range(args.num_of_samples):
113 | r_sample, actions = cos_net.decode_chain_sampling(item_enc, beg_embedding, data.MAX_TOKENS,
114 | stop_at_token=end_token)
115 | mean_emb_samp = net.get_mean_emb(beg_embedding, actions)
116 | sample_cosine = utils.calc_cosine_many(mean_emb_samp, mean_emb_ref)
117 |
118 | if not dial_shown:
119 | log.info("Sample: %s, cosine=%.4f",
120 | utils.untokenize(data.decode_words(actions, rev_emb_dict)),
121 | sample_cosine)
122 |
123 | net_policies.append(r_sample)
124 | net_actions.extend(actions)
125 | net_advantages.extend([float(sample_cosine) - argmax_cosine] * len(actions))
126 | cosine_sample.append(float(sample_cosine))
127 | dial_shown = True
128 |
129 | if not net_policies:
130 | continue
131 |
132 | policies_v = torch.cat(net_policies)
133 | actions_t = torch.LongTensor(net_actions).to(device)
134 | adv_v = torch.FloatTensor(net_advantages).to(device)
135 | log_prob_v = F.log_softmax(policies_v, dim=1)
136 | log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t]
137 | loss_policy_v = -log_prob_actions_v.mean()
138 |
139 | loss_v = loss_policy_v
140 | loss_v.backward()
141 | optimiser.step()
142 |
143 | cosine_test = run_test_cosine(test_data, cos_net, net, beg_token, end_token, device="cuda")
144 | cosine = np.mean(cosine_argmax)
145 | writer.add_scalar("cosine_test", cosine_test, batch_idx)
146 | writer.add_scalar("cosine_argmax", cosine, batch_idx)
147 | writer.add_scalar("cosine_sample", np.mean(cosine_sample), batch_idx)
148 | writer.add_scalar("epoch", batch_idx, epoch)
149 | log.info("Epoch %d, test COSINE: %.3f", epoch, cosine_test)
150 | if best_cosine is None or best_cosine < cosine_test:
151 | best_cosine = cosine_test
152 | log.info("Best cosine updated: %.4f", cosine_test)
153 | torch.save(cos_net.state_dict(), os.path.join(saves_path, "cosine_%.3f_%02d.dat" % (cosine_test, epoch)))
154 | if epoch % 10 == 0:
155 | torch.save(cos_net.state_dict(),
156 | os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, cosine, cosine_test)))
157 |
158 | writer.close()
--------------------------------------------------------------------------------
/use_model.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import logging
5 | import torch
6 |
7 | from libbots import data, model, utils
8 |
9 |
10 | log = logging.getLogger("use")
11 | k_sentences = 100
12 |
13 | def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
14 | tokens = data.encode_words(words, emb_dict)
15 | input_seq = model.pack_input(tokens, net.emb)
16 | enc = net.encode(input_seq)
17 | end_token = emb_dict[data.END_TOKEN]
18 | if use_sampling:
19 | _, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
20 | stop_at_token=end_token)
21 | else:
22 | _, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
23 | stop_at_token=end_token)
24 | if out_tokens[-1] == end_token:
25 | out_tokens = out_tokens[:-1]
26 | out_words = data.decode_words(out_tokens, rev_emb_dict)
27 | return out_words
28 |
29 | def mutual_words_to_words(words, emb_dict, rev_emb_dict, back_emb_dict, net, back_net):
30 | # Forward
31 | tokens = data.encode_words(words, emb_dict)
32 | source_seq = model.pack_input(tokens, net.emb)
33 | enc = net.encode(source_seq)
34 | end_token = emb_dict[data.END_TOKEN]
35 | list_of_out_tokens, probs = net.decode_k_best(enc, source_seq.data[0:1], k_sentences, seq_len=data.MAX_TOKENS,
36 | stop_at_token=end_token)
37 | list_of_out_words = []
38 | for iTokens in range(len(list_of_out_tokens)):
39 | if list_of_out_tokens[iTokens][-1] == end_token:
40 | list_of_out_tokens[iTokens] = list_of_out_tokens[iTokens][:-1]
41 | list_of_out_words.append(data.decode_words(list_of_out_tokens[iTokens], rev_emb_dict))
42 |
43 | # Backward
44 | back_seq2seq_prob = []
45 | for iTarget in range(len(list_of_out_words)):
46 | b_tokens = data.encode_words(list_of_out_words[iTarget], back_emb_dict)
47 | target_seq = model.pack_input(b_tokens, back_net.emb)
48 | b_enc = back_net.encode(target_seq)
49 | back_seq2seq_prob.append(back_net.get_qp_prob(b_enc, target_seq.data[0:1], tokens[1:]))
50 |
51 | mutual_prob = []
52 | for i in range(len(probs)):
53 | mutual_prob.append(probs[i] + back_seq2seq_prob[i])
54 | most_prob_mutual_sen_id = sorted(range(len(mutual_prob)), key=lambda s: mutual_prob[s])[-1:][0]
55 |
56 | return list_of_out_words[most_prob_mutual_sen_id], mutual_prob[most_prob_mutual_sen_id]
57 |
58 | def process_string(words, emb_dict, rev_emb_dict, net, use_sampling=False):
59 | out_words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=use_sampling)
60 | print(" ".join(out_words))
61 |
62 |
63 | logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
64 |
65 | load_seq2seq_path = 'Final_Saves/seq2seq/epoch_090_0.800_0.107.dat'
66 | laod_b_seq2seq_path = 'Final_Saves/backward_seq2seq/epoch_080_0.780_0.104.dat'
67 | bleu_model_path = 'Final_Saves/RL_BLUE/bleu_0.135_177.dat'
68 | mutual_model_path = 'Final_Saves/RL_Mutual/epoch_180_-4.325_-7.192.dat'
69 | prep_model_path = 'Final_Saves/RL_Perplexity/epoch_050_1.463_3.701.dat'
70 | cos_model_path = 'Final_Saves/RL_COSINE/cosine_0.621_03.dat'
71 |
72 | input_sentences = []
73 | input_sentences.append("Do you want to stay?")
74 | input_sentences.append("What's your full name?")
75 | input_sentences.append("Where are you going?")
76 | input_sentences.append("How old are you?")
77 | input_sentences.append("Where are you?")
78 | input_sentences.append("hi, joey.")
79 | input_sentences.append("let's go.")
80 | input_sentences.append("excuse me?")
81 | input_sentences.append("what's that?")
82 | input_sentences.append("Stop!")
83 | input_sentences.append("where ya goin?")
84 | input_sentences.append("what's this?")
85 | input_sentences.append("Do you play football?")
86 | input_sentences.append("who is she?")
87 | input_sentences.append("who is he?")
88 | input_sentences.append("Are you sure?")
89 | input_sentences.append("Did you see that?")
90 | input_sentences.append("Hello.")
91 |
92 | sample = False
93 | mutual = True
94 | RL = True
95 | self = 1
96 | device = torch.device("cuda") # "cuda"/"cpu"
97 |
98 | # Load Seq2Seq
99 | seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(load_seq2seq_path))
100 | seq2seq_rev_emb_dict = {idx: word for word, idx in seq2seq_emb_dict.items()}
101 | net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(seq2seq_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
102 | net.load_state_dict(torch.load(load_seq2seq_path))
103 |
104 | # Load Back Seq2Seq
105 | b_seq2seq_emb_dict = data.load_emb_dict(os.path.dirname(laod_b_seq2seq_path))
106 | b_seq2seq_rev_emb_dict = {idx: word for word, idx in b_seq2seq_emb_dict.items()}
107 | b_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(b_seq2seq_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
108 | b_net.load_state_dict(torch.load(laod_b_seq2seq_path))
109 |
110 | # Load BLEU
111 | bleu_emb_dict = data.load_emb_dict(os.path.dirname(bleu_model_path))
112 | bleu_rev_emb_dict = {idx: word for word, idx in bleu_emb_dict.items()}
113 | bleu_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(bleu_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
114 | bleu_net.load_state_dict(torch.load(bleu_model_path))
115 |
116 | # Load Mutual
117 | mutual_emb_dict = data.load_emb_dict(os.path.dirname(mutual_model_path))
118 | mutual_rev_emb_dict = {idx: word for word, idx in mutual_emb_dict.items()}
119 | mutual_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(mutual_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
120 | mutual_net.load_state_dict(torch.load(mutual_model_path))
121 |
122 |
123 | # Load Preplexity
124 | prep_emb_dict = data.load_emb_dict(os.path.dirname(prep_model_path))
125 | prep_rev_emb_dict = {idx: word for word, idx in prep_emb_dict.items()}
126 | prep_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(prep_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
127 | prep_net.load_state_dict(torch.load(prep_model_path))
128 |
129 | # Load Preplexity
130 | prep1_emb_dict = data.load_emb_dict(os.path.dirname(prep_model_path))
131 | prep1_rev_emb_dict = {idx: word for word, idx in prep1_emb_dict.items()}
132 | prep1_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(prep1_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
133 | prep1_net.load_state_dict(torch.load(prep_model_path))
134 |
135 | # Load Cosine Similarity
136 | cos_emb_dict = data.load_emb_dict(os.path.dirname(cos_model_path))
137 | cos_rev_emb_dict = {idx: word for word, idx in cos_emb_dict.items()}
138 | cos_net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(cos_emb_dict), hid_size=model.HIDDEN_STATE_SIZE).to(device)
139 | cos_net.load_state_dict(torch.load(cos_model_path))
140 |
141 | while True:
142 | input_sentence = input_sentences[0]
143 | if input_sentence:
144 | input_string = input_sentence
145 | else:
146 | input_string = input(">>> ")
147 | if not input_string:
148 | break
149 |
150 | words = utils.tokenize(input_string)
151 | for _ in range(self):
152 | if RL:
153 | words_seq2seq = words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, net, use_sampling=sample)
154 | words_bleu = words_to_words(words, bleu_emb_dict, bleu_rev_emb_dict, bleu_net, use_sampling=sample)
155 | words_mutual_RL = words_to_words(words, mutual_emb_dict, mutual_rev_emb_dict, mutual_net, use_sampling=sample)
156 | words_mutual, _ = mutual_words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, b_seq2seq_emb_dict,
157 | net, b_net)
158 |
159 | words_prep = words_to_words(words, prep_emb_dict, prep_rev_emb_dict, prep_net, use_sampling=sample)
160 | words_prep = words_to_words(words, prep1_emb_dict, prep1_rev_emb_dict, prep1_net, use_sampling=sample)
161 | words_cosine = words_to_words(words, cos_emb_dict, cos_rev_emb_dict, cos_net, use_sampling=sample)
162 |
163 | print('Seq2Seq: ', utils.untokenize(words_seq2seq))
164 | print('BLEU: ', utils.untokenize(words_bleu))
165 | print('Mutual Information (RL): ', utils.untokenize(words_mutual_RL))
166 | print('Mutual Information: ', utils.untokenize(words_mutual))
167 | print('Perplexity: ', utils.untokenize(words_prep))
168 | print('Perplexity: ', utils.untokenize(words_prep))
169 | print('Cosine Similarity: ', utils.untokenize(words_cosine))
170 | else:
171 | if mutual:
172 | words, _ = mutual_words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, b_seq2seq_emb_dict,
173 | net, b_net)
174 | else:
175 | words = words_to_words(words, seq2seq_emb_dict, seq2seq_rev_emb_dict, net, use_sampling=sample)
176 |
177 |
178 | if input_string:
179 | break
180 | pass
181 |
--------------------------------------------------------------------------------