├── .gitignore ├── Agent └── Agent.py ├── Environnement ├── Environnement.py └── data_util.py ├── LICENSE ├── LSTM_Model.py ├── NoisyDense.py ├── PriorityExperienceReplay ├── PriorityExperienceReplay.py └── sum_tree.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /Agent/Agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numba as nb 4 | import numpy as np 5 | import math 6 | from random import random, randint 7 | 8 | from keras.optimizers import Adam 9 | from keras.layers import Input, Dense, Embedding, PReLU, BatchNormalization, Conv1D 10 | from keras.models import Model 11 | 12 | from Environnement.Environnement import Environnement 13 | from LSTM_Model import LSTM_Model 14 | from NoisyDense import NoisyDense 15 | from PriorityExperienceReplay.PriorityExperienceReplay import Experience 16 | 17 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 18 | 19 | 20 | class Agent: 21 | def __init__(self, cutoff=8, from_save=False, gamma=.9, batch_size=32, min_history=64000, lr=0.0000625, 22 | sigma_init=0.5, target_network_period=32000, adam_e=1.5*10e-4, atoms=51, 23 | discriminator_loss_limits=0.1, n_steps=3): 24 | 25 | self.cutoff = cutoff 26 | self.environnement = Environnement(cutoff=cutoff, min_frequency_words=300000) 27 | self.vocab = self.environnement.different_words 28 | 29 | self.batch_size = batch_size 30 | 31 | self.n_steps = n_steps 32 | 33 | self.labels = np.array([1] * self.batch_size + [0] * self.batch_size) 34 | self.gammas = np.array([gamma ** (i + 1) for i in range(self.n_steps + 1)]).astype(np.float32) 35 | 36 | self.atoms = atoms 37 | self.v_max = np.sum([0.5 * gam for gam in self.gammas]) 38 | self.v_min = - self.v_max 39 | self.delta_z = (self.v_max - self.v_min) / float(self.atoms - 1) 40 | self.z_steps = np.array([self.v_min + i * self.delta_z for i in range(self.atoms)]).astype(np.float32) 41 | 42 | self.epsilon_greedy_max = 0.8 43 | self.sigma_init = sigma_init 44 | 45 | 46 | self.min_history = min_history 47 | self.lr = lr 48 | self.target_network_period = target_network_period 49 | self.adam_e = adam_e 50 | 51 | self.discriminator_loss_limit = discriminator_loss_limits 52 | 53 | self.model, self.target_model = self._build_model(), self._build_model() 54 | self.discriminator = self._build_discriminator() 55 | 56 | self.dataset_epoch = 0 57 | if from_save is True: 58 | self.model.load_weights('model') 59 | self.target_model.load_weights('model') 60 | self.discriminator.load_weights('discriminator') 61 | 62 | def update_target_model(self): 63 | self.target_model.set_weights(self.model.get_weights()) 64 | 65 | def get_average_noisy_weight(self): 66 | average = [] 67 | for i in range(self.vocab): 68 | average.append(np.mean(self.model.get_layer('Word_'+str(i)).get_weights()[1])) 69 | 70 | return np.mean(average), np.std(average) 71 | 72 | def _build_model(self): 73 | 74 | state_input = Input(shape=(self.cutoff,)) 75 | 76 | embedding = Embedding(self.vocab + 1, 50, input_length=self.cutoff)(state_input) 77 | 78 | main_network = Conv1D(256, 3, padding='same')(embedding) 79 | main_network = PReLU()(main_network) 80 | 81 | main_network = LSTM_Model(main_network, 100, batch_norm=False) 82 | 83 | main_network = Dense(256)(main_network) 84 | main_network = PReLU()(main_network) 85 | 86 | main_network = Dense(512)(main_network) 87 | main_network = PReLU()(main_network) 88 | 89 | dist_list = [] 90 | 91 | for i in range(self.vocab): 92 | dist_list.append(NoisyDense(self.atoms, activation='softmax', sigma_init=self.sigma_init, name='Word_' + str(i))(main_network)) 93 | 94 | 95 | actor = Model(inputs=[state_input], outputs=dist_list) 96 | actor.compile(optimizer=Adam(lr=self.lr, epsilon=self.adam_e), 97 | loss='categorical_crossentropy') 98 | 99 | return actor 100 | 101 | def _build_discriminator(self): 102 | 103 | state_input = Input(shape=(self.cutoff,)) 104 | 105 | embedding = Embedding(self.vocab + 1, 50, input_length=self.cutoff)(state_input) 106 | 107 | main_network = Conv1D(256, 3, padding='same')(embedding) 108 | main_network = PReLU()(main_network) 109 | main_network = BatchNormalization()(main_network) 110 | 111 | main_network = LSTM_Model(main_network, 100) 112 | 113 | main_network = Dense(256)(main_network) 114 | main_network = PReLU()(main_network) 115 | main_network = BatchNormalization()(main_network) 116 | 117 | main_network = Dense(512)(main_network) 118 | main_network = PReLU()(main_network) 119 | main_network = BatchNormalization()(main_network) 120 | 121 | discriminator_output = Dense(1, activation='sigmoid')(main_network) 122 | 123 | 124 | discriminator = Model(inputs=[state_input], outputs=discriminator_output) 125 | discriminator.compile(optimizer=Adam(), 126 | loss='binary_crossentropy') 127 | 128 | discriminator.summary() 129 | 130 | return discriminator 131 | 132 | def train(self, epoch): 133 | 134 | e, total_frames = 0, 0 135 | while e <= epoch: 136 | print('Epoch :', e) 137 | 138 | discrim_loss, model_loss_array, memory = [1], [], Experience(memory_size=1000000, batch_size=self.batch_size, alpha=0.5) 139 | while np.mean(discrim_loss[-20:]) >= self.discriminator_loss_limit: 140 | discrim_loss.append(self.train_discriminator()) 141 | 142 | for i in range(self.min_history//200): 143 | states, rewards, actions, states_prime = self.get_training_batch(200, self.get_epsilon(np.mean(discrim_loss[-20:]))) 144 | for j in range(200): 145 | memory.add((states[j], rewards[j], actions[j], states_prime[j]), 5) 146 | 147 | 148 | trained_frames = 1 149 | while np.mean(discrim_loss[-20:]) < 0.5 + 0.5 * 500000/(trained_frames * 10 * 4 * self.batch_size): 150 | 151 | if trained_frames % (self.target_network_period//(10 * 4 * self.batch_size)) == 0: 152 | self.update_target_model() 153 | 154 | states, rewards, actions, states_prime = self.get_training_batch(10 * self.batch_size, self.get_epsilon(np.mean(discrim_loss[-20:]))) 155 | for j in range(10 * self.batch_size): 156 | memory.add((states[j], rewards[j], actions[j], states_prime[j]), 5) 157 | for j in range(10 * 4): 158 | out, weights, indices = memory.select(min(1, 0.4 + 1.2 * np.mean(discrim_loss[-20:]))) # Scales b value 159 | model_loss_array.append(self.train_on_replay(out, self.batch_size)[0]) 160 | memory.priority_update(indices, [model_loss_array[-1] for _ in range(self.batch_size)]) 161 | 162 | trained_frames += 1 163 | total_frames += 1 164 | discrim_loss.append(self.train_discriminator(evaluate=True)) 165 | 166 | if trained_frames % 100 == 0: 167 | print() 168 | mean, std = self.get_average_noisy_weight() 169 | print('Average loss of model :', np.mean(model_loss_array[-10 * 4 * 20:]), 170 | '\tAverage discriminator loss :', np.mean(discrim_loss[-20:]), 171 | '\tFrames passed :', trained_frames * 10 * 4 * self.batch_size, 172 | '\tTotal frames passed :', total_frames * 10 * 4 * self.batch_size, 173 | '\tAverage Noisy Weights :', mean, 174 | '\tSTD Noisy Weights :', std, 175 | '\tEpoch :', e, 176 | '\tDataset Epoch :', self.dataset_epoch 177 | ) 178 | 179 | self.print_pred() 180 | self.print_pred() 181 | 182 | self.update_target_model() 183 | 184 | e += 1 185 | 186 | def get_epsilon(self, discrim_loss): 187 | epsilon = min(1.0, (0.1 / discrim_loss)) * self.epsilon_greedy_max 188 | return epsilon 189 | 190 | @nb.jit 191 | def train_discriminator(self, evaluate=False): 192 | fake_batch = self.get_fake_batch() 193 | real_batch, done = self.environnement.query_state(self.batch_size) 194 | if done is True: 195 | self.dataset_epoch += 1 196 | print('Current Dataset Epoch :', self.dataset_epoch) 197 | batch = np.vstack((real_batch, fake_batch)) 198 | if evaluate is True: 199 | return self.discriminator.evaluate([batch], [self.labels], verbose=0) 200 | return self.discriminator.train_on_batch([batch], [self.labels]) 201 | 202 | @nb.jit 203 | def make_seed(self, seed=None): 204 | if seed is None: 205 | # This is the kinda Z vector 206 | seed = np.random.random_integers(low=0, high=self.vocab - 1, size=(1, self.cutoff)) 207 | 208 | predictions = self.target_model.predict(seed) 209 | for _ in range(self.cutoff - 1): 210 | numba_optimised_seed_switch(predictions, seed, self.z_steps) 211 | predictions = self.target_model.predict(seed) 212 | numba_optimised_seed_switch(predictions, seed, self.z_steps) 213 | 214 | return seed 215 | 216 | @nb.jit 217 | def get_fake_batch(self): 218 | 219 | seed = self.make_seed() 220 | fake_batch = np.zeros((self.batch_size, self.cutoff)) 221 | for i in range(self.batch_size): 222 | predictions = self.target_model.predict([seed]) 223 | numba_optimised_pred_rollover(predictions, i, seed, fake_batch, self.z_steps) 224 | 225 | return fake_batch 226 | 227 | @nb.jit 228 | def get_training_batch(self, batch_size, epsilon): 229 | seed = self.make_seed() 230 | states = np.zeros((batch_size + self.n_steps, self.cutoff)) 231 | actions = np.zeros((batch_size + self.n_steps, 1)) 232 | for i in range(batch_size + self.n_steps): 233 | action = -1 234 | predictions = self.target_model.predict(seed) 235 | if random() < epsilon: 236 | action = randint(0, self.vocab - 1) 237 | 238 | numba_optimised_pred_rollover_with_actions(predictions, i, seed, states, self.z_steps, actions, action) 239 | 240 | rewards = self.get_values(states) 241 | states_prime = states[self.n_steps:] 242 | 243 | return states[:-self.n_steps], rewards, actions, states_prime 244 | 245 | @nb.jit 246 | def get_values(self, fake_batch): 247 | values = self.discriminator.predict(fake_batch) 248 | return numba_optimised_nstep_value_function(values, values.shape[0], self.n_steps, self.gammas) 249 | 250 | @nb.jit 251 | def print_pred(self): 252 | fake_state = self.make_seed() 253 | 254 | pred = "" 255 | for _ in range(4): 256 | for j in range(self.cutoff): 257 | pred += self.environnement.ind_to_word[fake_state[0][j]] 258 | pred += " " 259 | fake_state = self.make_seed(fake_state) 260 | for j in range(self.cutoff): 261 | pred += self.environnement.ind_to_word[fake_state[0][j]] 262 | pred += " " 263 | print(pred) 264 | 265 | 266 | 267 | # @nb.jit 268 | def train_on_replay(self, data, batch_size): 269 | states, reward, actions, state_prime = make_dataset(data=data, batch_size=batch_size) 270 | 271 | m_prob = np.zeros((batch_size, self.vocab, self.atoms)) 272 | 273 | z = self.target_model.predict(state_prime) 274 | z = np.array(z) 275 | z = np.swapaxes(z, 0, 1) 276 | q = np.sum(np.multiply(z, self.z_steps), axis=-1) 277 | optimal_action_idxs = np.argmax(q, axis=-1) 278 | 279 | update_m_prob(self.batch_size, self.atoms, self.v_max, self.v_min, reward, self.gammas[-1], 280 | self.z_steps, self.delta_z, m_prob, actions, z, optimal_action_idxs) 281 | 282 | return self.model.train_on_batch(states, [m_prob[:,i,:] for i in range(self.vocab)]) 283 | 284 | 285 | @nb.jit(nb.void(nb.int64,nb.int64,nb.float32,nb.float32, nb.float32[:],nb.float32, 286 | nb.float32[:],nb.float32,nb.float32[:,:,:],nb.float32[:,:], nb.float32[:,:,:], nb.float32[:])) 287 | def update_m_prob(batch_size, atoms, v_max, v_min, reward, gamma, z_steps, delta_z, m_prob, actions, z, optimal_action_idxs): 288 | for i in range(batch_size): 289 | for j in range(atoms): 290 | Tz = min(v_max, max(v_min, reward[i] + gamma * z_steps[j])) 291 | bj = (Tz - v_min) / delta_z 292 | m_l, m_u = math.floor(bj), math.ceil(bj) 293 | m_prob[i, actions[i, 0], int(m_l)] += z[i, optimal_action_idxs[i], j] * (m_u - bj) 294 | m_prob[i, actions[i, 0], int(m_l)] += z[i, optimal_action_idxs[i], j] * (bj - m_l) 295 | 296 | # @nb.jit 297 | def make_dataset(data, batch_size): 298 | states, reward, actions, state_prime = [], [], [], [] 299 | for i in range(batch_size): 300 | states.append(data[i][0]) 301 | reward.append(data[i][1]) 302 | actions.append(data[i][2]) 303 | state_prime.append(data[i][3]) 304 | states = np.array(states) 305 | reward = np.array(reward) 306 | actions = np.array(actions).astype(np.int) 307 | state_prime = np.array(state_prime) 308 | return states, reward, actions, state_prime 309 | 310 | @nb.jit(nb.int64(nb.float32[:,:], nb.float32[:])) 311 | def get_optimal_action(z, z_distrib): 312 | 313 | z_concat = np.vstack(z) 314 | q = np.sum(np.multiply(z_concat, z_distrib), axis=1) 315 | action_idx = np.argmax(q) 316 | 317 | return action_idx 318 | 319 | 320 | # Some strong numba optimisation in bottlenecks 321 | # N_Step reward function 322 | @nb.jit(nb.float32[:,:](nb.float32[:,:], nb.int64, nb.int64, nb.float32[:])) 323 | def numba_optimised_nstep_value_function(values, batch_size, n_step, gammas): 324 | for i in range(batch_size): 325 | for j in range(n_step): 326 | values[i] += values[i + j + 1] * gammas[j] 327 | return values[:batch_size] 328 | 329 | 330 | @nb.jit(nb.void(nb.float32[:,:], nb.int64, nb.float32[:,:], nb.float32[:,:], nb.float32[:])) 331 | def numba_optimised_pred_rollover(predictions, index, seed, fake_batch, z_distrib): 332 | seed[:, :-1] = seed[:, 1:] 333 | seed[:, -1] = get_optimal_action(predictions, z_distrib) 334 | fake_batch[index] = seed 335 | 336 | @nb.jit(nb.void(nb.float32[:,:], nb.int64, nb.float32[:,:], nb.float32[:,:], nb.float32[:], nb.float32[:,:], nb.int64)) 337 | def numba_optimised_pred_rollover_with_actions(predictions, index, seed, fake_batch, z_distrib, actions, action): 338 | if action != -1: 339 | choice = action 340 | else: 341 | choice = get_optimal_action(predictions, z_distrib) 342 | seed[:, :-1] = seed[:, 1:] 343 | seed[:, -1] = choice 344 | actions[index] = choice 345 | fake_batch[index] = seed 346 | 347 | @nb.jit(nb.void(nb.float32[:,:], nb.int64, nb.float32[:,:])) 348 | def numba_optimised_seed_switch(predictions, seed, z_distrib): 349 | seed[:, :-1] = seed[:, 1:] 350 | seed[:, -1] = get_optimal_action(predictions, z_distrib) 351 | 352 | 353 | if __name__ == '__main__': 354 | agent = Agent(cutoff=5, from_save=False, batch_size=32) 355 | agent.train(epoch=5000) 356 | -------------------------------------------------------------------------------- /Environnement/Environnement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba 3 | from Environnement import data_util 4 | 5 | class Environnement: 6 | def __init__(self, cutoff=4, min_frequency_words=30000): 7 | self.ind_to_word, self.datas = data_util.convert_text_to_nptensor(cutoff=cutoff, min_frequency_words=min_frequency_words) 8 | self.different_words = len(self.ind_to_word) 9 | self.index = 0 10 | 11 | @numba.jit 12 | def query_state(self, batch_size): 13 | 14 | state = self.datas[self.index: self.index + batch_size] 15 | self.index += batch_size 16 | # End of epoch, shuffle dataset for next epoch 17 | if self.index + batch_size >= self.datas.shape[0]: 18 | self.index = 0 19 | np.random.shuffle(self.datas) 20 | return state, True 21 | else: 22 | return state, False 23 | 24 | if __name__ == '__main__': 25 | env = Environnement() 26 | env.query_state(2) -------------------------------------------------------------------------------- /Environnement/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from string import ascii_lowercase as al 4 | from collections import Counter 5 | import pickle 6 | 7 | # This does two passes through the data, the first one to figure out what are the characters 8 | # or words that interest us, getting rid of those that are not present enough 9 | def convert_text_to_nptensor(directory='../datas/BillionWords/', cutoff=5, min_frequency_words=50000, max_lines=50000000, name='Billion_Words'): 10 | if os.path.isfile('../datas/TransformedData/' + 'text_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.npy'): 11 | X = np.load('../datas/TransformedData/' + 'text_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.npy') 12 | ind_to_word = pickle.load(open('../datas/TransformedData/' + 'ind_to_word_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.pickle', 'rb')) 13 | else: 14 | words = Counter() 15 | n_lines = 0 16 | files = [] 17 | 18 | # First pass to gather statistics on data 19 | for file in os.listdir(directory): 20 | files.append(directory + file) 21 | for file_idx in range(len(files)): 22 | with open(files[file_idx], encoding='utf-8') as f: 23 | text = f.readlines() 24 | for line in text: 25 | line = line.replace('\n', '').replace('.', '').replace('!', '').replace('?', '').replace('\t', ' ') 26 | line = line.lower() 27 | words.update(line.split(' ')) 28 | n_lines += len(line.split(' '))//cutoff 29 | 30 | number_of_words, removed = 0, 0 31 | words_init = len(words) 32 | 33 | for k in list(words): 34 | number_of_words += words[k] 35 | if words[k] < min_frequency_words: 36 | removed += words[k] 37 | del words[k] 38 | print('% of raw words remaining :', (number_of_words - removed)/number_of_words*100.0) 39 | print('Initial amount of tokens :', words_init) 40 | print('Current amount of tokens :', len(words)) 41 | print('% of remaining tokens :', len(words)/words_init) 42 | print('Max amount of lines :', n_lines) 43 | 44 | # We reserve 0 for 0 padding 45 | word_to_ind = dict((c, i) for i, c in enumerate(list(set(words)))) 46 | ind_to_word = dict((i, c) for i, c in enumerate(list(set(words)))) 47 | X = np.zeros((max_lines, cutoff), dtype=np.int16) 48 | 49 | lines_added = 0 50 | 51 | for file_idx in range(len(files)): 52 | with open(files[file_idx], encoding='utf-8') as f: 53 | text = f.readlines() 54 | for line in text: 55 | line = line.replace('\n', '').replace('.', '').replace('!', '').replace('?', '').replace('\t', ' ') 56 | line = line.lower() 57 | line = line.split(' ') 58 | offset = 0 59 | while(len(line) > offset + cutoff) & (lines_added < max_lines): 60 | 61 | # This makes sure that every word in the coming section of text is in the vocabulary, else we pass it 62 | check_word = True 63 | for word in line[offset: offset + cutoff]: 64 | try: 65 | word_to_ind[word] 66 | except KeyError: 67 | check_word = False 68 | if check_word == False: 69 | offset += cutoff 70 | 71 | else: 72 | for t, word in enumerate(line[offset: offset + cutoff]): 73 | X[lines_added, t] = word_to_ind[word] 74 | 75 | offset += cutoff 76 | lines_added += 1 77 | print(lines_added) 78 | with open('../datas/TransformedData/ind_to_word_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.pickle', 'wb') as pck: 79 | pickle.dump(ind_to_word, pck) 80 | np.save('../datas/TransformedData/text_' + name + str(cutoff) + '_' + str(min_frequency_words), X[:lines_added]) 81 | return ind_to_word, X 82 | 83 | 84 | # This does two passes through the data, the first one to figure out what are the characters 85 | # or words that interest us, getting rid of those that are not present enough 86 | def convert_text_to_nptensor_word(directory='../datas/BillionWords/', cutoff=180, max_lines=50000000, name='Billion_Words'): 87 | if os.path.isfile('../datas/TransformedData/' + 'text_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.npy'): 88 | X = np.load('../datas/TransformedData/' + 'text_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.npy') 89 | ind_to_word = pickle.load(open('../datas/TransformedData/' + 'ind_to_word_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.pickle', 'rb')) 90 | else: 91 | words = Counter() 92 | n_lines = 0 93 | files = [] 94 | 95 | # First pass to gather statistics on data 96 | for file in os.listdir(directory): 97 | files.append(directory + file) 98 | for file_idx in range(len(files)): 99 | with open(files[file_idx], encoding='utf-8') as f: 100 | text = f.readlines() 101 | for line in text: 102 | line = line.replace('\n', '').replace('.', '').replace('!', '').replace('?', '').replace('\t', ' ') 103 | line = line.lower() 104 | words.update(line.split(' ')) 105 | n_lines += len(line.split(' '))//cutoff 106 | 107 | number_of_words, removed = 0, 0 108 | words_init = len(words) 109 | 110 | for k in list(words): 111 | number_of_words += words[k] 112 | if words[k] < min_frequency_words: 113 | removed += words[k] 114 | del words[k] 115 | print('% of raw words remaining :', (number_of_words - removed)/number_of_words*100.0) 116 | print('Initial amount of tokens :', words_init) 117 | print('Current amount of tokens :', len(words)) 118 | print('% of remaining tokens :', len(words)/words_init) 119 | print('Max amount of lines :', n_lines) 120 | 121 | # We reserve 0 for 0 padding 122 | word_to_ind = dict((c, i) for i, c in enumerate(list(set(words)))) 123 | ind_to_word = dict((i, c) for i, c in enumerate(list(set(words)))) 124 | X = np.zeros((max_lines, cutoff), dtype=np.int16) 125 | 126 | lines_added = 0 127 | 128 | for file_idx in range(len(files)): 129 | with open(files[file_idx], encoding='utf-8') as f: 130 | text = f.readlines() 131 | for line in text: 132 | line = line.replace('\n', '').replace('.', '').replace('!', '').replace('?', '').replace('\t', ' ') 133 | line = line.lower() 134 | line = line.split(' ') 135 | offset = 0 136 | while(len(line) > offset + cutoff) & (lines_added < max_lines): 137 | 138 | # This makes sure that every word in the coming section of text is in the vocabulary, else we pass it 139 | check_word = True 140 | for word in line[offset: offset + cutoff]: 141 | try: 142 | word_to_ind[word] 143 | except KeyError: 144 | check_word = False 145 | if check_word == False: 146 | offset += cutoff 147 | 148 | else: 149 | for t, word in enumerate(line[offset: offset + cutoff]): 150 | X[lines_added, t] = word_to_ind[word] 151 | 152 | offset += cutoff 153 | lines_added += 1 154 | print(lines_added) 155 | with open('../datas/TransformedData/ind_to_word_' + name + str(cutoff) + '_' + str(min_frequency_words) + '.pickle', 'wb') as pck: 156 | pickle.dump(ind_to_word, pck) 157 | np.save('../datas/TransformedData/text_' + name + str(cutoff) + '_' + str(min_frequency_words), X[:lines_added]) 158 | return ind_to_word, X 159 | 160 | 161 | 162 | if __name__ == '__main__': 163 | convert_text_to_nptensor(cutoff=16, min_frequency_words=150000) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Louis Clouâtre 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LSTM_Model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import CuDNNLSTM 2 | from keras import layers 3 | from keras.layers.normalization import BatchNormalization 4 | from keras.layers.advanced_activations import PReLU 5 | 6 | 7 | def LSTM_Model(input_tensor, gru_cells, batch_norm=True): 8 | model_input = input_tensor 9 | x = CuDNNLSTM(gru_cells, return_sequences=True)(model_input) 10 | x = PReLU()(x) 11 | 12 | y = CuDNNLSTM(gru_cells, return_sequences=True)(x) 13 | y = PReLU()(y) 14 | 15 | z = layers.concatenate([x, y]) 16 | if batch_norm is True: 17 | z = BatchNormalization()(z) 18 | z = CuDNNLSTM(gru_cells)(z) 19 | z = PReLU()(z) 20 | if batch_norm is True: 21 | z = BatchNormalization()(z) 22 | 23 | return z 24 | 25 | -------------------------------------------------------------------------------- /NoisyDense.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.engine.topology import Layer 3 | from keras import activations, initializers, regularizers, constraints 4 | 5 | class NoisyDense(Layer): 6 | 7 | def __init__(self, units, 8 | sigma_init=0.02, 9 | activation=None, 10 | use_bias=True, 11 | kernel_initializer='glorot_uniform', 12 | bias_initializer='zeros', 13 | kernel_regularizer=None, 14 | bias_regularizer=None, 15 | activity_regularizer=None, 16 | kernel_constraint=None, 17 | bias_constraint=None, 18 | **kwargs): 19 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 20 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 21 | super(NoisyDense, self).__init__(**kwargs) 22 | self.units = units 23 | self.sigma_init = sigma_init 24 | self.activation = activations.get(activation) 25 | self.use_bias = use_bias 26 | self.kernel_initializer = initializers.get(kernel_initializer) 27 | self.bias_initializer = initializers.get(bias_initializer) 28 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 29 | self.bias_regularizer = regularizers.get(bias_regularizer) 30 | self.activity_regularizer = regularizers.get(activity_regularizer) 31 | self.kernel_constraint = constraints.get(kernel_constraint) 32 | self.bias_constraint = constraints.get(bias_constraint) 33 | 34 | def build(self, input_shape): 35 | assert len(input_shape) >= 2 36 | self.input_dim = input_shape[-1] 37 | 38 | self.kernel = self.add_weight(shape=(self.input_dim, self.units), 39 | initializer=self.kernel_initializer, 40 | name='kernel', 41 | regularizer=self.kernel_regularizer, 42 | constraint=self.kernel_constraint) 43 | 44 | self.sigma_kernel = self.add_weight(shape=(self.input_dim, self.units), 45 | initializer=initializers.Constant(value=self.sigma_init), 46 | name='sigma_kernel' 47 | ) 48 | 49 | 50 | if self.use_bias: 51 | self.bias = self.add_weight(shape=(self.units,), 52 | initializer=self.bias_initializer, 53 | name='bias', 54 | regularizer=self.bias_regularizer, 55 | constraint=self.bias_constraint) 56 | self.sigma_bias = self.add_weight(shape=(self.units,), 57 | initializer=initializers.Constant(value=self.sigma_init), 58 | name='sigma_bias') 59 | else: 60 | self.bias = None 61 | self.epsilon_bias = None 62 | # self.sample_noise() 63 | super(NoisyDense, self).build(input_shape) 64 | 65 | 66 | def call(self, X): 67 | perturbation = self.sigma_kernel * K.random_normal(shape=(self.input_dim, self.units), mean=0, stddev=1) 68 | perturbed_kernel = self.kernel + perturbation 69 | output = K.dot(X, perturbed_kernel) 70 | if self.use_bias: 71 | bias_perturbation = self.sigma_bias * K.random_normal(shape=(self.units,), mean=0, stddev=1) 72 | perturbed_bias = self.bias + bias_perturbation 73 | output = K.bias_add(output, perturbed_bias) 74 | if self.activation is not None: 75 | output = self.activation(output) 76 | return output 77 | 78 | def compute_output_shape(self, input_shape): 79 | assert input_shape and len(input_shape) >= 2 80 | assert input_shape[-1] 81 | output_shape = list(input_shape) 82 | output_shape[-1] = self.units 83 | return tuple(output_shape) 84 | 85 | def remove_noise(self): 86 | self.sigma_kernel = K.zeros(shape=(self.input_dim, self.units)) 87 | self.sigma_bias = K.zeros(shape=(self.units,)) 88 | 89 | def get_config(self): 90 | config = { 91 | 'units': self.units, 92 | 'sigma_init': self.sigma_init, 93 | 'sigma_kernel': self.sigma_kernel, 94 | 'sigma_bias': self.sigma_bias, 95 | # 'epsilon_bias': self.epsilon_bias, 96 | # 'epsilon_kernel': self.epsilon_kernel, 97 | 'activation': activations.serialize(self.activation), 98 | 'use_bias': self.use_bias, 99 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 100 | 'bias_initializer': initializers.serialize(self.bias_initializer), 101 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 102 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 103 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 104 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 105 | 'bias_constraint': constraints.serialize(self.bias_constraint) 106 | } 107 | base_config = super(NoisyDense, self).get_config() 108 | return dict(list(base_config.items()) + list(config.items())) 109 | -------------------------------------------------------------------------------- /PriorityExperienceReplay/PriorityExperienceReplay.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/takoika/PrioritizedExperienceReplay/ 2 | 3 | import numpy as np 4 | import random 5 | from PriorityExperienceReplay import sum_tree 6 | 7 | 8 | class Experience(object): 9 | """ The class represents prioritized experience replay buffer. 10 | The class has functions: store samples, pick samples with 11 | probability in proportion to sample's priority, update 12 | each sample's priority, reset alpha. 13 | see https://arxiv.org/pdf/1511.05952.pdf . 14 | """ 15 | 16 | def __init__(self, memory_size, batch_size, alpha): 17 | """ Prioritized experience replay buffer initialization. 18 | 19 | Parameters 20 | ---------- 21 | memory_size : int 22 | sample size to be stored 23 | batch_size : int 24 | batch size to be selected by `select` method 25 | alpha: float 26 | exponent determine how much prioritization. 27 | Prob_i \sim priority_i**alpha/sum(priority**alpha) 28 | """ 29 | self.tree = sum_tree.SumTree(memory_size) 30 | self.memory_size = memory_size 31 | self.batch_size = batch_size 32 | self.alpha = alpha 33 | 34 | def add(self, data, priority): 35 | """ Add new sample. 36 | 37 | Parameters 38 | ---------- 39 | data : object 40 | new sample 41 | priority : float 42 | sample's priority 43 | """ 44 | self.tree.add(data, priority ** self.alpha) 45 | 46 | def select(self, beta): 47 | """ The method return samples randomly. 48 | 49 | Parameters 50 | ---------- 51 | beta : float 52 | 53 | Returns 54 | ------- 55 | out : 56 | list of samples 57 | weights: 58 | list of weight 59 | indices: 60 | list of sample indices 61 | The indices indicate sample positions in a sum tree. 62 | """ 63 | 64 | if self.tree.filled_size() < self.batch_size: 65 | return None, None, None 66 | 67 | out = [] 68 | indices = [] 69 | weights = [] 70 | priorities = [] 71 | for _ in range(self.batch_size): 72 | r = random.random() 73 | data, priority, index = self.tree.find(r) 74 | priorities.append(priority) 75 | weights.append((1. / self.memory_size / priority) ** beta if priority > 1e-16 else 0) 76 | indices.append(index) 77 | out.append(data) 78 | self.priority_update([index], [0]) # To avoid duplicating 79 | 80 | self.priority_update(indices, priorities) # Revert priorities 81 | weights /= np.max(weights) # Normalize for stability 82 | return out, weights, indices 83 | 84 | def priority_update(self, indices, priorities): 85 | """ The methods update samples's priority. 86 | 87 | Parameters 88 | ---------- 89 | indices : 90 | list of sample indices 91 | """ 92 | for i, p in zip(indices, priorities): 93 | self.tree.val_update(i, p ** self.alpha) 94 | 95 | def reset_alpha(self, alpha): 96 | """ Reset a exponent alpha. 97 | Parameters 98 | ---------- 99 | alpha : float 100 | """ 101 | self.alpha, old_alpha = alpha, self.alpha 102 | priorities = [self.tree.get_val(i) ** -old_alpha for i in range(self.tree.filled_size())] 103 | self.priority_update(range(self.tree.filled_size()), priorities) -------------------------------------------------------------------------------- /PriorityExperienceReplay/sum_tree.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import math 4 | 5 | 6 | class SumTree(object): 7 | def __init__(self, max_size): 8 | self.max_size = max_size 9 | self.tree_level = math.ceil(math.log(max_size + 1, 2)) + 1 10 | self.tree_size = 2 ** self.tree_level - 1 11 | self.tree = [0 for _ in range(self.tree_size)] 12 | self.data = [None for _ in range(self.max_size)] 13 | self.size = 0 14 | self.cursor = 0 15 | 16 | def add(self, contents, value): 17 | index = self.cursor 18 | self.cursor = (self.cursor + 1) % self.max_size 19 | self.size = min(self.size + 1, self.max_size) 20 | 21 | self.data[index] = contents 22 | self.val_update(index, value) 23 | 24 | def get_val(self, index): 25 | tree_index = 2 ** (self.tree_level - 1) - 1 + index 26 | return self.tree[tree_index] 27 | 28 | def val_update(self, index, value): 29 | tree_index = 2 ** (self.tree_level - 1) - 1 + index 30 | diff = value - self.tree[tree_index] 31 | self.reconstruct(tree_index, diff) 32 | 33 | def reconstruct(self, tindex, diff): 34 | self.tree[tindex] += diff 35 | if not tindex == 0: 36 | tindex = int((tindex - 1) / 2) 37 | self.reconstruct(tindex, diff) 38 | 39 | def find(self, value, norm=True): 40 | if norm: 41 | value *= self.tree[0] 42 | return self._find(value, 0) 43 | 44 | def _find(self, value, index): 45 | if 2 ** (self.tree_level - 1) - 1 <= index: 46 | return self.data[index - (2 ** (self.tree_level - 1) - 1)], self.tree[index], index - ( 47 | 2 ** (self.tree_level - 1) - 1) 48 | 49 | left = self.tree[2 * index + 1] 50 | 51 | if value <= left: 52 | return self._find(value, 2 * index + 1) 53 | else: 54 | return self._find(value - left, 2 * (index + 1)) 55 | 56 | def print_tree(self): 57 | for k in range(1, self.tree_level + 1): 58 | for j in range(2 ** (k - 1) - 1, 2 ** k - 1): 59 | print(self.tree[j], end=' ') 60 | print() 61 | 62 | def filled_size(self): 63 | return self.size 64 | 65 | 66 | if __name__ == '__main__': 67 | s = SumTree(10) 68 | for i in range(20): 69 | s.add(2 ** i, i) 70 | s.print_tree() 71 | print(s.find(0.5)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative_NLP_RL_GAN 2 | Trying to train a NLP generative model in a reinforcement learning setting. 3 | 4 | I currently am trying to train a https://arxiv.org/abs/1710.02298 Rainbow DQN (only Noisy network, C51 and prioritized experience replay, paper seemed to show it gives the biggest gains by far) to generate from a truncated Google Billion Word dataset. 5 | 6 | The general idea is to beat several different environnement, that get progressively harder with my DQN. In this case, the environnement is a discriminator that is trained to differentiate between my DQN's output and the dataset untill it's loss reaches a treshold (0.1 in this case), the reward is the output of the discriminator and we consider the environnement beat when the loss of the discriminator would reach another threshold (0.9 in this case). 7 | 8 | The model adds 1 word at the time to a word vector of fixed size and then the discriminator evaluates the full vector. 9 | --------------------------------------------------------------------------------