├── .gitignore ├── LICENSE ├── RBM.py ├── README-datasets.md ├── README.md ├── Sampling.py ├── ShortTextCodec.py ├── Utils.py ├── compare_models.py ├── requirements.txt ├── sample.py ├── samples ├── README.markdown ├── actors.txt ├── actors_unique.txt ├── games.txt ├── games_unique.txt ├── raw │ ├── actors.txt │ ├── actors_first.txt │ ├── games.txt │ ├── repos.txt │ └── usgeo.txt ├── repos.txt ├── repos2.txt ├── repos2_unique.txt ├── repos_unique.txt ├── usgeo.txt └── usgeo_unique.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pickle 2 | *.pyc 3 | *.html 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Colin Morris 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RBM.py: -------------------------------------------------------------------------------- 1 | """Restricted Boltzmann Machine with softmax visible units. 2 | Based on sklearn's BernoulliRBM class. 3 | """ 4 | 5 | # Authors: Yann N. Dauphin 6 | # Vlad Niculae 7 | # Gabriel Synnaeve 8 | # Lars Buitinck 9 | # License: BSD 3 clause 10 | 11 | import time 12 | import re 13 | 14 | import numpy as np 15 | import scipy.sparse as sp 16 | 17 | from sklearn.base import BaseEstimator 18 | from sklearn.base import TransformerMixin 19 | from sklearn.externals.six.moves import xrange 20 | from sklearn.utils import check_array 21 | from sklearn.utils import check_random_state 22 | from sklearn.utils import gen_even_slices 23 | from sklearn.utils import issparse 24 | from sklearn.utils import shuffle 25 | from sklearn.utils.extmath import safe_sparse_dot, log_logistic 26 | from sklearn.utils.fixes import expit # logistic function 27 | from sklearn.utils.validation import check_is_fitted 28 | 29 | import Utils 30 | 31 | # Experiment: when sampling with high temperature (>1), use the softmax probabilities 32 | # of the biases as the prior rather than a uniform distribution. Based on the observation 33 | # that annealing starting from a high temperature often resulted in samples that were 34 | # highly biased toward long strings (because a uniform distribution over the visible 35 | # units will tend to produce strings of the maximum length). 36 | # This kind of helped but wasn't amazing. Possibly I just needed a longer/gentler annealing schedule? 37 | BIASED_PRIOR = 0 38 | 39 | class BernoulliRBM(BaseEstimator, TransformerMixin): 40 | """Bernoulli Restricted Boltzmann Machine (RBM). 41 | 42 | A Restricted Boltzmann Machine with binary visible units and 43 | binary hiddens. Parameters are estimated using Stochastic Maximum 44 | Likelihood (SML), also known as Persistent Contrastive Divergence (PCD) 45 | [2]. 46 | 47 | The time complexity of this implementation is ``O(d ** 2)`` assuming 48 | d ~ n_features ~ n_components. 49 | 50 | Parameters 51 | ---------- 52 | 53 | n_components : int, optional 54 | Number of binary hidden units. 55 | 56 | learning_rate : float, optional 57 | The learning rate for weight updates. It is *highly* recommended 58 | to tune this hyper-parameter. Reasonable values are in the 59 | 10**[0., -3.] range. 60 | 61 | batch_size : int, optional 62 | Number of examples per minibatch. 63 | 64 | n_iter : int, optional 65 | Number of iterations/sweeps over the training dataset to perform 66 | during training. 67 | 68 | verbose : int, optional 69 | The verbosity level. The default, zero, means silent mode. 70 | 71 | random_state : integer or numpy.RandomState, optional 72 | A random number generator instance to define the state of the 73 | random permutations generator. If an integer is given, it fixes the 74 | seed. Defaults to the global numpy random number generator. 75 | 76 | Attributes 77 | ---------- 78 | intercept_hidden_ : array-like, shape (n_components,) 79 | Biases of the hidden units. 80 | 81 | intercept_visible_ : array-like, shape (n_features,) 82 | Biases of the visible units. 83 | 84 | components_ : array-like, shape (n_components, n_features) 85 | Weight matrix, where n_features in the number of 86 | visible units and n_components is the number of hidden units. 87 | 88 | References 89 | ---------- 90 | 91 | [1] Hinton, G. E., Osindero, S. and Teh, Y. A fast learning algorithm for 92 | deep belief nets. Neural Computation 18, pp 1527-1554. 93 | http://www.cs.toronto.edu/~hinton/absps/fastnc.pdf 94 | 95 | [2] Tieleman, T. Training Restricted Boltzmann Machines using 96 | Approximations to the Likelihood Gradient. International Conference 97 | on Machine Learning (ICML) 2008 98 | """ 99 | 100 | def __init__(self, n_components=256, learning_rate=0.1, batch_size=10, 101 | n_iter=10, verbose=0, random_state=None, lr_backoff=False, weight_cost=0): 102 | self.n_components = n_components 103 | self.base_learning_rate = learning_rate 104 | self.learning_rate = learning_rate 105 | self.lr_backoff = lr_backoff 106 | self.batch_size = batch_size 107 | self.n_iter = n_iter 108 | self.verbose = verbose 109 | self.random_state = random_state 110 | self.rng_ = check_random_state(self.random_state) 111 | self.weight_cost = weight_cost 112 | # A history of some summary statistics recorded at the end of each epoch of training 113 | # Each key maps to a 2-d array. One row per 'session', one value per epoch. 114 | # (Another session means this model was pickled, then loaded and fit again.) 115 | self.history = {'pseudo-likelihood': [], 'overfit': []} 116 | 117 | # TODO 118 | # Experimental: How many times more fantasy particles compared to minibatch size 119 | @property 120 | def fantasy_to_batch(self): 121 | return 1 122 | 123 | def record(self, name, value): 124 | if not hasattr(self, 'history'): 125 | self.history = {'pseudo-likelihood': [], 'overfit': []} 126 | self.history[name][-1].append(value) 127 | 128 | def _mean_hiddens(self, v, temperature=1.0): 129 | """Computes the probabilities P(h=1|v). 130 | 131 | v : array-like, shape (n_samples, n_features) 132 | Values of the visible layer. 133 | 134 | Returns 135 | ------- 136 | h : array-like, shape (n_samples, n_components) 137 | Corresponding mean field values for the hidden layer. 138 | """ 139 | p = safe_sparse_dot(v, self.components_.T/temperature) 140 | p += self.intercept_hidden_/(min(1.0, temperature) if BIASED_PRIOR else temperature) 141 | return expit(p, out=p) 142 | 143 | def _sample_hiddens(self, v, temperature=1.0): 144 | """Sample from the distribution P(h|v). 145 | 146 | v : array-like, shape (n_samples, n_features) 147 | Values of the visible layer to sample from. 148 | 149 | Returns 150 | ------- 151 | h : array-like, shape (n_samples, n_components) 152 | Values of the hidden layer. 153 | """ 154 | p = self._mean_hiddens(v, temperature) 155 | return (self.rng_.random_sample(size=p.shape) < p) 156 | 157 | def _sample_visibles(self, h, temperature=1.0): 158 | """Sample from the distribution P(v|h). 159 | 160 | h : array-like, shape (n_samples, n_components) 161 | Values of the hidden layer to sample from. 162 | 163 | Returns 164 | ------- 165 | v : array-like, shape (n_samples, n_features) 166 | Values of the visible layer. 167 | """ 168 | p = np.dot(h, self.components_/temperature) 169 | p += self.intercept_visible_/(min(1.0, temperature) if BIASED_PRIOR else temperature) 170 | expit(p, out=p) 171 | return (self.rng_.random_sample(size=p.shape) < p) 172 | 173 | def _free_energy(self, v): 174 | """Computes the free energy F(v) = - log sum_h exp(-E(v,h)). 175 | 176 | v : array-like, shape (n_samples, n_features) 177 | Values of the visible layer. 178 | 179 | Returns 180 | ------- 181 | free_energy : array-like, shape (n_samples,) 182 | The value of the free energy. 183 | """ 184 | return (- safe_sparse_dot(v, self.intercept_visible_) 185 | - np.logaddexp(0, safe_sparse_dot(v, self.components_.T) 186 | + self.intercept_hidden_).sum(axis=1)) 187 | 188 | def gibbs(self, v, temperature=1.0): 189 | """Perform one Gibbs sampling step. 190 | 191 | v : array-like, shape (n_samples, n_features) 192 | Values of the visible layer to start from. 193 | 194 | Returns 195 | ------- 196 | v_new : array-like, shape (n_samples, n_features) 197 | Values of the visible layer after one Gibbs step. 198 | """ 199 | check_is_fitted(self, "components_") 200 | h_ = self._sample_hiddens(v, temperature) 201 | v_ = self._sample_visibles(h_, temperature) 202 | 203 | return v_ 204 | 205 | def repeated_gibbs(self, v, niters): 206 | """Perform n rounds of alternating Gibbs sampling starting from the 207 | given visible vectors. 208 | """ 209 | for i in range(niters): 210 | h = self._sample_hiddens(v) 211 | v = self._sample_visibles(h, temperature=1.0) 212 | return v 213 | 214 | def partial_fit(self, X, y=None): 215 | """Fit the model to the data X which should contain a partial 216 | segment of the data. 217 | 218 | X : array-like, shape (n_samples, n_features) 219 | Training data. 220 | 221 | Returns 222 | ------- 223 | self : BernoulliRBM 224 | The fitted model. 225 | """ 226 | X = check_array(X, accept_sparse='csr', dtype=np.float) 227 | if not hasattr(self, 'components_'): 228 | self.components_ = np.asarray( 229 | self.rng_.normal( 230 | 0, 231 | 0.01, 232 | (self.n_components, X.shape[1]) 233 | ), 234 | order='fortran') 235 | if not hasattr(self, 'intercept_hidden_'): 236 | self.intercept_hidden_ = np.zeros(self.n_components, ) 237 | if not hasattr(self, 'intercept_visible_'): 238 | self.intercept_visible_ = np.zeros(X.shape[1], ) 239 | if not hasattr(self, 'h_samples_'): 240 | self.h_samples_ = np.zeros((self.batch_size, self.n_components)) 241 | 242 | self._fit(X) 243 | 244 | def _fit(self, v_pos): 245 | """Inner fit for one mini-batch. 246 | 247 | Adjust the parameters to maximize the likelihood of v using 248 | Stochastic Maximum Likelihood (SML). 249 | 250 | v_pos : array-like, shape (n_samples, n_features) 251 | The data to use for training. 252 | """ 253 | h_pos = self._mean_hiddens(v_pos) 254 | # TODO: Worth trying with visible probabilities rather than binary states. 255 | # PG: it is common to use p_i instead of sampling a binary value'... 'it reduces 256 | # sampling noise this allowing faster learning. There is some evidence that it leads 257 | # to slightly worse density models' 258 | 259 | # I'm confounded by the fact that we seem to get more effective models WITHOUT 260 | # softmax visible units. The only explanation I can think of is that it's like 261 | # a pseudo-version of using visible probabilities. Without softmax, v_neg 262 | # can have multiple 1s per one-hot vector, which maybe somehow accelerates learning? 263 | # Need to think about this some more. 264 | v_neg = self._sample_visibles(self.h_samples_) 265 | h_neg = self._mean_hiddens(v_neg) 266 | 267 | lr = float(self.learning_rate) / v_pos.shape[0] 268 | update = safe_sparse_dot(v_pos.T, h_pos, dense_output=True).T 269 | update -= np.dot(h_neg.T, v_neg) / self.fantasy_to_batch 270 | # L2 weight penalty 271 | update -= self.components_ * self.weight_cost 272 | self.components_ += lr * update 273 | self.intercept_hidden_ += lr * (h_pos.sum(axis=0) - h_neg.sum(axis=0)/self.fantasy_to_batch) 274 | self.intercept_visible_ += lr * (np.asarray( 275 | v_pos.sum(axis=0)).squeeze() - 276 | v_neg.sum(axis=0)/self.fantasy_to_batch) 277 | 278 | h_neg[self.rng_.uniform(size=h_neg.shape) < h_neg] = 1.0 # sample binomial 279 | self.h_samples_ = np.floor(h_neg, h_neg) 280 | 281 | def corrupt(self, v): 282 | # Randomly corrupt one feature in each sample in v. 283 | ind = (np.arange(v.shape[0]), 284 | self.rng_.randint(0, v.shape[1], v.shape[0])) 285 | if issparse(v): 286 | data = -2 * v[ind] + 1 287 | v_ = v + sp.csr_matrix((data.A.ravel(), ind), shape=v.shape) 288 | else: 289 | v_ = v.copy() 290 | v_[ind] = 1 - v_[ind] 291 | return v_, None 292 | 293 | def uncorrupt(self, visibles, state): 294 | pass 295 | 296 | @Utils.timeit 297 | def score_samples(self, X): 298 | """Compute the pseudo-likelihood of X. 299 | 300 | X : {array-like, sparse matrix} shape (n_samples, n_features) 301 | Values of the visible layer. Must be all-boolean (not checked). 302 | 303 | Returns 304 | ------- 305 | pseudo_likelihood : array-like, shape (n_samples,) 306 | Value of the pseudo-likelihood (proxy for likelihood). 307 | 308 | Notes 309 | ----- 310 | This method is not deterministic: it computes a quantity called the 311 | free energy on X, then on a randomly corrupted version of X, and 312 | returns the log of the logistic function of the difference. 313 | """ 314 | check_is_fitted(self, "components_") 315 | 316 | v = check_array(X, accept_sparse='csr') 317 | fe = self._free_energy(v) 318 | 319 | v_, state = self.corrupt(v) 320 | # TODO: If I wanted to be really fancy here, I would do one of those "with..." things. 321 | fe_corrupted = self._free_energy(v) 322 | self.uncorrupt(v, state) 323 | 324 | # See https://en.wikipedia.org/wiki/Pseudolikelihood 325 | # Let x be some visible vector. x_i is the ith entry. x_-i is the vector except that entry. 326 | # x_iflipped is x with the ith bit flipped. F() is free energy. 327 | # P(x_i | x_-i) = P(x) / P(x_-i) = P(x) / (P(x) + p(x_iflipped)) 328 | # expand def'n of P(x), cancel out the partition function on each term, and divide top and bottom by e^{-F(x)} to get... 329 | # 1 / (1 + e^{F(x) - F(x_iflipped)}) 330 | # So we're just calculating the log of that. We multiply by the number of 331 | # visible units because we're approximating P(x) as the product of the conditional likelihood 332 | # of each individual unit. But we're too lazy to do each one individually, so we say the unit 333 | # we tested represents an average. 334 | if hasattr(self, 'codec'): 335 | normalizer = self.codec.shape()[0] 336 | else: 337 | normalizer = v.shape[1] 338 | return normalizer * log_logistic(fe_corrupted - fe) 339 | 340 | # TODO: No longer used 341 | def pseudolikelihood_ratio(self, good, bad): 342 | assert good.shape == bad.shape 343 | good_energy = self._free_energy(good) 344 | bad_energy = self._free_energy(bad) 345 | # Let's do ratio of log probabilities instead 346 | return (bad_energy - good_energy).mean() 347 | 348 | @Utils.timeit 349 | def score_validation_data(self, train, validation): 350 | """Return the energy difference between the given validation data, and a 351 | subset of the training data. This is useful for monitoring overfitting. 352 | If the model isn't overfitting, the difference should be around 0. The 353 | greater the difference, the more the model is overfitting. 354 | """ 355 | # It's important to use the same subset of the training data every time (per Hinton's "Practical Guide") 356 | return self._free_energy(train[:validation.shape[0]]).mean(), self._free_energy(validation).mean() 357 | 358 | def fit(self, X, validation=None): 359 | """Fit the model to the data X. 360 | 361 | X : {array-like, sparse matrix} shape (n_samples, n_features) 362 | Training data. 363 | 364 | validation : {array-like, sparse matrix} 365 | 366 | Returns 367 | ------- 368 | self : BernoulliRBM 369 | The fitted model. 370 | """ 371 | X = check_array(X, accept_sparse='csr', dtype=np.float) 372 | n_samples = X.shape[0] 373 | 374 | if not hasattr(self, 'components_'): 375 | self.components_ = np.asarray( 376 | self.rng_.normal(0, 0.01, (self.n_components, X.shape[1])), 377 | order='fortran') 378 | self.intercept_hidden_ = np.zeros(self.n_components, ) 379 | # 'It is usually helpful to initialize the bias of visible unit i to log[p_i/(1-p_i)] where p_i is the prptn of training vectors where i is on' - Practical Guide 380 | # TODO: Make this configurable? 381 | if 1: 382 | counts = X.sum(axis=0).A.reshape(-1) 383 | # There should be no units that are always on 384 | assert np.max(counts) < X.shape[0], "Found a visible unit always on in the training data. Fishy." 385 | # There might be some units never on. Add a pseudo-count of 1 to avoid inf 386 | vis_priors = (counts + 1) / float(X.shape[0]) 387 | self.intercept_visible_ = np.log( vis_priors / (1 - vis_priors) ) 388 | else: 389 | self.intercept_visible_ = np.zeros(X.shape[1], ) 390 | 391 | # If this already *does* have weights and biases before fit() is called, 392 | # we'll start from them rather than wiping them out. May want to train 393 | # a model further with a different learning rate, or even on a different 394 | # dataset. 395 | else: 396 | print "Reusing existing weights and biases" 397 | # Don't necessarily want to reuse h_samples if we have one leftover from before - batch size might have changed 398 | self.h_samples_ = np.zeros((self.batch_size * self.fantasy_to_batch, self.n_components)) 399 | 400 | # Add new inner lists for this session 401 | if not hasattr(self, 'history'): 402 | self.history = {'pseudo-likelihood': [], 'overfit': []} 403 | for session in self.history.itervalues(): 404 | session.append([]) 405 | 406 | n_batches = int(np.ceil(float(n_samples) / self.batch_size)) 407 | batch_slices = list(gen_even_slices(n_batches * self.batch_size, 408 | n_batches, n_samples)) 409 | verbose = self.verbose 410 | begin = time.time() 411 | for iteration in xrange(1, self.n_iter + 1): 412 | if self.lr_backoff: 413 | # If, e.g., we're doing 10 epochs, use the full learning rate for 414 | # the first iteration, 90% of the base learning rate for the second 415 | # iteration... and 10% for the final iteration 416 | self.learning_rate = ((self.n_iter - (iteration - 1)) / (self.n_iter+0.0)) * self.base_learning_rate 417 | print "Using learning rate of {:.3f} (base LR={:.3f})".format(self.learning_rate, self.base_learning_rate) 418 | 419 | for batch_slice in batch_slices: 420 | self._fit(X[batch_slice]) 421 | 422 | if verbose and iteration != self.n_iter: 423 | end = time.time() 424 | self.wellness_check(iteration, end - begin, X, validation) 425 | begin = end 426 | if iteration != self.n_iter: 427 | X = shuffle(X) 428 | 429 | return self 430 | 431 | def wellness_check(self, epoch, duration, train, validation): 432 | """Log some diagnostic information on how the model is doing so far.""" 433 | validation_debug = '' 434 | if validation is not None: 435 | t_energy, v_energy = self.score_validation_data(train, validation) 436 | validation_debug = "\nE(vali):\t{:.2f}\tE(train):\t{:.2f}\tdifference: {:.2f}".format( 437 | v_energy, t_energy, v_energy-t_energy) 438 | self.record('overfit', (v_energy, t_energy)) 439 | 440 | # TODO: This is pretty expensive. Figure out why? Or just do less often. 441 | # Also, can use crippling amounts of memory for large datasets. Hack... 442 | pseudo = self.score_samples(train[:min(train.shape[0], 10**5)]) 443 | self.record('pseudo-likelihood', pseudo.mean()) 444 | print re.sub('\n *', '\n', """[{}] Iteration {}/{}\tt = {:.2f}s 445 | Pseudo-log-likelihood sum: {:.2f}\tAverage per instance: {:.2f}{}""".format 446 | (type(self).__name__, epoch, self.n_iter, duration, 447 | pseudo.sum(), pseudo.mean(), validation_debug, 448 | )) 449 | 450 | 451 | class CharBernoulliRBM(BernoulliRBM): 452 | 453 | def __init__(self, codec, *args, **kwargs): 454 | """ 455 | codec is the ShortTextCodec used to create the vectors being fit. The 456 | most important function of the codec is as a proxy to the shape of the 457 | softmax units in the visible layer (if you're using the CharBernoulliRBMSoftmax 458 | subclass). It's also used to decode and print 459 | fantasy particles at the end of each epoch. 460 | """ 461 | # Attaching this to the object is really helpful later on when models 462 | # are loaded from pickle in visualize.py and sample.py 463 | self.codec = codec 464 | self.softmax_shape = codec.shape() 465 | # Old-style class :( 466 | BernoulliRBM.__init__(self, *args, **kwargs) 467 | 468 | def wellness_check(self, epoch, duration, train, validation): 469 | BernoulliRBM.wellness_check(self, epoch, duration, train, validation) 470 | fantasy_samples = '|'.join([self.codec.decode(vec) for vec in 471 | self._sample_visibles(self.h_samples_[:3], temperature=0.1)]) 472 | print "Fantasy samples: {}".format(fantasy_samples) 473 | 474 | def corrupt(self, v): 475 | n_softmax, n_opts = self.softmax_shape 476 | # Select a random index in to the indices of the non-zero values of each input 477 | # TODO: In the char-RBM case, if I wanted to really challenge the model, I would avoid selecting any 478 | # trailing spaces here. Cause any dumb model can figure out that it should assign high energy to 479 | # any instance of / [^ ]/ 480 | meta_indices_to_corrupt = self.rng_.randint(0, n_softmax, v.shape[0]) + np.arange(0, n_softmax * v.shape[0], n_softmax) 481 | 482 | # Offset these indices by a random amount (but not 0 - we want to actually change them) 483 | offsets = self.rng_.randint(1, n_opts, v.shape[0]) 484 | # Also, do some math to make sure we don't "spill over" into a different softmax. 485 | # E.g. if n_opts=5, and we're corrupting index 3, we should choose offsets from {-3, -2, -1, +1} 486 | # 1-d array that matches with meta_i_t_c but which contains the indices themselves 487 | indices_to_corrupt = v.indices[meta_indices_to_corrupt] 488 | # Sweet lucifer 489 | offsets = offsets - (n_opts * (((indices_to_corrupt % n_opts) + offsets.ravel()) >= n_opts)) 490 | 491 | v.indices[meta_indices_to_corrupt] += offsets 492 | return v, (meta_indices_to_corrupt, offsets) 493 | 494 | def uncorrupt(self, visibles, state): 495 | mitc, offsets = state 496 | visibles.indices[mitc] -= offsets 497 | 498 | 499 | class CharBernoulliRBMSoftmax(CharBernoulliRBM): 500 | 501 | def _sample_visibles(self, h, temperature=1.0): 502 | """Sample from the distribution P(v|h). This obeys the softmax constraint 503 | on visible units. i.e. sum(v) == softmax_shape[0] for any visible 504 | configuration v. 505 | 506 | h : array-like, shape (n_samples, n_components) 507 | Values of the hidden layer to sample from. 508 | 509 | Returns 510 | ------- 511 | v : array-like, shape (n_samples, n_features) 512 | Values of the visible layer. 513 | """ 514 | p = np.dot(h, self.components_/temperature) 515 | p += self.intercept_visible_/(min(1.0, temperature) if BIASED_PRIOR else temperature) 516 | nsamples, nfeats = p.shape 517 | reshaped = np.reshape(p, (nsamples,) + self.softmax_shape) 518 | return Utils.softmax_and_sample(reshaped).reshape((nsamples, nfeats)) 519 | 520 | 521 | -------------------------------------------------------------------------------- /README-datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | ## A note on preprocessing 4 | 5 | In my experiments, I removed duplicates as part of the preprocessing step for each dataset below. I didn't think much of it at the time, but if I were to do these experiments again, I would *not* dedupe. The main argument against deduping is that it's giving the model a censored version of the density function you're trying to get it to learn. Also, as you change the amount of data you collect for your training set, you're changing the actual shape of the distribution. As you collect more data, the modes of your data will form increasingly smaller proportions of the dataset. In the limit, you would approach the uniform distribution. This concern is not entirely theoretical - for example, there are many GitHub repositories that seem to have been named completely randomly (example from the training set: `LlHZRbrhqYXMlX`). 6 | 7 | On the other hand, having regions of extremely high probability can hurt your model's mixing rate and make sampling more difficult, so it's not an obvious decision. 8 | 9 | ## Usgeo 10 | 11 | To download the geonames corpus of US geographical names: 12 | 13 | wget http://download.geonames.org/export/dump/US.zip 14 | unzip US.zip 15 | cut -f 2 US.txt > usnames.txt 16 | 17 | That should give you around 2.2m names. 18 | 19 | I filtered out punctuation and numerals in the beginning to make the problem as easy as possible, but empirically, adding a few more characters doesn't slow down training that much, and doesn't seem to hurt sample quality. 20 | 21 | For more information on the geonames data, and to download names from other countries, check out the [geonames website](http://www.geonames.org/export/). 22 | 23 | ## Actors 24 | 25 | I used [actors.list.gz](ftp://ftp.fu-berlin.de/pub/misc/movies/database/actors.list.gz) from [IMDB's public datasets](http://www.imdb.com/interfaces). Note that you'll only get male names in this list - if you want female names as well, you'll want to grab `actresses.list.gz`. 26 | 27 | There's nothing special about *actor* names in particular that I wanted to capture - this was just the easiest way to get a big list of full names. 28 | 29 | ## First/last names 30 | 31 | Check out [this directory](http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/) for deduped lists of first/last names. 32 | 33 | At around 60k tokens, this dataset is relatively small - you'll probably want to do many epochs of training. 34 | 35 | ## GitHub repositories 36 | 37 | I [used Google BigQuery](https://www.githubarchive.org/#bigquery) to grab all distinct repository names (n=3.7m) from GitHub's 2014 archive. This involved puzzling over a lot of help articles and giving Google my credit card information, so to make things easier for future interested parties, I've dumped the dataset into [a GitHub repo](https://github.com/colinmorris/reponames-dataset). 38 | 39 | ## Board Games 40 | 41 | I grabbed a scrape of board game geek data from [here](https://github.com/ThaWeatherman/scrapers/blob/master/boardgamegeek/games.csv). Thanks to /u/thaweatherman for [posting this on /r/datasets](https://www.reddit.com/r/datasets/comments/3lm8p4/boardgamegeek_data/). 42 | 43 | There are around 80k games total, which is more than the personal names dataset, but these names are much longer and high-entropy. I didn't have much luck learning a model of this data. 44 | 45 | ## Other ideas 46 | 47 | It's not hard to imagine other domains we could apply this to. For example, the names/titles of... 48 | 49 | - books 50 | - movies 51 | - songs 52 | - bands 53 | - prescription drugs 54 | 55 | I was able to find large public datasets for some of these domains (e.g. [the Project Gutenberg catalog](http://www.gutenberg.org/wiki/Gutenberg:Offline_Catalogs) for books), but a common problem was that they would often contain names from many different languages mixed together. Which makes the problem harder by making the data distribution more complex and multi-modal. It also makes it harder to qualitatively assess outputs. 56 | 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # char-boltzmann 2 | 3 | Character-level RBMs for short text. For more information, check out [my blog post](https://colinmorris.github.io/blog/dreaming-rbms). 4 | 5 | # Requirements 6 | 7 | scikit-learn and its dependencies (numpy, scipy) is the big one. Also enum34. `pip install -r requirements.txt` might be all you need to do. 8 | 9 | # How-to 10 | 11 | The two important scripts are: 12 | 13 | - `train.py`: trains an RBM model on a text file with one short text per line. It has a whole bunch of command line options you can supply, but the defaults are all pretty reasonable. The one you're most likely to need to change is `--extra-chars` - the default behaviour is to use only `[a-z ]` (and `[A-Z]` implicitly downcased), which is definitely not appropriate for some datasets having lots of numerals/punctuation. 14 | - `sample.py`: generates new short texts given a pickled model file generated by `train.py` 15 | 16 | (The last script, `compare_models.py` is only really relevant if you're training a bunch of different models on the same dataset and enjoy spreadsheets.) 17 | 18 | More details on the arguments to these scripts can be seen by running them with '-h'. 19 | 20 | README-datasets.md has pointers to some suitable datasets. 21 | 22 | # Example 23 | 24 | To train a small model on first names: 25 | 26 | wget http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/other/names.txt 27 | python train.py --maxlen 10 --extra-chars '' --hid 100 names.txt 28 | python sample.py names__nh100.pickle 29 | 30 | This should give you some output like... 31 | 32 | wietzer 33 | sarnimono 34 | buttheo 35 | ressinosoo 36 | bernington 37 | 38 | # Interpreting train.py output 39 | 40 | During training, you'll see debug output like... 41 | 42 | [CharBernoulliRBMSoftmax] Iteration 3/5 t = 14.46s 43 | Pseudo-log-likelihood sum: -115047.96 Average per instance: -2.13 44 | E(vali): -14.00 E(train): -14.07 difference: 0.07 45 | Fantasy samples: moll$$$$$$|anderd$$$$|gronbel$$$ 46 | 47 | Without going into too much detail, the pseudo-log-likelihood (-2.13 above), is a pretty decent estimation of how well the model is currently fitting the training data. The lower the better. 48 | 49 | The next line compares the energy assigned to the training data vs. the validation set. The difference (0.07 in this case) gives an idea of how much the model is overfitting. The higher the difference, the worse. A difference of 0 implies no overfitting. 50 | 51 | The final line has string representions of a few of the "fantasy particles" used for the [persistent contrastive divergence](http://www.cs.toronto.edu/~tijmen/pcd/pcd.pdf) training. 52 | 53 | # More details 54 | 55 | The core RBM code is cannibalized from scikit-learn's [BernoulliRBM](http://scikit-learn.org/stable/modules/generated/sklearn.neural_network.BernoulliRBM.html#sklearn.neural_network.BernoulliRBM) implementation. I tacked on some additional features including: 56 | 57 | - L2 weight cost 58 | - softmax sampling 59 | - sampling with temperature (for simulated annealing) 60 | - flag to gradually reduce learning rate 61 | - initializing visible biases to the training set means 62 | 63 | This code has the same performance limitations as the base sklearn implementation. In particular, it can't run on a GPU. 64 | 65 | The 'workspace' branch has a lot of extra scripts and data files which *might* be useful to someone, but which are kind of messy (even relative to the already-kinda-messy master). They mostly relate to model visualization and experiments with different sampling techniques. 66 | -------------------------------------------------------------------------------- /Sampling.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import pickle 4 | import numpy as np 5 | import enum 6 | 7 | import Utils 8 | from ShortTextCodec import ShortTextCodec 9 | 10 | MAX_PROG_SAMPLE_INTERVAL = 10000 11 | 12 | # If true, then subtract a constant between iterations when annealing, rather than dividing by a constant. 13 | # Literature seems divided on the best way to do this? Anecdotally, seem to get better results with exp 14 | # decay most of the time, but haven't looked very carefully. 15 | LINEAR_ANNEAL = 0 16 | 17 | BIG_NUMBER = 3.0 18 | 19 | class shrink_model(object): 20 | 21 | def __init__(self, model, min_length, max_length): 22 | assert 1 <= max_length <= model.codec.maxlen 23 | assert 0 <= min_length <= model.codec.maxlen 24 | self.model = model 25 | self.min_length = min_length 26 | self.max_length = max_length 27 | 28 | def __enter__(self): 29 | codec = self.model.codec 30 | model = self.model 31 | padidx = codec.char_lookup[codec.filler] 32 | self.prev_biases = [model.intercept_visible_[codec.nchars*posn+padidx] for posn in range(codec.maxlen)] 33 | # Force padding character off for all indices up to min length 34 | for posn in range(self.min_length): 35 | model.intercept_visible_[codec.nchars*posn + padidx] += -1*BIG_NUMBER 36 | 37 | # Force padding character *on* for indices past max length 38 | for posn in range(self.max_length, codec.maxlen): 39 | model.intercept_visible_[codec.nchars*posn + padidx] += BIG_NUMBER 40 | 41 | def __exit__(self, *args): 42 | padidx = self.model.codec.char_lookup[self.model.codec.filler] 43 | for posn, bias in enumerate(self.prev_biases): 44 | self.model.intercept_visible_[self.model.codec.nchars*posn + padidx] = bias 45 | 46 | 47 | 48 | class VisInit(enum.Enum): 49 | """Ways of initializing visible units before repeated gibbs sampling.""" 50 | # All zeros. Should be basically equivalent to deferring to the *hidden* biases. 51 | zeros = 1 52 | # Treat visible biases as softmax 53 | biases = 2 54 | # Turn on each unit (not just each one-hot vector) with p=.5 55 | uniform = 3 56 | spaces = 4 57 | padding = 7 # Old models use ' ' as filler, making this identical to the above 58 | # Training examples 59 | train = 5 60 | # Choose a random length. Fill in that many uniformly random chars. Fill the rest with padding character. 61 | chunks = 6 62 | # Use training examples but randomly mutate non-space/padding characters. Only the "shape" is preserved. 63 | silhouettes = 8 64 | # Valid one-hot vectors, each chosen uniformly at random 65 | uniform_chars = 9 66 | 67 | class BadInitMethodException(Exception): 68 | pass 69 | 70 | def starting_visible_configs(init_method, n, model, training_examples_fname=None): 71 | """Return an ndarray of n visible configurations for the given model 72 | according to the specified init method (which should be a member of the VisInit enum) 73 | """ 74 | vis_shape = (n, model.intercept_visible_.shape[0]) 75 | maxlen, nchars = model.codec.maxlen, model.codec.nchars 76 | if init_method == VisInit.biases: 77 | sm = np.tile(model.intercept_visible_, [n, 1]).reshape( (-1,) + model.codec.shape() ) 78 | return Utils.softmax_and_sample(sm).reshape(vis_shape) 79 | elif init_method == VisInit.zeros: 80 | return np.zeros(vis_shape) 81 | elif init_method == VisInit.uniform: 82 | return np.random.randint(0, 2, vis_shape) 83 | # This will fail if ' ' isn't in the alphabet of this model 84 | elif init_method == VisInit.spaces or init_method == VisInit.padding: 85 | fillchar = {VisInit.spaces: ' ', VisInit.padding: model.codec.filler}[init_method] 86 | vis = np.zeros( (n,) + model.codec.shape()) 87 | try: 88 | fill = model.codec.char_lookup[fillchar] 89 | except KeyError: 90 | raise BadInitMethodException(fillchar + " is not in model alphabet") 91 | 92 | vis[:,:,fill] = 1 93 | return vis.reshape(vis_shape) 94 | elif init_method == VisInit.train or init_method == VisInit.silhouettes: 95 | assert training_examples_fname is not None, "No training examples provided to initialize with" 96 | mutagen = model.codec.mutagen_silhouettes if init_method == VisInit.silhouettes else None 97 | examples = Utils.vectors_from_txtfile(training_examples_fname, model.codec, limit=n, mutagen=mutagen) 98 | return examples 99 | elif init_method == VisInit.chunks or init_method == VisInit.uniform_chars: 100 | # This works, but probably isn't idiomatic numpy. 101 | # I don't think I'll ever write idiomatic numpy. 102 | 103 | # Start w uniform dist 104 | char_indices = np.random.randint(0, nchars, (n,maxlen)) 105 | if init_method == VisInit.chunks: 106 | # Choose some random lengths 107 | lengths = np.clip(maxlen*.25 * np.random.randn(n) + (maxlen*.66), 1, maxlen 108 | ).astype('int8').reshape(n, 1) 109 | _, i = np.indices((n, maxlen)) 110 | char_indices[i>=lengths] = model.codec.char_lookup[model.codec.filler] 111 | 112 | # TODO: This is a useful little trick. Make it a helper function and reuse it elsewhere? 113 | return np.eye(nchars)[char_indices.ravel()].reshape(vis_shape) 114 | else: 115 | raise ValueError("Unrecognized init method: {}".format(init_method)) 116 | 117 | 118 | def print_sample_callback(sample_strings, i, energy=None): 119 | if energy is not None: 120 | print "\n".join('{}\t{:.2f}'.format(t[0], t[1]) for t in zip(sample_strings, energy)) 121 | else: 122 | print "\n".join(sample_strings) 123 | print 124 | 125 | @Utils.timeit 126 | def sample_model(model, n, iters, sample_iter_indices, 127 | start_temp=1.0, final_temp=1.0, 128 | callback=print_sample_callback, init_method=VisInit.biases, training_examples=None, 129 | sample_energy=False, starting_vis=None, min_length=0, max_length=0, 130 | ): 131 | if callback is None: 132 | callback = lambda: None 133 | if starting_vis is not None: 134 | vis = starting_vis 135 | else: 136 | vis = starting_visible_configs(init_method, n, model, training_examples) 137 | 138 | args = [model, vis, iters, sample_iter_indices, start_temp, final_temp, callback, sample_energy] 139 | if min_length or max_length: 140 | if max_length == 0: 141 | max_length = model.codec.maxlen 142 | with shrink_model(model, min_length, max_length): 143 | return _sample_model(*args) 144 | else: 145 | return _sample_model(*args) 146 | 147 | def _sample_model(model, vis, iters, sample_iter_indices, start_temp, final_temp, callback, 148 | sample_energy): 149 | 150 | temp = start_temp 151 | temp_decay = (final_temp/start_temp)**(1/iters) 152 | temp_delta = (final_temp-start_temp)/iters 153 | next_sample_metaindex = 0 154 | for i in range(iters): 155 | if i == sample_iter_indices[next_sample_metaindex]: 156 | # Time to take samples 157 | sample_strings = [model.codec.decode(v, pretty=True, strict=False) for v in vis] 158 | if sample_energy: 159 | energy = model._free_energy(vis) 160 | callback(sample_strings, i, energy) 161 | else: 162 | callback(sample_strings, i) 163 | next_sample_metaindex += 1 164 | if next_sample_metaindex == len(sample_iter_indices): 165 | break 166 | vis = model.gibbs(vis, temp) 167 | if LINEAR_ANNEAL: 168 | temp += temp_delta 169 | else: 170 | temp *= temp_decay 171 | return vis 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser(description='Sample short texts from a pickled model', 176 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 177 | parser.add_argument('model_fname', metavar='model.pickle', nargs='+', 178 | help='One or more pickled RBM models') 179 | parser.add_argument('-n', '--n-samples', dest='n_samples', type=int, default=30, 180 | help='How many samples to draw') 181 | parser.add_argument('-i', '--iters', dest='iters', type=int, default=10**4, 182 | help='How many rounds of Gibbs sampling to perform before generating the outputs') 183 | parser.add_argument('--prog', '--progressively-sample', dest='prog', action='store_true', 184 | help='Output n samples after 0 rounds of sampling, then 1, 10, 100, 1000... until we reach a power of 10 >=iters') 185 | parser.add_argument('--init', '--init-method', dest='init_method', default='silhouettes', help="How to initialize vectors before sampling") 186 | parser.add_argument('--energy', action='store_true', help='Along with each sample generated, print its free energy') 187 | parser.add_argument('--every', type=int, default=None, help='Sample once every this many iters. Incompatible with --prog and --table.') 188 | 189 | args = parser.parse_args() 190 | 191 | args.init_method = VisInit[args.init_method] 192 | 193 | for model_fname in args.model_fname: 194 | print "Drawing samples from model defined at {}".format(model_fname) 195 | f = open(model_fname) 196 | model = pickle.load(f) 197 | f.close() 198 | # TODO: add as arg 199 | if 'usgeo' in model_fname: 200 | example_file = 'data/usgeo.txt' 201 | elif 'reponames' in model_fname: 202 | example_file = 'data/reponames.txt' 203 | elif 'names' in model_fname: 204 | example_file = 'data/names2.txt' 205 | 206 | -------------------------------------------------------------------------------- /ShortTextCodec.py: -------------------------------------------------------------------------------- 1 | from sklearn.utils import issparse 2 | import numpy as np 3 | import random 4 | 5 | 6 | class NonEncodableTextException(Exception): 7 | 8 | def __init__(self, reason=None, *args): 9 | self.reason = reason 10 | super(NonEncodableTextException, self).__init__(*args) 11 | 12 | 13 | class ShortTextCodec(object): 14 | # TODO: problematic if this char appears in the training text 15 | FILLER = '$' 16 | 17 | # If a one-hot vector can't be decoded meaningfully, render this char in its place 18 | MYSTERY = '?' 19 | 20 | # Backward-compatibility. Was probably a mistake to have FILLER be a class var rather than instance 21 | @property 22 | def filler(self): 23 | if self.__class__.FILLER in self.alphabet: 24 | return self.__class__.FILLER 25 | # Old versions of this class used ' ' as filler 26 | return ' ' 27 | 28 | def __init__(self, extra_chars, maxlength, minlength=0, preserve_case=False, leftpad=False): 29 | assert 0 <= minlength <= maxlength 30 | if self.FILLER not in extra_chars and maxlength != minlength: 31 | extra_chars = self.FILLER + extra_chars 32 | self.maxlen = maxlength 33 | self.minlen = minlength 34 | self.char_lookup = {} 35 | self.leftpad_ = leftpad 36 | self.alphabet = '' 37 | for i, o in enumerate(range(ord('a'), ord('z') + 1)): 38 | self.char_lookup[chr(o)] = i 39 | self.alphabet += chr(o) 40 | nextidx = len(self.alphabet) 41 | for i, o in enumerate(range(ord('A'), ord('Z') + 1)): 42 | if preserve_case: 43 | self.char_lookup[chr(o)] = nextidx 44 | nextidx += 1 45 | self.alphabet += chr(o) 46 | else: 47 | self.char_lookup[chr(o)] = i 48 | 49 | offset = len(self.alphabet) 50 | for i, extra in enumerate(extra_chars): 51 | self.char_lookup[extra] = i + offset 52 | self.alphabet += extra 53 | 54 | def debug_description(self): 55 | return ' '.join('{}={}'.format(attr, repr(getattr(self, attr, None))) for attr in ['maxlen', 'minlen', 'leftpad', 'alphabet', 'nchars']) 56 | 57 | @property 58 | def leftpad(self): 59 | return getattr(self, 'leftpad_', False) 60 | 61 | @property 62 | def nchars(self): 63 | return len(self.alphabet) 64 | 65 | @property 66 | def non_special_char_alphabet(self): 67 | return ''.join(c for c in self.alphabet if (c != ' ' and c != self.FILLER)) 68 | 69 | def _encode(self, s, padlen): 70 | if len(s) > padlen: 71 | raise NonEncodableTextException(reason='toolong') 72 | padding = [self.char_lookup[self.filler] for _ in range(padlen - len(s))] 73 | try: 74 | payload = [self.char_lookup[c] for c in s] 75 | except KeyError: 76 | raise NonEncodableTextException(reason='illegal_char') 77 | if self.leftpad: 78 | return padding + payload 79 | else: 80 | return payload + padding 81 | 82 | 83 | def encode(self, s, mutagen=None): 84 | if len(s) > self.maxlen: 85 | raise NonEncodableTextException(reason='toolong') 86 | elif (hasattr(self, 'minlen') and len(s) < self.minlen): 87 | raise NonEncodableTextException(reason='tooshort') 88 | if mutagen: 89 | s = mutagen(s) 90 | return self._encode(s, self.maxlen) 91 | 92 | def encode_onehot(self, s): 93 | indices = self.encode(s) 94 | return np.eye(self.nchars)[indices].ravel() 95 | 96 | def decode(self, vec, pretty=False, strict=True): 97 | # TODO: Whether we should use 'strict' mode depends on whether the model 98 | # we got this vector from does softmax sampling of visibles. Anywhere this 99 | # is called on fantasy samples, we should use the model to set this param. 100 | if issparse(vec): 101 | vec = vec.toarray().reshape(-1) 102 | assert vec.shape == (self.nchars * self.maxlen,) 103 | chars = [] 104 | for position_index in range(self.maxlen): 105 | # Hack - insert a tab between name parts in binomial mode 106 | if isinstance(self, BinomialShortTextCodec) and pretty and position_index == self.maxlen/2: 107 | chars.append('\t') 108 | subarr = vec[position_index * self.nchars:(position_index + 1) * self.nchars] 109 | if np.count_nonzero(subarr) != 1 and strict: 110 | char = self.MYSTERY 111 | else: 112 | char_index = np.argmax(subarr) 113 | char = self.alphabet[char_index] 114 | if pretty and char == self.FILLER: 115 | # Hack 116 | char = ' ' if isinstance(self, BinomialShortTextCodec) else '' 117 | chars.append(char) 118 | return ''.join(chars) 119 | 120 | def shape(self): 121 | """The shape of a set of RBM inputs given this codecs configuration.""" 122 | return (self.maxlen, len(self.alphabet)) 123 | 124 | def mutagen_nudge(self, s): 125 | # Mutate a single character chosen uniformly at random. 126 | # If s is shorter than the max length, include an extra virtual character at the end 127 | i = random.randint(0, min(len(s), self.maxlen-1)) 128 | def roll(forbidden): 129 | newchar = random.choice(self.alphabet) 130 | while newchar in forbidden: 131 | newchar = random.choice(self.alphabet) 132 | return newchar 133 | 134 | if i == len(s): 135 | return s + roll(self.FILLER + ' ') 136 | if i == len(s)-1: 137 | replacement = roll(' ' + s[-1]) 138 | if replacement == self.FILLER: 139 | return s[:-1] 140 | return s[:-1] + roll(' ' + s[-1]) 141 | else: 142 | return s[:i] + roll(s[i] + self.FILLER) + s[i+1:] 143 | 144 | 145 | def mutagen_silhouettes(self, s): 146 | newchars = [] 147 | for char in s: 148 | if char == ' ': 149 | newchars.append(char) 150 | else: 151 | newchars.append(random.choice(self.non_special_char_alphabet)) 152 | return ''.join(newchars) 153 | 154 | def mutagen_noise(self, s): 155 | return ''.join(random.choice(self.alphabet) for _ in range(self.maxlen)) 156 | 157 | class BinomialShortTextCodec(ShortTextCodec): 158 | """Encodes two-part names (e.g. "John Smith"), padding each part separately 159 | to the same length. (Presumed to help learning.) 160 | """ 161 | 162 | def __init__(self, *args, **kwargs): 163 | super(BinomialShortTextCodec, self).__init__(*args, **kwargs) 164 | self.separator = ',' 165 | # Hack: require maxlen to be even, and give each part of the name 166 | # an equal share 167 | assert self.maxlen % 2 == 0, "Maxlen must be even for binomial codec" 168 | 169 | def encode(self, s, mutagen=None): 170 | namelen = self.maxlen / 2 171 | if self.separator not in s: 172 | first = s 173 | last = '' 174 | else: 175 | try: 176 | last, first = s.split(self.separator) 177 | except ValueError: 178 | raise NonEncodableTextException(reason='too many separators') 179 | last = last.strip() 180 | first = first.strip() 181 | if mutagen: 182 | first = mutagen(first) 183 | last = mutagen(last) 184 | return self._encode(first, namelen) + self._encode(last, namelen) 185 | 186 | # We don't really need to override decode(). It should do basically the right 187 | # thing (modulo some funny spacing) 188 | 189 | # TODO: Probably *do* need to override some or all mutagen methods. Leaving 190 | # them for now since they're only necessary for evaluation. 191 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import logging 4 | from collections import Counter 5 | 6 | from ShortTextCodec import NonEncodableTextException 7 | 8 | from sklearn.preprocessing import OneHotEncoder 9 | 10 | DEBUG_TIMING = False 11 | 12 | # Taken from StackOverflow 13 | def timeit(f): 14 | if not DEBUG_TIMING: 15 | return f 16 | 17 | def timed(*args, **kw): 18 | 19 | ts = time.time() 20 | result = f(*args, **kw) 21 | te = time.time() 22 | 23 | print 'func:%r took: %2.4f sec' % \ 24 | (f.__name__, te - ts) 25 | return result 26 | 27 | return timed 28 | 29 | def vectors_from_txtfile(fname, codec, limit=-1, mutagen=None): 30 | f = open(fname) 31 | skipped = Counter() 32 | vecs = [] 33 | for line in f: 34 | line = line.strip() 35 | try: 36 | vecs.append(codec.encode(line, mutagen=mutagen)) 37 | if len(vecs) == limit: 38 | break 39 | except NonEncodableTextException as e: 40 | # Too long, or illegal characters 41 | skipped[e.reason] += 1 42 | 43 | logging.debug("Gathered {} vectors. Skipped {} ({})".format(len(vecs), 44 | sum(skipped.values()), dict(skipped))) 45 | vecs = np.asarray(vecs) 46 | # TODO: Why default to dtype=float? Seems wasteful? Maybe it doesn't really matter. Actually, docs here seem inconsistent? Constructor docs say default float. transform docs say int. Should file a bug on sklearn. 47 | return OneHotEncoder(len(codec.alphabet)).fit_transform(vecs) 48 | 49 | # Adapted from sklearn.utils.extmath.softmax 50 | def softmax(X, copy=True): 51 | if copy: 52 | X = np.copy(X) 53 | X_shape = X.shape 54 | a, b, c = X_shape 55 | # This will cause overflow when large values are exponentiated. 56 | # Hence the largest value in each row is subtracted from each data 57 | max_prob = np.max(X, axis=2).reshape((X.shape[0], X.shape[1], 1)) 58 | X -= max_prob 59 | np.exp(X, X) 60 | sum_prob = np.sum(X, axis=2).reshape((X.shape[0], X.shape[1], 1)) 61 | X /= sum_prob 62 | return X 63 | 64 | def softmax_and_sample(X, copy=True): 65 | """ 66 | Given an array of 2-d arrays, each having shape (M, N) representing M softmax 67 | units with N possible values each, return an array of the same shape where 68 | each N-dimensional inner array has a 1 at one index, and zero everywhere 69 | else. The 1 is assigned according to the corresponding softmax probabilities 70 | (i.e. np.exp(X) / np.sum(np.exp(X)) ) 71 | 72 | Parameters 73 | ---------- 74 | X: array-like, shape (n_samples, M, N), dtype=float 75 | Argument to the logistic function 76 | copy: bool, optional 77 | Copy X or not. 78 | Returns 79 | ------- 80 | out: array of 0,1, shape (n_samples, M, N) 81 | Softmax function evaluated at every point in x and sampled 82 | """ 83 | a,b,c = X.shape 84 | X_shape = X.shape 85 | X = softmax(X, copy) 86 | # We've got our probabilities, now sample from them 87 | thresholds = np.random.rand(X.shape[0], X.shape[1], 1) 88 | cumsum = np.cumsum(X, axis=2, out=X) 89 | x, y, z = np.indices(cumsum.shape) 90 | # This relies on the fact that, if there are multiple instances of the max 91 | # value in an array, argmax returns the index of the first one 92 | to_select = np.argmax(cumsum > thresholds, axis=2).reshape(a, b, 1) 93 | bin_sample = np.zeros(X_shape) 94 | bin_sample[x, y, to_select] = 1 95 | 96 | return bin_sample 97 | -------------------------------------------------------------------------------- /compare_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import pickle 4 | import Utils 5 | import csv 6 | import os 7 | import sklearn.metrics.pairwise 8 | from sklearn.utils.extmath import log_logistic 9 | 10 | from short_text_codec import BinomialShortTextCodec 11 | 12 | Utils.DEBUG_TIMING = True 13 | 14 | FIELDS = (['nchars', 'minlen', 'maxlen', 'nhidden', 'batch_size', 'epochs', 'weight_cost',] 15 | + ['pseudol9'] 16 | + ['{}_{}'.format(metric, mut) for metric in ('LR', 'Err',) 17 | for mut in ('nudge', 'sil', 'noise')] 18 | + ['recon_error', 'filler', 'name', 'grade'] 19 | ) 20 | 21 | FORCE_MINLEN = False 22 | 23 | SUBDIR_TO_SCORE = { 24 | 'bad': 1, 25 | 'okay': 2, 26 | 'good': 3, 27 | 'great': 4, 28 | } 29 | 30 | @Utils.timeit 31 | def eval_model(model, trainfile, n): 32 | row = {'name': model.name} 33 | # Comparing models with different codec params seems problematic when they change the 34 | # set of examples each model is looking at (some strings will be too short/long for one 35 | # model but not another). This could introduce a systematic bias where some models get 36 | # strings that are a little easier or harder. Quick experiment performed to clamp minlen 37 | # and maxlen to a shared middle ground for all models. Didn't really affect ranking. 38 | codec = model.codec 39 | if FORCE_MINLEN: 40 | old_minlen = getattr(codec, 'minlen', None) 41 | codec.minlen = FORCE_MINLEN 42 | row['minlen'] = '{} ({})'.format(old_minlen, codec.minlen) 43 | else: 44 | row['minlen'] = getattr(codec, 'minlen', None) 45 | 46 | row['nchars'] = codec.nchars 47 | row['maxlen'] = codec.maxlen 48 | row['nhidden'] = model.intercept_hidden_.shape[0] 49 | row['filler'] = codec.filler 50 | row['batch_size'] = model.batch_size 51 | # TODO: Not accurate for incrementally trained models 52 | row['epochs'] = model.n_iter 53 | row['weight_cost'] = getattr(model, 'weight_cost', 'NA') 54 | row['grade'] = getattr(model, 'grade', '?') 55 | 56 | 57 | # The untainted vectorizations 58 | good = Utils.vectors_from_txtfile(trainfile, codec, n) 59 | good_energy = model._free_energy(good) 60 | row['pseudol9'] = model.score_samples(good).mean() 61 | for name, mutagen in [ ('nudge', codec.mutagen_nudge), 62 | ('sil', codec.mutagen_silhouettes), 63 | ('noise', codec.mutagen_noise), 64 | ]: 65 | bad = Utils.vectors_from_txtfile(trainfile, codec, n, mutagen) 66 | bad_energy = model._free_energy(bad) 67 | 68 | # TODO: too lazy to implement 69 | if isinstance(model.codec, BinomialShortTextCodec): 70 | break 71 | 72 | # log-likelihood ratio 73 | # This is precisely log(P_model(good)/P_model(bad)) 74 | # i.e. according to the model, how much more likely is the authentic data compared to the noised version? 75 | # Which seems like a really useful thing to know, but actually gives results that are pretty counterintuitive. 76 | # Some models score *really* well under this metric, but very poorly on the 'error rate' metric below and 77 | # on pseudo-likelihood. In fact, this metric seems to be inversely correlated with success on other metrics. 78 | # It's not clear to me why this is. My vague hypothesis is that models unconstrained by weight costs have 79 | # learned to associate really-really-really high (relative) energy to certain configurations. So the good 80 | # models 'win' more often (assigning lower energy to authentic examples), but the bad models sometimes win 81 | # by a lot more. (And for our purposes, maybe this shouldn't really be worth many more points. We just want 82 | # the strings we get from sampling to be reasonable, and for unreasonable strings to have high enough energy 83 | # that we won't encounter them. Whether "asdasdsf" gets HIGH_ENERGY, or HIGH_ENERGY x 10^100 doesn't really 84 | # matter to us.) 85 | # The connection to KL-divergence here is interesting. If we say P is the model distribution over the training 86 | # data and Q is the corresponding distribution which 'sees' the noised version (i.e. Q(v) := P(mutate(v))) 87 | # then D_KL(P||Q) = \sum{v} P(v) * (energy(mutate(v)) - energy(v)) 88 | # So our log-likelihood ratio is identical to KL-divergence except for the P(v) term (which is of course intractable). 89 | # The 'beating a dead horse' hypothesis is consistent with the 'bad' models having low KL-divergence in 90 | # spite of having a better log-likelihood ratio. These models may be assigning low absolute probabilities 91 | # to the training examples, but even lower probabilities to the mutants. So -log(P(v)) = 10^10^100, -log(P(mutate(v))) = 92 | # 10^10^200 isn't worth that many points, because the large ratio is tempered by the low P(v). 93 | # One could hypothesize that the opposite phenomenon is occurring, and the bad models are overfit to the 94 | # training data, and have learned to assign precisely those strings very low energy. But using test data 95 | # (or even a different dataset similar to the training data - e.g. testing on Canadian geo names models trained 96 | # on US geo names) results in the same rankings. 97 | row['LR_{}'.format(name)] = (bad_energy - good_energy).mean() 98 | 99 | # "Error rate" (how often is lower energy assigned to the evil twin) 100 | row['Err_{}'.format(name)] = 100 * (bad_energy < good_energy).sum() / float(n) 101 | 102 | goodish = model.gibbs(good) 103 | row['recon_error'] = sklearn.metrics.pairwise.paired_distances(good, goodish).mean() 104 | #goodisher = model.repeated_gibbs(good, 20) 105 | #row['mix_20'] = sklearn.metrics.pairwise.paired_distances(good, goodisher).mean() 106 | # TODO: This is too slow. Like 20 minutes per model. 107 | #goodish = model.repeated_gibbs(goodisher, 200) 108 | #row['mix_200'] = 0.0 #sklearn.metrics.pairwise.paired_distances(goodisher, goodish).mean() 109 | for k in row: 110 | if isinstance(row[k], float): 111 | # Why doesn't this work? 112 | if 0 < abs(row[k]) < 10**(-3): 113 | fmt_string = '{:.1E}' 114 | else: 115 | fmt_string = '{:.4g}' 116 | row[k] = fmt_string.format(row[k]) 117 | # By default, None is rendered as empty string, which messes up column output 118 | elif row[k] is None: 119 | row[k] = 'NA' 120 | elif row[k] == '': 121 | row[k] = "''" 122 | elif row[k] == ' ': 123 | row[k] = "" 124 | return row 125 | 126 | if __name__ == '__main__': 127 | # TODO: "Append" mode so we don't have to do a bunch of redundant calculations when we add one or two new models 128 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 129 | parser.add_argument('models', metavar='model', nargs='+', help='Pickled RBM models') 130 | parser.add_argument('trainfile', help='File with training examples') 131 | parser.add_argument('-a', '--append', action='store_true', help='If there exists a model_comparison_.csv' 132 | + ' file, append a row for each model file passed in, rather than clobbering. Does not ' 133 | + 'attempt to dedupe rows.') 134 | parser.add_argument('-t', '--tag', default='', help='A tag to append to the output csv filename') 135 | parser.add_argument('-n', type=int, default=10**4, help="Number of samples to average over." + 136 | "Default is pretty fast and, anecdotally, seems to give pretty reliable results." 137 | + " Increasing it by a factor of 5-10 doesn't change much.") 138 | args = parser.parse_args() 139 | 140 | if args.trainfile.endswith('.pickle'): 141 | print "trainfile is mandatory" 142 | parser.print_usage() 143 | sys.exit(1) 144 | 145 | models = [] 146 | for fname in args.models: 147 | if os.path.isdir(fname): 148 | print "Received directory. Assuming this contains subdirs /bad, /okay, /good, /great with pickles" 149 | for dirname, _, fnames in os.walk(fname): 150 | leafdir = dirname.split(os.path.sep)[-1] 151 | try: 152 | grade = SUBDIR_TO_SCORE[leafdir] 153 | except KeyError: 154 | print "Ignoring unrecognized subdir {}".format(dirname) 155 | continue 156 | for fname in fnames: 157 | path = os.path.join(dirname, fname) 158 | f = open(path) 159 | model = pickle.load(f) 160 | f.close() 161 | model.name = os.path.basename(fname) 162 | model.grade = grade 163 | models.append(model) 164 | 165 | 166 | else: 167 | f = open(fname) 168 | models.append(pickle.load(f)) 169 | models[-1].name = os.path.basename(fname) 170 | f.close() 171 | 172 | # We could try to be efficient and only load the training data once for all models 173 | # But then we would need to require that all models passed in use equivalent codecs 174 | # Or do something clever to only load n times for n distinct codecs 175 | # Let's just do the dumb thing for now 176 | if not os.path.exists('model_comparisons/'): 177 | print "Creating model_comparisons dir" 178 | os.mkdir("model_comparisons") 179 | outname = 'model_comparisons/model_comparison_{}.csv'.format(args.tag) 180 | append = args.append 181 | if append and not os.path.exists(outname): 182 | print "WARNING: received append option, but found no existing file {}".format(outname) 183 | append = False 184 | f = open(outname, 'a' if append else 'w') 185 | writer = csv.DictWriter(f, FIELDS, delimiter='\t') 186 | if not append: 187 | writer.writeheader() 188 | 189 | for i, model in enumerate(models): 190 | print "Evaluating {} [{}/{}]".format(model.name, i+1, len(models)) 191 | row = eval_model(model, args.trainfile, args.n) 192 | writer.writerow(row) 193 | print 194 | 195 | f.close() 196 | print "Wrote results to " + outname 197 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==0.17.1 2 | enum34==1.1.6 3 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import Sampling 2 | import sys 3 | import pickle 4 | import argparse 5 | import colorama 6 | colorama.init() 7 | 8 | SAMPLES = [] 9 | def horizontal_cb(strings, i, energy=None): 10 | global SAMPLES 11 | if energy is not None: 12 | SAMPLES.append(zip(strings, energy)) 13 | else: 14 | SAMPLES.append(strings) 15 | 16 | DEDUPE_SEEN = [] 17 | def dedupe_cb(strings, i, energy=None): 18 | global DEDUPE_SEEN 19 | if not DEDUPE_SEEN: 20 | DEDUPE_SEEN = [set() for _ in strings] 21 | for i in range(len(strings)): 22 | if strings[i] in DEDUPE_SEEN[i]: 23 | continue 24 | print strings[i] + "\t" + ("{:.2f}".format(energy[i]) if energy is not None else "") 25 | DEDUPE_SEEN[i].add(strings[i]) 26 | print 27 | 28 | def bold(s): 29 | return "\033[31m" + s + "\033[0m" 30 | 31 | def print_columns(maxlen): 32 | col_width = maxlen+2 33 | for fantasy_index in range(len(SAMPLES[0])): 34 | particles = [s[fantasy_index] for s in SAMPLES] 35 | if args.energy: 36 | min_energy = min(particles, key=lambda tup: tup[1]) 37 | print "".join( 38 | bold(p[0].ljust(col_width)) if p == min_energy 39 | else p[0].ljust(col_width) 40 | for p in particles) 41 | else: 42 | print "".join(s[fantasy_index].ljust(col_width) for s in SAMPLES) 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='Sample short texts from a pickled model', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 47 | parser.add_argument('model_fname', metavar='model.pickle', nargs='+', 48 | help='One or more pickled RBM models') 49 | parser.add_argument('--every', type=int, default=-1, help='How often to sample.' + 50 | ' If -1 (default) only sample after the last iteration.') 51 | parser.add_argument('-n', '--n-samples', dest='n_samples', type=int, default=30, 52 | help='How many samples to draw') 53 | parser.add_argument('-f', '--first', dest='first', type=int, default=-1, 54 | help='Which iteration to draw the first sample at ' + 55 | '(if --every is provided and this is not, defaults to --every)') 56 | parser.add_argument('-i', '--iters', dest='iters', type=int, default=10**3, 57 | help='How many rounds of Gibbs sampling to perform') 58 | parser.add_argument('--energy', action='store_true', help='Along with each sample generated, print its free energy') 59 | parser.add_argument('-s', '--start-temp', dest='start_temp', type=float, default=1.0, help="Temperature for first iteration") 60 | parser.add_argument('-e', '--end-temp', dest='end_temp', type=float, default=1.0, help="Temperature at last iteration") 61 | parser.add_argument('--no-col', dest='columns', action='store_false') 62 | parser.add_argument('--dedupe', action='store_true') 63 | parser.add_argument('--sil', help='data file for silhouettes') 64 | 65 | args = parser.parse_args() 66 | 67 | 68 | for model_fname in args.model_fname: 69 | if len(args.model_fname) > 1 or not args.columns: 70 | print "Drawing samples from model defined at {}".format(model_fname) 71 | f = open(model_fname) 72 | model = pickle.load(f) 73 | f.close() 74 | 75 | if args.every == -1: 76 | sample_indices = [args.iters-1] 77 | else: 78 | first = args.every if args.first == -1 else args.first 79 | sample_indices = range(first, args.iters, args.every) 80 | if sample_indices[-1] != args.iters - 1: 81 | sample_indices.append(args.iters-1) 82 | 83 | if args.columns: 84 | cb = horizontal_cb 85 | elif args.dedupe: 86 | cb = dedupe_cb 87 | else: 88 | cb = Sampling.print_sample_callback 89 | 90 | kwargs = dict(start_temp=args.start_temp, final_temp=args.end_temp, sample_energy=args.energy, 91 | callback=cb) 92 | if args.sil: 93 | kwargs['init_method'] = Sampling.VisInit.silhouettes 94 | kwargs['training_examples'] = args.sil 95 | 96 | vis = Sampling.sample_model(model, args.n_samples, args.iters, sample_indices, **kwargs) 97 | 98 | if args.columns: 99 | print_columns(model.codec.maxlen) 100 | 101 | if args.energy: 102 | fe = model._free_energy(vis) 103 | sys.stderr.write('Final energy: {:.2f} (stdev={:.2f})\n'.format(fe.mean(), fe.std())) 104 | 105 | -------------------------------------------------------------------------------- /samples/README.markdown: -------------------------------------------------------------------------------- 1 | This directory contains a bunch of samples drawn from models trained on a few different datasets. `foo.txt` has all the sampled names with repetitions removed. `foo_unique.txt` is deduped against the corresponding training set. 2 | 3 | Details of each model are below. 4 | 5 | # Actors 6 | 7 | ### Model 8 | 9 | - 180 hidden units 10 | - trained for 20 epochs with LR=.1 decaying linearly per epoch 11 | - batch size = 20 12 | - alphabet = `[a-z]$ .-` 13 | - used specialized "binomial" codec (see `short_text_codec.py`) 14 | 15 | ### Sampling 16 | 17 | Sampled with simulated annealing going from T=1.3 to T=0.3 over 800 iterations. 18 | 19 | python sample_every.py --dedupe -s 1.3 -n 4000 -e 0.3 -i 800 --energy $model 200 20 | 21 | ### Deduping 22 | 23 | Out of 7,600 generated names... 24 | 25 | - 5 exactly match a name in the training set 26 | - 40 out of 775 distinct first names exist in the training set (n=1.5m) 27 | - 32 out of 3,454 distinct last names exist in the training set 28 | 29 | # US place names 30 | 31 | ### Model 32 | 33 | - 350 hidden units 34 | - 20 + 20 epochs on usgeo dataset with max length = 20 35 | - lr=.05 in the first round, then .00 in second. decayed during each run. 36 | - batch size 20 in first round then 40 37 | - weight cost of 0.0001 38 | 39 | ### Sampling 40 | 41 | Sampled with simulated annealing going from T=1.0 to T=0.3 over 2k iterations. Started from 'silhouettes' of training data (see `VisInit` enum in `sampling.py`). 42 | 43 | python sample_every.py --sil data/usgeo.txt --no-col --dedupe --energy -f 250 -s 1.0 -e 0.3 -i 2000 -n 10000 $model 250 44 | # dedupe and filter by energy < -145 45 | 46 | ### Deduping 47 | 48 | 2,920 / 42,709 generated place names exist in the training set (n=700k). 49 | 50 | # GitHub repositories 51 | 52 | ### Model 53 | 54 | - 350 hidden units 55 | - trained for 4 rounds of 20 epochs 56 | - LR for final round = .001 57 | - maxlen = 20, minlen = 6, alphabet is case sensitive and includes all special chars in dataset (nchars=66) 58 | 59 | ### Sampling 60 | 61 | Sampled with simulated annealing going from T=1.5 to T=0.2 over 1k iterations. Particles initialized randomly according to biases of the visible units. 62 | 63 | python sample_every.py -s 1.5 -e 0.2 -i 1000 -f 400 --energy --no-col --dedupe -n 10000 $model 100 64 | 65 | ### Deduping 66 | 67 | 2,605 / 20,000 generated repo names exist in the training set (n=3.7m) 68 | 69 | # Games 70 | 71 | ### Model 72 | 73 | - 250 hidden units 74 | - alphabet = `[a-z][0-9]$ !-:&.'` 75 | - trained for 80+10 epochs 76 | 77 | ### Sampling 78 | 79 | Sampled with simulated annealing going from T=1.0 to T=0.2 over 250 iterations. Started from 'silhouettes' of training data (see `VisInit` enum in `sampling.py`). 80 | 81 | python sample_every.py -n 10000 -f 100 -i 250 -s 1.0 -e 0.2 --sil data/games.txt $model 50 82 | 83 | ### Deduping 84 | 85 | 34 / 3,490 generated games exist in the training set (n=80k) 86 | -------------------------------------------------------------------------------- /samples/games_unique.txt: -------------------------------------------------------------------------------- 1 | stopeest game 2 | chef ths gome 3 | the mitean game 4 | eleppitt care game 5 | chipling gome 6 | the sidal sgame 7 | elepes on the game 8 | the hing board game 9 | black wory 10 | conquest pame 11 | hocket'& pace 12 | charit's board game 13 | spop the gime 14 | 5tla: the card game 15 | brauk 16 | spel the dige 17 | green tibtard game 18 | the gogat wor 19 | ale mathen game 20 | blanz 21 | pocket's ragitg 22 | chafling gamise 23 | exerestor wor game 24 | the monien game 25 | the stre game 26 | the nation gome 27 | the lase game 28 | comp the game 29 | plaxk 30 | the mandon game 31 | ove nichecken game 32 | quizd 33 | the mivil game 34 | coppetst bame 35 | the woblar card game 36 | the rine game 37 | the bine game 38 | one mitheckon game 39 | chen ihe game 40 | chickiny care game 41 | catelite card game 42 | croy che cord game 43 | pocket quizs 44 | the minil wun game 45 | glidk 46 | the miviy war game 47 | ster the game 48 | empere of the ange 49 | the pothe bcard game 50 | blook 51 | pro: the cord game 52 | tlank 53 | constery gome 54 | the mibaon game 55 | sved the bige 56 | elepes on the glee 57 | packet's pick 58 | eleemitor whe game 59 | pocket poce 60 | clax 61 | the stee game 62 | stot the bame 63 | black wirs 64 | the mithe wal game 65 | triends of fy game 66 | chimling game 67 | quibk 68 | uno: the goven 69 | eleepite card game 70 | triends of or game 71 | the wing game 72 | 10 f:ict board game 73 | huppy sack card game 74 | operation wert 75 | monstera game 76 | the pite care game 77 | triemat of the poal 78 | das sthe picesspiel 79 | the cane game 80 | the milell game 81 | black worl 82 | the file game 83 | empers of the ange 84 | sved the bime 85 | trind 86 | the nithe bor game 87 | the rane game 88 | elepes of the geme 89 | the mine gad game 90 | the fuce game 91 | stad the bagese 92 | brakd 93 | sour the game 94 | triends of ty gace 95 | glauk 96 | uvo: the cord game 97 | the mituin game 98 | wher the game 99 | plep the card game 100 | uno: the bace 101 | qlabk 102 | the nical cgame 103 | 1050: c0 board game 104 | over thonr ch game 105 | gree the game 106 | tera the gome 107 | chip the gome 108 | the mitis wol game 109 | chicling gade 110 | chicling manise 111 | friendly fire pack 112 | hoon the bige 113 | chushin' board game 114 | erepes of the glee 115 | the wire game 116 | popperst marden game 117 | the bine gome 118 | shores hick 119 | the mare game 120 | globel on the game 121 | the sinele game 122 | over the gome 123 | emples on the game 124 | the great wols 125 | glay 126 | operation wougame 127 | who: the board game 128 | trie the popile 129 | monsterpolo 130 | conquest comn 131 | chicling bomige 132 | triends of wo game 133 | stat the booe 134 | black warss 135 | grean torrans game 136 | forby baar card game 137 | preamiter wan game 138 | chne: co board game 139 | popeyine card game 140 | operation clingo 141 | the jine gome 142 | rver the game 143 | pyppetst cale game 144 | competse bame 145 | stro: che card game 146 | the clet game 147 | stat the baee 148 | the gatien game 149 | char the games 150 | the bica game 151 | etoam the wor game 152 | the nadaln game 153 | trie the bor 154 | star: the card game 155 | evo: the pand game 156 | eneersthe bld game 157 | the great wark 158 | black wors 159 | clack wand the dess 160 | the bingane 161 | chunders gome 162 | one nithe ble game 163 | monster rone 164 | cheruthe game 165 | cold the game 166 | the ithe game 167 | the fike game 168 | sved the gume 169 | stit the tire 170 | the erbe game 171 | chucking board game 172 | empole of the gume 173 | pocket quiz: gace 174 | pocket's wame 175 | huppes of the game 176 | constery game 177 | 10 f:'ct board game 178 | the ethe gome 179 | tlop 180 | rounte's roving 181 | the manden game 182 | trienas of the porl 183 | conquest wale 184 | star the gome 185 | star the tam 186 | chickino caad game 187 | pocket quizhe game 188 | the bicy game 189 | clack hars 190 | battle of the gime 191 | star the games 192 | the hondes card game 193 | the minien ge game 194 | the manien game 195 | glodey co the game 196 | chur the game 197 | pocket'& game 198 | plock wars 199 | the midel cgame 200 | i t: the card game 201 | klipk 202 | croymite tard game 203 | plop the card game 204 | the maphens 205 | chof: the card game 206 | the natpan game 207 | star the grme 208 | blau 209 | chipling gade 210 | hurketse game 211 | the great dar 212 | erepes of the glle 213 | frax 214 | the minden ge game 215 | the sute game 216 | overasthenven game 217 | the bige game 218 | blaak 219 | the greal gades 220 | pocket hick 221 | kligk 222 | hockyt't card game 223 | the griat card game 224 | chop: the card game 225 | uno: the piven 226 | glienite card game 227 | stat the bime 228 | pocket's racke 229 | chapling games 230 | monstery gome 231 | cold che game 232 | flauk 233 | chafling gane 234 | blotk 235 | conquestace 236 | chir the game 237 | the minder game 238 | miday baal card game 239 | the bile gale 240 | one mithenven game 241 | triends of an game 242 | the bige gamede 243 | evermetor whe game 244 | triadas of the poal 245 | popperst marken game 246 | chifling gome 247 | heef the game 248 | conquest las 249 | uno: the bige 250 | the fite game 251 | star the bime 252 | trives of the ande 253 | poperist card game 254 | the moval game 255 | the mangen game 256 | the fure gome 257 | the bily game 258 | stab the game 259 | 5 to: the card game 260 | operation warda 261 | 3 t: the cord game 262 | the liye game 263 | 20 : the board game 264 | over the bord game 265 | uno: the gige 266 | glogk 267 | glaik 268 | emperice card game 269 | over tion: ch game 270 | emperdel card game 271 | slop the gime 272 | a tay debcard game 273 | chary co roard game 274 | gera the game 275 | glawk 276 | evo: the burd game 277 | glaez 278 | tren the bo 279 | the lant game 280 | over the dard game 281 | quazs 282 | black pole 283 | over thond ch game 284 | chimlind game 285 | thim: tactics: game 286 | shar: the card game 287 | trapk 288 | triages of the poal 289 | chipling fame 290 | uno: the bage 291 | chadling gomise 292 | pocket cuis: game 293 | the bilt game 294 | bligk 295 | the minge ggame 296 | croymane card game 297 | the mothen game 298 | blakd 299 | spok the bige 300 | block ward 301 | itoam the car game 302 | chacling bonese 303 | chundire gome 304 | triends of to game 305 | heed the game 306 | ghoj: che card game 307 | conquest word 308 | poppetse card game 309 | monsterpane 310 | chefling gamede 311 | hurket & game 312 | pocket's ravil 313 | glodes in the game 314 | conquest comns 315 | the mice gime 316 | hood the game 317 | therethe curd game 318 | the fase game 319 | egeles on the gale 320 | wounde's roving 321 | uno: the bimes 322 | chunne's roving 323 | pocket's kack 324 | clock word 325 | popperst masken game 326 | the malions 327 | purket p game 328 | conquest card game 329 | chan the gome 330 | over the baee 331 | thef the game 332 | aneenithe for game 333 | don't toon the bort 334 | the bate game 335 | brank 336 | chundere game 337 | conquest gume 338 | uno: the ward game 339 | the nile game 340 | sta: the board game 341 | the civen game 342 | croynine word game 343 | 1862: che card game 344 | eneam the war game 345 | the michon game 346 | the manker game 347 | the manion 348 | trie the go 349 | 1002: c0 board game 350 | the mithe war game 351 | pocket's rigilg 352 | brac 353 | uco: the bardenspiel 354 | chafling bade 355 | the pite gare game 356 | anima tactics: kerka 357 | chazy co board game 358 | hoo: the rige 359 | checling gome 360 | 5 t': the card game 361 | blaod 362 | blant 363 | over then: ch game 364 | pocket pucs 365 | ove mithe war game 366 | anima tactics: derai 367 | geld the game 368 | stor the game 369 | the birl game 370 | the mone game 371 | braok 372 | friends of bangame 373 | flook 374 | sheanthe came game 375 | conquest dopire 376 | triekas of the boal 377 | anima tactics: derya 378 | gloek 379 | pocket's bogile 380 | xtto: the card game 381 | the matlen game 382 | black wols 383 | over the bans game 384 | one kitheckon game 385 | flidk 386 | the masren game 387 | hara the game 388 | exore the war game 389 | blao 390 | ove mithe wor game 391 | blonk 392 | the pice came game 393 | elemes of the gale 394 | uno: the bole 395 | the pige game 396 | pocket's rich 397 | the bocl game 398 | head the gome 399 | poperise card game 400 | the pinge of game 401 | bront 402 | blix 403 | the fite game game 404 | tlox 405 | sver the game 406 | triends of bangade 407 | flau 408 | the sare gome 409 | stot the times 410 | the midal ggame 411 | pocket's back 412 | the stre gome 413 | 3to: the board game 414 | the cang gome 415 | block wards 416 | triends of bongame 417 | the dile game 418 | chafling gowes 419 | chef the game 420 | trienas of the poal 421 | chelling game 422 | triends of botgame 423 | treves of the sless 424 | blandy 425 | trab the game 426 | blalk 427 | the buhe gamee 428 | chacling rogis 429 | poppetsa game 430 | uno: the bold game 431 | checking rang game 432 | chickino dare game 433 | conquest popire 434 | the fece game 435 | blurd toon the boll 436 | stack wand the ders 437 | triik 438 | the pithicken game 439 | coundere gome 440 | the manline 441 | theenghe game game 442 | e pols of the gale 443 | operation wringo 444 | klank 445 | trep the gome 446 | chary co doard game 447 | uno: the bive 448 | chefling game 449 | the bite game 450 | chatlenger 451 | the cire gome 452 | blayz 453 | the mival fgame 454 | chim: taclics: game 455 | blick harring game 456 | the dind game 457 | the mitir war game 458 | conquest bar game 459 | black words 460 | pocket's rack 461 | heep the game 462 | chifling game 463 | over thens or game 464 | monsters rone 465 | hoppythe cord game 466 | operation wores 467 | shepes & game 468 | the mipol game 469 | pocket's jice 470 | klayk 471 | conquest wopile 472 | stop the ge 473 | the micol card game 474 | cheshing board game 475 | eneanster che game 476 | the winder 477 | operation vertuie 478 | charstha board game 479 | eneensthe bhe game 480 | the fagl game 481 | emperidos ard game 482 | chicking bomigame 483 | stor the tage 484 | thip the gome 485 | horkey & games 486 | the miner wun game 487 | the bana game 488 | the monken game 489 | the fire game 490 | the moway game 491 | coppetse game 492 | chickiny fand game 493 | the fital sgame 494 | operation worgace 495 | conquest ofpales 496 | a day mibrard game 497 | whecking ring game 498 | one kichecken game 499 | cher the game 500 | gond 501 | chepling bamese 502 | triends of ry gace 503 | over the game 504 | over the bamd game 505 | 5to: the board game 506 | pocket's corise 507 | the mineu war game 508 | tren the io 509 | the biwe gome 510 | hippy sack card game 511 | geed the game 512 | the nimien game 513 | hour the game 514 | eveepite card game 515 | quizz 516 | the gase game 517 | cheanthe card game 518 | tlick 519 | moustert game 520 | gara 521 | hupers of the game 522 | emeles of the glee 523 | blocd ward 524 | the mate game 525 | shop she game 526 | chicling mobing 527 | the mise game 528 | chundire game 529 | blap 530 | the bindane 531 | uno: the gimen 532 | spores hick 533 | block wers 534 | elepes of the gale 535 | black ward 536 | ove sthe rang game 537 | cypric's board game 538 | the monden game 539 | ivo: the bard game 540 | chifling bame 541 | the nidal ggame 542 | clack wore 543 | 100 :.c. board game 544 | onora the war game 545 | the nival sgame 546 | block hars 547 | the great wales game 548 | the wase game 549 | the mingine 550 | slack word 551 | over the ward game 552 | hoppes of the game 553 | 1110: che card game 554 | the michen game 555 | emperdes care game 556 | preanthe pard game 557 | corder's barigame 558 | the male game 559 | chickino card game 560 | the sthe board game 561 | the mabiin game 562 | shaf the lame 563 | the sabeln game 564 | thechiny came game 565 | the bele gome 566 | slep the grme 567 | chicking biig game 568 | the mace game 569 | blac 570 | hoppetse gome 571 | the gane game 572 | the mitken game 573 | cypanshe card game 574 | tref the gome 575 | pocket's rogil 576 | hampes of the game 577 | trie the pine 578 | the jice game 579 | the nivel sgame 580 | the wise game 581 | ino: the bard game 582 | the great gorn 583 | the pity came game 584 | triends of ty rack 585 | store the card game 586 | chapling game 587 | conquest ofpides 588 | the macien game 589 | hour the gome 590 | theat ward card game 591 | the wite game 592 | glack 593 | cher the games 594 | chamin's board game 595 | chin the game 596 | 100 :'ct board game 597 | hepp the game 598 | one nichecken game 599 | gren the game 600 | chicling rogils 601 | cher the grme 602 | triedat of the doy 603 | the piny cave game 604 | triends of ty pack 605 | pocket'e rack 606 | the disa gome 607 | conquest pomnde 608 | pocket's pach 609 | block harr 610 | spot the game 611 | the orgat mer 612 | ontra the war game 613 | stad the bice 614 | the bise gome 615 | svan the bame 616 | spor the game 617 | glaen 618 | the ninien game 619 | tree the lo 620 | the minee war game 621 | shopeshe game 622 | the the game 623 | stoam che card game 624 | 101p:'ct board game 625 | the stae game 626 | eveles of the game 627 | the boll racien game 628 | shokes hick 629 | stad the rame 630 | the wegk game 631 | tren the ga 632 | operation wort 633 | chipling gomise 634 | stad the bage 635 | black woly 636 | the matlon game 637 | hoon the gome 638 | gleen te card game 639 | uno: the bamd game 640 | the mond game 641 | chassi's board game 642 | blork 643 | emperide cord game 644 | the face game 645 | stap the gome 646 | exerustir wer game 647 | gliemite cord game 648 | the stce board game 649 | gleemiter che game 650 | the erle game 651 | over the bard game 652 | conquest war game 653 | battle of the bale 654 | bland 655 | klaxk 656 | enperide care game 657 | thap the game 658 | das sthe gacesspiel 659 | the mites war game 660 | hocket's rich 661 | the banger 662 | tree the fo 663 | i s: the pard game 664 | black harr 665 | the gine game game 666 | stot the pire 667 | spores hich 668 | brapk 669 | cous the game 670 | gleen to card game 671 | hvel the game 672 | checking bane game 673 | the bune game 674 | chepning game 675 | the fide game game 676 | gees the bome 677 | the bile gome 678 | cper the game 679 | chicking boess 680 | blabk 681 | pocket quizle game 682 | the manian game 683 | charsthe board game 684 | blunk 685 | dreves of the gade 686 | the ninal cgame 687 | over the biame 688 | the stha game 689 | chap the game 690 | monstertace 691 | the erce game 692 | the lote game 693 | dou't doen the boat 694 | elepes on the gtee 695 | 5vo: the bord game 696 | dou't duen the port 697 | klagk 698 | operation wringd 699 | operation vorsuie 700 | coppuest game 701 | flalk 702 | uno: the gigen 703 | thup the game 704 | the mege game 705 | greenster che game 706 | the wang game 707 | counters rocil 708 | gloods on the gume 709 | the pile game game 710 | qlazk 711 | poppetse game 712 | the mingen game 713 | black 714 | experily card game 715 | exeamster was game 716 | thun the gome 717 | operation vorss 718 | the motiin game 719 | stop the pimes 720 | the dine game 721 | black tars 722 | shad the gime 723 | compit'a boord game 724 | the mitmon game 725 | cerp the game 726 | chacling bogis 727 | clock wars 728 | the fane games 729 | conquest dapir 730 | chanlind game 731 | uno: the war game 732 | eleppine cord game 733 | klook 734 | hoppetst pame 735 | ship the gome 736 | chepling game 737 | the bicd gome 738 | triakes of the poel 739 | the zure game 740 | uno: the bio 741 | tref the game 742 | bladk 743 | the foge game 744 | gleenster che game 745 | tria the game 746 | shapes of the game 747 | egeepite cord game 748 | empers on the game 749 | operation fricet 750 | wvel the bage 751 | the mites are game 752 | hoppythe game 753 | the mice came game 754 | the sthe boork game 755 | exerestor war game 756 | the great gars 757 | the michuckon game 758 | chen the games 759 | the mande sgame 760 | the silel sgame 761 | blagd 762 | the mithe bar game 763 | monstert gome 764 | the sico board game 765 | the bine gamer 766 | dobo: the board game 767 | chen ghe gome 768 | black wole 769 | the matlin game 770 | the morian game 771 | stat the rage 772 | the migal sgame 773 | emperdas card game 774 | black wars 775 | pocket's kace 776 | chima tactics: game 777 | blick! 778 | gold: the card game 779 | the nise game 780 | triahes of the pial 781 | stot the tage 782 | trienat of the pors 783 | quint 784 | the lole game 785 | the mithe bol game 786 | hocket's dack 787 | flao 788 | plopethe card game 789 | the lond game 790 | horde 'f board game 791 | operation wornade 792 | chepling gome 793 | triakes of the pial 794 | gleenstor: ch game 795 | the mander 796 | chicling borise 797 | tren the gam 798 | a to: the card game 799 | chipling fane 800 | bracz 801 | blays 802 | cunquest game 803 | speles of wers 804 | operation werte 805 | imperdel care game 806 | tren the game 807 | the bihe game 808 | uno: the bame 809 | shopline game 810 | the moral game 811 | pocket qhos 812 | thef the gome 813 | chicking eabings 814 | the athe game 815 | charic's board game 816 | operation wortace 817 | slok the rige 818 | exeles of the game 819 | ever the bard game 820 | the ping came game 821 | chipking game 822 | the woll gacien game 823 | the pice game 824 | the great warze 825 | the pander 826 | chersthe board game 827 | the movan game 828 | the sire gamene 829 | the lank game 830 | herp the gome 831 | the winges 832 | 3no: the bard game 833 | ovel the bale 834 | the bose game 835 | glank 836 | the late game 837 | a day taar card game 838 | chacling game 839 | chicking roving 840 | triyk 841 | flayk 842 | the mithen game 843 | the ballers 844 | greens on the game 845 | thele tol wor game 846 | pocket picd 847 | the bele game 848 | monquest gome 849 | i day hebrand game 850 | popperst card game 851 | monsterpano 852 | the minge fgame 853 | glirk 854 | ev : the cand game 855 | pyrapide card game 856 | shop the gome 857 | chicling banese 858 | pocket's pagale 859 | battle of the bald 860 | operation vorncces 861 | the fice card game 862 | fliek 863 | taro 864 | clack ward 865 | chipling mamise 866 | 2110: che card game 867 | corder's bag game 868 | tref the fale 869 | spip the bame 870 | 2vd: the pirg game 871 | cypsic's board game 872 | croymite cord game 873 | the nine game 874 | houn the game 875 | flop the game 876 | operation fricge 877 | the mire game 878 | the ditel sgame 879 | the ficl game 880 | the buce game 881 | svep the game 882 | the bilp game 883 | pocket aciad game 884 | anima tactics: kerk 885 | the mani gome 886 | the panien game 887 | ghad the game 888 | theend'y card game 889 | one mithe for game 890 | chenlond game 891 | croymite ward game 892 | ther the bime 893 | triek 894 | stop the tane 895 | the fale game 896 | star the com 897 | chlo: the card game 898 | one vithecker game 899 | pocket's race 900 | huppes of the gime 901 | hoppytse card game 902 | the motion game 903 | pyperist card game 904 | houn the lome 905 | over then: cr game 906 | stad the bamese 907 | the minel cgame 908 | the winden game 909 | uno: the bicen 910 | eleepide card game 911 | the minel sgame 912 | poppetst fame 913 | wolder's board game 914 | the vackon game 915 | hurpetse game 916 | stat the bige 917 | bronk 918 | toun dhe game 919 | operatio tick game 920 | the sthe rogeme 921 | the milalg game 922 | the mackin game 923 | treals of the game 924 | gro: che bard game 925 | competca board game 926 | the sihe game 927 | one mithe che game 928 | conquest rard game 929 | cocket's rogil 930 | the gricy game 931 | the mithe afe game 932 | eleemitor: ch game 933 | bluck 934 | copmin's board game 935 | chicling wide 936 | operation sfine 937 | the niwal ggeme 938 | the fire gome 939 | stad the bige 940 | exeles on wor game 941 | uni: the bame 942 | chom: the card game 943 | the will gome 944 | ttad the game 945 | glix 946 | c tp: co board game 947 | operation were 948 | triakes of the poal 949 | charsche board game 950 | chomm the card game 951 | uno: the gard game 952 | ston the rame 953 | cheinthe game 954 | compythe game 955 | trea the go 956 | che ninile gome 957 | anima tactics: piea 958 | cheplind game 959 | flack 960 | dread blay card game 961 | one mithe far game 962 | the bind game 963 | poperite card game 964 | copp the game 965 | blick ward 966 | gladk 967 | conquest of ales 968 | the ho war card game 969 | the ette gamd game 970 | donos 971 | the vindachun game 972 | chinling gomise 973 | then the game 974 | emperdey care game 975 | block word 976 | 3 tay de card game 977 | the bingiro 978 | clink 979 | erepes of the game 980 | evo: the rard game 981 | trabk 982 | the sure game 983 | complese game 984 | tlaok 985 | qlaek 986 | star the bagese 987 | stoam the wor game 988 | chacling roving 989 | the bandeno 990 | ono mithe che game 991 | whinling game 992 | the mane game 993 | hear the game 994 | block ware 995 | over the bige 996 | svel the game 997 | gleenite card game 998 | 5 da: the card game 999 | monstertano 1000 | the mitis war game 1001 | slip the games 1002 | x di: the card game 1003 | experidis ard game 1004 | rounde's roving 1005 | compuest game 1006 | the nataon game 1007 | clack wand the dors 1008 | the mitke war game 1009 | uno: the baoe 1010 | hocket's pack 1011 | stap the gat 1012 | chary 'o board game 1013 | gleet ta bothe game 1014 | blanks 1015 | uno: the gime 1016 | computse pame 1017 | the mitian game 1018 | klaek 1019 | the fine game 1020 | flonk 1021 | the land game 1022 | eteam the cor game 1023 | hoppethe game 1024 | the morae game 1025 | blogk 1026 | the bire game 1027 | chicking bine game 1028 | stor the grme 1029 | the fice gomes game 1030 | emperide care game 1031 | hoppeest game 1032 | roun the game 1033 | conquest ror game 1034 | graek 1035 | the grean wame 1036 | the bicl game 1037 | conquest wo 1038 | evo: the rord game 1039 | one mithe wer game 1040 | stot the baoe 1041 | the fane gome 1042 | heps the gome 1043 | eleple of the game 1044 | 1011: c0 board game 1045 | computhe gome 1046 | drex the rame game 1047 | the miral game 1048 | the pithecken game 1049 | the canger 1050 | cheakthe card game 1051 | conquest wor game 1052 | the wile game 1053 | over the bame 1054 | klogk 1055 | spopeest game 1056 | monsters gome 1057 | the lahe game 1058 | the mibaen game 1059 | insa the card game 1060 | blaxk 1061 | pocket's puck 1062 | hocket's game 1063 | the ballens 1064 | teep the game 1065 | the gahe game 1066 | glazk 1067 | evd: the pard game 1068 | pype the card game 1069 | checling bade 1070 | the cange mer 1071 | the mipiin game 1072 | stan the bame 1073 | twep the game 1074 | the nical sgame 1075 | i day herrand game 1076 | staret's board game 1077 | cheb the game 1078 | monstertalo 1079 | ovel the bine 1080 | cocket's rocil 1081 | uno: the bife 1082 | ovel the bage 1083 | phap the gome 1084 | the modal game 1085 | the miter wun game 1086 | the rthe gome 1087 | the nase game 1088 | the mitis wal game 1089 | the evte came game 1090 | eleppite card game 1091 | ove sthe bard game 1092 | hass the game 1093 | dreven of the glee 1094 | the lile game 1095 | the mand gome 1096 | the basl game 1097 | monsterparos 1098 | the great gore 1099 | checling gime 1100 | the lice game 1101 | did ': war card game 1102 | uno: the bote 1103 | hocket's pich 1104 | toun the gome 1105 | blapk 1106 | black wart 1107 | the : che card game 1108 | the mace gime 1109 | the mune game 1110 | the sarw game 1111 | brayk 1112 | croy the card game 1113 | preamites wor game 1114 | tref the lo 1115 | the migal cgame 1116 | plaok 1117 | the nidal ogame 1118 | chak the game 1119 | the mithe far game 1120 | flox 1121 | tlaxk 1122 | trif the go 1123 | elpers on the game 1124 | eneem the bar game 1125 | monster 1126 | chod: the card game 1127 | chicling bine 1128 | the wingon game 1129 | anima tactics: dirk 1130 | poppytte card game 1131 | star the bamede 1132 | cho: the board game 1133 | the mirien game 1134 | chicking eabine 1135 | operation ward 1136 | than the game 1137 | challenger 1138 | spad the bome 1139 | counte's robing 1140 | blacg 1141 | the binger 1142 | klag 1143 | stlo: the card game 1144 | tried 1145 | the cind game 1146 | the mindin game 1147 | stot the boge 1148 | the macren game 1149 | the mander game 1150 | grienite card game 1151 | pyllin's board game 1152 | the milaan game 1153 | the hinker cary game 1154 | over thons or game 1155 | the mirel sgame 1156 | thap the gome 1157 | battle of the gage 1158 | frienss of card game 1159 | stor the dame 1160 | sluck wand the dors 1161 | the ginge mer 1162 | stat the pime 1163 | chip ing game 1164 | the windero 1165 | 100s: c0 board game 1166 | aneenithe bor game 1167 | hepe the pame 1168 | pocket's dich 1169 | coun the game 1170 | conquest baw game 1171 | the mitlen game 1172 | the great worde 1173 | monster-apoly 1174 | das sthe giresspiel 1175 | heer the gome 1176 | overasthe bar game 1177 | ovo: the bard game 1178 | pocket pins 1179 | treels of the game 1180 | chipking board game 1181 | stat the bioe 1182 | chimling ramese 1183 | pocket'y card game 1184 | shop the game 1185 | the binder 1186 | grades on the game 1187 | the dand game 1188 | brick 1189 | chofmite card game 1190 | one mithenver game 1191 | bloik 1192 | gtar the game 1193 | glook 1194 | chimling bomise 1195 | over the band game 1196 | tren the go 1197 | ghlo: the card game 1198 | flamk 1199 | the fity came game 1200 | flapk 1201 | the bunce game 1202 | the mitale game 1203 | dobon the bcard game 1204 | pocket's ricil 1205 | aneensthe che game 1206 | blagk 1207 | shepes of the gale 1208 | the ping cane game 1209 | chicling bonide 1210 | hockey't card game 1211 | uno: the bames 1212 | the wind game 1213 | fribnss of card game 1214 | block hord 1215 | charst's board game 1216 | stat the games 1217 | over the gage! 1218 | pupby sack card game 1219 | blozk 1220 | tora 1221 | the mithe wol game 1222 | spat the bume 1223 | pocket quiss 1224 | cera the game 1225 | sped the bame 1226 | gonds 1227 | chickino band game 1228 | playk 1229 | cno: she card game 1230 | theen sard card game 1231 | eleapide cand game 1232 | trie the lo 1233 | eleers of the game 1234 | etoam the war game 1235 | eleaes of the gale 1236 | the mica board game 1237 | challengere 1238 | chante's roving 1239 | the bill game 1240 | pocket's bics 1241 | pocket's roce 1242 | the lil. game 1243 | storm the card game 1244 | hoppetse game 1245 | star the pime 1246 | blirk 1247 | braek 1248 | the bile game 1249 | the matmin game 1250 | glonk 1251 | chac the game 1252 | bloc 1253 | ove: the bard game 1254 | the matien 1255 | over the card game 1256 | the minken game 1257 | stan the bige 1258 | monster lon 1259 | operation words 1260 | the lore game 1261 | ito: the board game 1262 | stot the bace 1263 | chipkind game 1264 | the nobaan game 1265 | the mitiin game 1266 | erepes of the gale 1267 | evi: the lard game 1268 | blawk 1269 | the maval game 1270 | the wond game 1271 | choybing card game 1272 | blocd wors 1273 | choj: the card game 1274 | the lise game 1275 | the great warke 1276 | chod: che card game 1277 | chin the gome 1278 | pocket's bick 1279 | the great wory 1280 | stoam the war game 1281 | the simel sgime 1282 | the face gome 1283 | chickeng bobing 1284 | plopeese game 1285 | glaxk 1286 | i tay mebrard game 1287 | conquest gome 1288 | chicking sobing 1289 | glagk 1290 | hoppethe cerd game 1291 | pro mithe war game 1292 | shipes of the game 1293 | bloxd 1294 | the rase game 1295 | plop the game game 1296 | the lire game 1297 | the bige gome 1298 | triends of wargame 1299 | greenstescard game 1300 | the bime game 1301 | hurpey & game 1302 | gwer the game 1303 | one mithecken game 1304 | conquest comndes 1305 | over the rard game 1306 | the minger game 1307 | the maclens 1308 | 3 t': to board game 1309 | evo: the pind game 1310 | emperise card game 1311 | heel the gome 1312 | double dane 1313 | elepmitor dhe game 1314 | the lare game 1315 | i t: the board game 1316 | poppetse pame 1317 | cheolind gome 1318 | the mivel sgame 1319 | chimling gade 1320 | theldstor wor game 1321 | exeres or wer game 1322 | the mitil wal game 1323 | uno: the dord game 1324 | the pitlackon game 1325 | pocket quize 1326 | poppytst card game 1327 | stocks hick 1328 | ino: the ward game 1329 | the mitlin game 1330 | ster the games 1331 | chiflind gome 1332 | pocket's bice 1333 | chickiny card game 1334 | pocket's rogilg 1335 | eneenster che game 1336 | the milian game 1337 | stan the game 1338 | operation sring 1339 | tripes of the ange 1340 | tlax 1341 | trundes of the poal 1342 | sheams of the game 1343 | a t': to board game 1344 | coppetse pame 1345 | the ninde dgame 1346 | the meddin gome 1347 | the tand game 1348 | stor the bage 1349 | the nical ggame 1350 | horkey & game 1351 | the wangen 1352 | monsterpalo 1353 | chun bhe game 1354 | ono: the bord game 1355 | emperdes card game 1356 | chofling bobil 1357 | triedat of the doys 1358 | the stare game 1359 | the great were 1360 | the f0se game 1361 | coperthe card game 1362 | glands 1363 | spoc the gome 1364 | chofling gawe 1365 | the mece game 1366 | conquest game 1367 | friendl of re pack 1368 | ovel the bige 1369 | eleemiter dhe game 1370 | fladk 1371 | the fame game 1372 | the mithe pel game 1373 | the sule game 1374 | the cise game 1375 | the cthe game 1376 | the bin. game 1377 | horpetse game 1378 | geef the game 1379 | one mithe the game 1380 | the grean farss 1381 | thel the game 1382 | das sthe gimesspiel 1383 | the mele gome 1384 | uno: the bome 1385 | competse fame 1386 | the baglers 1387 | arbal tack card game 1388 | the bingen 1389 | the winderoge 1390 | hippy baal card game 1391 | hamp the game 1392 | emeles of the game 1393 | the elte game game 1394 | blipk 1395 | hocket's rick 1396 | chacling robil 1397 | theb the game 1398 | the maluin game 1399 | chicking raninga 1400 | the mibain game 1401 | eleemite care game 1402 | pocket's hack 1403 | the fole ranket game 1404 | thecethe card game 1405 | trienas of the port 1406 | the bibe game 1407 | chishing board game 1408 | sved the game 1409 | operation worsuie 1410 | triends of ch gace 1411 | the bote game 1412 | pocket's wack 1413 | clom: the card game 1414 | monsters gomy 1415 | blocd word 1416 | anie tactics game 1417 | blidk 1418 | the ltae game 1419 | the mithe bor game 1420 | chaching board game 1421 | 101s: ce board game 1422 | whad the game 1423 | pocket's rovil 1424 | over the bice! 1425 | the winden we game 1426 | pocket's ruce 1427 | the micmen game 1428 | globel on the gles 1429 | blask 1430 | das sthe gamesspiel 1431 | corquest dopires 1432 | glodel on the ghee 1433 | the jite game 1434 | porayide card game 1435 | shof the gome 1436 | the wind gomeme 1437 | trink 1438 | staf the game 1439 | uno: the gine 1440 | monster rons 1441 | tree the go 1442 | poppytse game 1443 | block horring game 1444 | chan the game 1445 | over thon: cr game 1446 | the movel game 1447 | thef lhe game 1448 | chicking bale game 1449 | chicling game 1450 | the minan game 1451 | pocketss game 1452 | the great owers 1453 | the file gome 1454 | stok the bige 1455 | chen the game 1456 | overatio ticd game 1457 | shepey & game 1458 | elepes of the game 1459 | elepes of the glle 1460 | the niwil sgame 1461 | i t: the camd game 1462 | flatk 1463 | chen bhe game 1464 | store tion the borl 1465 | over the gages 1466 | clock ward 1467 | shep the games 1468 | pocket's bace 1469 | chepland game 1470 | hurpey & games 1471 | emples on the gnde 1472 | pyramide card game 1473 | pockeyst card game 1474 | the mimian game 1475 | ttap the gome 1476 | the cire game 1477 | qlack 1478 | operation werd 1479 | chadling gamede 1480 | ove sthe band game 1481 | gonos 1482 | conquest gamy 1483 | conquest wam 1484 | the nibaln game 1485 | qlaxk 1486 | bluck wand the dors 1487 | one nithe gle game 1488 | espamide card game 1489 | the minal cgame 1490 | pocket's corite 1491 | cherling gome 1492 | the lame game 1493 | poppeest game 1494 | chinling gome 1495 | pocket's pace 1496 | cocket's back 1497 | exerestim war game 1498 | the windin we game 1499 | competse lame 1500 | the nivel cgome 1501 | the hiwal sgame 1502 | the ninge ggame 1503 | the mice gome 1504 | seed the game 1505 | black walds 1506 | countere gome 1507 | charking board game 1508 | flaok 1509 | squpeest game 1510 | gleaes of the game 1511 | bragk 1512 | over tien: ch game 1513 | the machens 1514 | the great word 1515 | ther the game game 1516 | pocket's dack 1517 | e pors of the game 1518 | cocket's dack 1519 | the nivol cgame 1520 | the bilg game 1521 | chicking eaming 1522 | operation worel 1523 | enepes of the game 1524 | the lite game 1525 | plep the pacs game 1526 | klach 1527 | slock wors 1528 | the bisg game 1529 | ey quitt card game 1530 | hocket's pick 1531 | ovel the game 1532 | tren the la 1533 | chepling gomise 1534 | operation worgame 1535 | the miraan game 1536 | 1812: che card game 1537 | chen dhe game 1538 | the mine came game 1539 | black hors 1540 | chicling moning 1541 | the gile game 1542 | the oreay wer 1543 | pocket's pice 1544 | anima tactics: cery 1545 | the bander 1546 | theplind game 1547 | chipling gorese 1548 | char: che card game 1549 | a day baar card game 1550 | shipling game 1551 | eleepite cand game 1552 | the momaan game 1553 | pocketst camele 1554 | roun dhe gume 1555 | the pile game 1556 | the mitel game 1557 | the macion game 1558 | operation verssie 1559 | homp the game 1560 | the hidal sgame 1561 | stop the lige 1562 | the potthe card game 1563 | thun dhe game 1564 | one nithe wor game 1565 | the bile gamer 1566 | chinling bonese 1567 | eneensthe for game 1568 | cati: ca board game 1569 | black hork 1570 | chefland gome 1571 | shoom the che game 1572 | twef the gime 1573 | overatthenven game 1574 | counters rovil 1575 | monstertaro 1576 | quack 1577 | hed sthe game 1578 | operation dorname 1579 | bloyk 1580 | chafling game 1581 | coppetst pame 1582 | grean torrard game 1583 | blick canding game 1584 | the stha board game 1585 | trin the lo 1586 | anima tactics: 1iri 1587 | stod the game 1588 | the bare game 1589 | tlaak 1590 | blocd wars 1591 | the cole gome 1592 | stor: the card game 1593 | pocket pick 1594 | operation wers 1595 | conquestick 1596 | soun the game 1597 | the minien game 1598 | snock the wor game 1599 | chicling gome 1600 | the bire gome 1601 | over the gime! 1602 | uno: ithuckor game 1603 | the nirlan ge game 1604 | triple of the bane 1605 | uno: the liven 1606 | the mive game 1607 | the maciin game 1608 | trials of the game 1609 | the mibhin game 1610 | plopeesp game 1611 | the mivel sgeme 1612 | operution wers 1613 | spel the bige 1614 | conquest ofpires 1615 | block warss 1616 | poppythe came game 1617 | elexmithe che game 1618 | ster the gige 1619 | choy the card game 1620 | bliak 1621 | the sine game 1622 | star the wrm 1623 | stop: the card game 1624 | triz the game 1625 | computre game 1626 | the manges 1627 | gloxk 1628 | shepes of the game 1629 | pocket quiz: games 1630 | kliuk 1631 | the winden ge game 1632 | glacz 1633 | spot the bioe 1634 | tlatk 1635 | chickino cand game 1636 | shopey & game 1637 | cher tha game 1638 | uno: the game 1639 | eneels of the bale 1640 | grean tarrans game 1641 | star the bomene 1642 | the mocieg geme 1643 | the witeen game 1644 | the mithe gar game 1645 | emeles on the game 1646 | fraek 1647 | block the wor game 1648 | klax 1649 | wordey's card game 1650 | the birs game 1651 | the ding game 1652 | stat the bame 1653 | spoc the dame 1654 | the minel wun game 1655 | hesp the game 1656 | the fuse game 1657 | dreven on the glee 1658 | 5 do: the card game 1659 | the nipel ggame 1660 | the pice care game 1661 | tar 1662 | the sale game 1663 | the mand game 1664 | stor the games 1665 | monstera gomes 1666 | dreven on the ghee 1667 | the othe game 1668 | challing gade 1669 | hetp the game 1670 | chard 's board game 1671 | hopplese game 1672 | slack wors 1673 | operation flice 1674 | eneensthe wer game 1675 | the pity game game 1676 | e dols of the game 1677 | pocket's bacs 1678 | overasthe war game 1679 | conquest dfpores 1680 | pocket pace 1681 | greb the game 1682 | the dice gome 1683 | chicling bonise 1684 | the mathe war game 1685 | one mithe whe game 1686 | blaez 1687 | uno: the barestsppel 1688 | the boge game 1689 | spat the bace 1690 | cno: the bord game 1691 | hocket's dice 1692 | conquest ofpire 1693 | trigk 1694 | the woblay card game 1695 | stad the game 1696 | chadling gade 1697 | chaf the game 1698 | the nice game 1699 | altenity card game 1700 | x t': to board game 1701 | the fice game game 1702 | the mithe bon game 1703 | blond 1704 | chafling mowese 1705 | the bene game 1706 | consters rogil 1707 | flap the game 1708 | chafling gamese 1709 | combetse bame 1710 | flaak 1711 | step the game 1712 | over the bames 1713 | conqeest gome 1714 | ove nithecken game 1715 | the bingano 1716 | tara 1717 | the cure game 1718 | pocket'p dace 1719 | pocket's ribil 1720 | spot the bage 1721 | glozk 1722 | eneenithe por game 1723 | sto: the board game 1724 | blick wors 1725 | thep the game 1726 | poppetst card game 1727 | over the bind game 1728 | tlap 1729 | chepling gade 1730 | over the bage 1731 | the bece game 1732 | operation warde 1733 | chicling bade 1734 | then the gode 1735 | coppin's board game 1736 | clack 1737 | troplese game 1738 | theop tace card game 1739 | stot the bire 1740 | geep the gome 1741 | the mathen game 1742 | the ginden game 1743 | pocketst care game 1744 | shores & games 1745 | the wile gamer 1746 | the fore game 1747 | shar the bame 1748 | 101 :.ct board game 1749 | the nere game 1750 | counders gome 1751 | tree the po 1752 | teab the game 1753 | triak 1754 | one mithennee game 1755 | stop the bire 1756 | monsters game 1757 | the picand ce game 1758 | chimling momise 1759 | black hanr 1760 | camputhe game 1761 | the great pare 1762 | battle of the game 1763 | uno: the cird game 1764 | chickino bare game 1765 | 3 t: the bard game 1766 | trundes of the boal 1767 | the great ward 1768 | the ette game game 1769 | pocket dice 1770 | wherd 's board game 1771 | theenssard cr game 1772 | don't tion the bort 1773 | shecling gade 1774 | camputse game 1775 | gompetse game 1776 | ther the gamd game 1777 | competse game 1778 | eneersthe ble game 1779 | the bace gome 1780 | eleams of the game 1781 | i t': the card game 1782 | operation cringo 1783 | the file ganden game 1784 | thun the game 1785 | the matkin game 1786 | heer the game 1787 | the fice came game 1788 | chef bhe game 1789 | stor the gime 1790 | shapes of the gome 1791 | chinland game 1792 | quank 1793 | b t': the card game 1794 | thuond the wore 1795 | chepping gome 1796 | dou't toan the bort 1797 | pocket's gamede 1798 | operation flinge 1799 | erepmite card game 1800 | the hothe bcard game 1801 | star the pame 1802 | callin's board game 1803 | chepfing game 1804 | theend'y camd game 1805 | the dane game 1806 | the minal sgame 1807 | the wangere 1808 | chawling gamise 1809 | round 1810 | pocket pocs 1811 | chafling lame 1812 | tren the ge 1813 | conquest mapile 1814 | snip the bame 1815 | chipling gime 1816 | counters robil 1817 | the cingico 1818 | antenidy card game 1819 | the bage game 1820 | blayk 1821 | don't doen the boat 1822 | miday baar card game 1823 | the mange game 1824 | the menge game 1825 | hype the card game 1826 | spot the lige 1827 | stet the game 1828 | one mithe wor game 1829 | glaok 1830 | braik 1831 | pocket quiz: 1832 | block hard 1833 | the cale gome 1834 | chipling game 1835 | the lane game 1836 | tree the wo 1837 | chenling gome 1838 | pocket quiz: bame 1839 | emeles on the gume 1840 | pocket pics 1841 | pocket's dact 1842 | hoos the game 1843 | c te: the card game 1844 | bliek 1845 | altenite card game 1846 | blaxz 1847 | stap the game 1848 | triends of bargame 1849 | the ninel ggame 1850 | chefnihe game 1851 | the dile gome 1852 | eneem the wor game 1853 | block wand the dors 1854 | char the game 1855 | star the tame 1856 | the bice gome 1857 | toun the game 1858 | the miran card game 1859 | anima tactics: dery 1860 | monstert gume 1861 | svef the bole 1862 | tala 1863 | hampet of the game 1864 | coppetst fame 1865 | black wores 1866 | trauk 1867 | operation flint 1868 | the mine gat game 1869 | the mitis wer game 1870 | stot the like 1871 | theeeddy card game 1872 | tribk 1873 | pocket's ruck 1874 | the mahe game 1875 | the mical cgome 1876 | staritha borrd game 1877 | the eume game 1878 | the sidel sgame 1879 | hommer's board game 1880 | quazk 1881 | over therd ch game 1882 | over the bald game 1883 | shat the game 1884 | thip the game 1885 | the mibiln game 1886 | her sthe game 1887 | spor the dame 1888 | hypp the card game 1889 | tren the gane 1890 | the lict game 1891 | pocket'e dick 1892 | plack wors 1893 | monstery game 1894 | the wandere 1895 | ster the grme 1896 | the ninel cgame 1897 | anima tactics: y9ey 1898 | cherstha board game 1899 | over the birg game 1900 | ster the fime 1901 | hocket's dick 1902 | the great worz 1903 | the bandert 1904 | the line game 1905 | blicks 1906 | conquest ofpoles 1907 | the pichucken game 1908 | the vithe wor game 1909 | brokk 1910 | pocket's rice 1911 | tid ': war card game 1912 | the ballen 1913 | uno: the gacen 1914 | the nidel sgame 1915 | the great warse 1916 | popplese game 1917 | pocket quizee game 1918 | the list game 1919 | the fthe game 1920 | pocket's pict 1921 | the covel wer 1922 | one mithe ale game 1923 | tlack 1924 | uno: the biver 1925 | checking rane game 1926 | collin's board game 1927 | clack wars 1928 | chunters rogil 1929 | conquest comnde 1930 | chaoming game 1931 | sper the bage 1932 | don't tion the borl 1933 | stoam the car game 1934 | the lale game 1935 | chesling gamese 1936 | pocket's dick 1937 | thin the game 1938 | the great bore 1939 | anima tactics: herka 1940 | operation sfingo 1941 | triends of toegame 1942 | ther the bade 1943 | das sthe camesspiel 1944 | stad the bire 1945 | slop the game 1946 | tytr the cand game 1947 | pyty the card game 1948 | happes of the game 1949 | the fict game 1950 | pocketst card game 1951 | roun the gime 1952 | chafling gale 1953 | sver the bige 1954 | the greae onese 1955 | mad ': war card game 1956 | chef the gome 1957 | the fate game 1958 | elexmithe the game 1959 | stop the pogene 1960 | the cipal wer 1961 | pvep the ling game 1962 | herper & game 1963 | stan the bage 1964 | hupers on the game 1965 | the wind gome 1966 | the er's board game 1967 | conquest wor 1968 | hoo: the lige 1969 | 196s: c0 board game 1970 | chessire game 1971 | triekat of the doas 1972 | theendhe game game 1973 | one mithe bor game 1974 | plopethe bame game 1975 | thelestoy wor game 1976 | overatthe war game 1977 | then the glle 1978 | trep the game 1979 | popplest game 1980 | xhto: the card game 1981 | cho: the card game 1982 | operation fricg 1983 | croymite card game 1984 | glamk 1985 | flop 1986 | clack wors 1987 | the greay sark 1988 | commin's board game 1989 | brokd 1990 | the minge 1991 | stat the togese 1992 | uno:msthe wor game 1993 | hoppythe card game 1994 | poppetse gome 1995 | 1012: c0 board game 1996 | i tay debrard game 1997 | emperily care game 1998 | conquest dopores 1999 | the ethe game game 2000 | operation dere 2001 | the mite fad game 2002 | ove: the bird game 2003 | the cile game 2004 | the bane game 2005 | poppytse pard game 2006 | the mite game game 2007 | bliyk 2008 | chicling eobine 2009 | bradk 2010 | ched the game 2011 | the bove game 2012 | then the gale 2013 | elepms of the game 2014 | the more game 2015 | croymity card game 2016 | poppyest gome 2017 | chen the gale 2018 | the rice game 2019 | gledes on the gome 2020 | ove michecken game 2021 | chinne's roving 2022 | operation flingo 2023 | chen the bame 2024 | the great wares 2025 | hepe the game 2026 | choflind game 2027 | uno: sthecwor game 2028 | the bilh game 2029 | dhbo: the board game 2030 | monsterpare 2031 | tlip 2032 | the b0ct gime 2033 | triakas of the poal 2034 | cher the bame 2035 | the minir wun game 2036 | blacd wars 2037 | chicling bagele 2038 | conqueste 2039 | the nace game 2040 | aneensthe por game 2041 | the bive game 2042 | the great worls 2043 | triends of te gack 2044 | star the tage 2045 | stit the bike 2046 | stat the cam 2047 | operation wor game 2048 | triok 2049 | staf the rige 2050 | the aige game 2051 | chacking boess 2052 | chinny's roving 2053 | poperide card game 2054 | jockey's card game 2055 | anima tactics: dren 2056 | slopeese game 2057 | uno: the bardenspiel 2058 | thar the game 2059 | the nitord ch game 2060 | then the gade 2061 | sham: the card game 2062 | block wardy 2063 | the k0hl gome 2064 | the ping rang game 2065 | kraek 2066 | the file ganken game 2067 | the minger 2068 | chipling wome 2069 | conquest dopiles 2070 | cylric's board game 2071 | theet sard card game 2072 | battle of the bal 2073 | the bila game 2074 | the finden game 2075 | the nidel ggame 2076 | chicling bamise 2077 | the dule game 2078 | the lind gomed game 2079 | the mure game 2080 | the notien game 2081 | blitk 2082 | blark 2083 | dou't tien the bost 2084 | black ware 2085 | the sthp game 2086 | counters robing 2087 | black hardy 2088 | pnock thr war game 2089 | chicling bale 2090 | monsterl game 2091 | friends of ty pack 2092 | dou't tian the bort 2093 | triends of mongame 2094 | conquest dopide 2095 | monsters ros 2096 | emperdet card game 2097 | the macmen game 2098 | the bole game 2099 | 100 :tct board game 2100 | glock 2101 | groods on the game 2102 | the sinel game 2103 | i day debrard game 2104 | the bisl game 2105 | black wore 2106 | blopk 2107 | stot the bice 2108 | the ling board game 2109 | stot che game 2110 | glabk 2111 | the b0se game 2112 | the micg board game 2113 | the great ware 2114 | the nicel cgame 2115 | the piny care game 2116 | blaik 2117 | tweb the game 2118 | the wile gome 2119 | campetca board game 2120 | svap the bice 2121 | over the biot 2122 | blokk 2123 | the manken game 2124 | pocket's pich 2125 | shan the game 2126 | pocket pice 2127 | operation sbone 2128 | capp the game 2129 | theperst game 2130 | the land gome 2131 | glab the game 2132 | the buhe game 2133 | the mile gamer 2134 | spot the bime 2135 | chadling ragis 2136 | the macher game 2137 | then the gole 2138 | anima tactics game 2139 | glanks 2140 | the mindon game 2141 | chec the game 2142 | the bangiga 2143 | uno: the cord game 2144 | pocket's comise 2145 | conquest ofpiles 2146 | pleck hard ce game 2147 | blamk 2148 | pytcethe card game 2149 | the mice game 2150 | the sind game 2151 | ey cuist card game 2152 | poppythe card game 2153 | slack wand the dors 2154 | chipling gamese 2155 | the great werde 2156 | pisk the cand game 2157 | triends of we rack 2158 | tren the gome 2159 | the ping game 2160 | smare hoon the boll 2161 | evermitos war game 2162 | black hord 2163 | roor the game 2164 | ster the gumes 2165 | stit the bime 2166 | braxk 2167 | stot the pagese 2168 | dou't doen the bort 2169 | uno: the rard game 2170 | spap the game 2171 | black wurs 2172 | black wordy 2173 | hed mand game 2174 | the fice game 2175 | the mace gome 2176 | black hard 2177 | empers of the ende 2178 | thef bhe game 2179 | ono: the gord game 2180 | black halr 2181 | sted the game 2182 | the grean game 2183 | the laso gome 2184 | eleamiter dhe game 2185 | poperine card game 2186 | pocket quise 2187 | black hart 2188 | star the bames 2189 | pocket's rogile 2190 | brills of the game 2191 | quizk 2192 | starit's board game 2193 | triades of the poal 2194 | the etse game 2195 | slip the bice 2196 | stad the gom 2197 | the nitlin game 2198 | the wollar card game 2199 | clonk 2200 | pleay de card game 2201 | charic'r board game 2202 | shod bhe game 2203 | 3 tay debrard game 2204 | black werd 2205 | bramk 2206 | triedas of the poal 2207 | enper the are game 2208 | smurdes of the boal 2209 | the lind game 2210 | chiplong game 2211 | the wose game 2212 | chacling bamege 2213 | das sthe pamesspiel 2214 | the pingac we game 2215 | the mitien game 2216 | the jind ranken game 2217 | pocket's rocily 2218 | black hald 2219 | the mician game 2220 | the lice gomese 2221 | pocket qhes 2222 | emepes of the game 2223 | ship the gime 2224 | stoam the for game 2225 | poppetst pame 2226 | the mibieg game 2227 | the nire game 2228 | conquest of ame 2229 | triends of ey game 2230 | topp the game 2231 | cocket's puck 2232 | the nthe game 2233 | hoppythe gome 2234 | empole of the ande 2235 | plop the camd game 2236 | clop tre game 2237 | stot the bime 2238 | pocket's game 2239 | globel in the game 2240 | chicking morine 2241 | poppetst care game 2242 | the mome game 2243 | shepling game 2244 | black wary 2245 | groons on the game 2246 | coppyest game 2247 | the hile game 2248 | monster-aloly 2249 | the miteln game 2250 | tref the go 2251 | klaok 2252 | ivo: the card game 2253 | uno: the gicen 2254 | the great worl 2255 | hepp the gome 2256 | uno: the picen 2257 | rean the game 2258 | over the bige! 2259 | the molo rachey game 2260 | ove mithecken game 2261 | char: the card game 2262 | stot the rome 2263 | the mice game game 2264 | chickiny bane game 2265 | slop the bages 2266 | checking dale game 2267 | trond 2268 | black hurd 2269 | pocket's barigame 2270 | chofling rard game 2271 | emperidis are game 2272 | chundy's gomese 2273 | shepes & games 2274 | fhop the game 2275 | computhe game 2276 | thuond the wors 2277 | stat the bace 2278 | pocket'& wame 2279 | chep the gome 2280 | the mime game 2281 | conquest damy game 2282 | chafling gome 2283 | thechiny camd game 2284 | geap the game 2285 | 101 : che card game 2286 | over the bace game 2287 | blosk 2288 | treb the gome 2289 | the matken game 2290 | ovel the bames 2291 | pocket's rigil 2292 | black worr 2293 | the manie game 2294 | the bice care game 2295 | stal the game 2296 | spor the dime 2297 | aream tact card game 2298 | the eisl game 2299 | the fice care game 2300 | quaek 2301 | monster ron 2302 | conquest gomy 2303 | monstery 2304 | the dice game 2305 | blazk 2306 | pocket's racil 2307 | the micard ce game 2308 | the matdon game 2309 | uno: the wor game 2310 | uno:m the war game 2311 | star the board game 2312 | counters gome 2313 | emeles of the glle 2314 | the bice game 2315 | rear the game 2316 | tref the gol 2317 | evo: the band game 2318 | ereepite care game 2319 | chafling games 2320 | glafk 2321 | the nidel sgamy 2322 | monstergaro 2323 | the grian card game 2324 | ther the game 2325 | glaxz 2326 | i t: the bard game 2327 | pocket's games 2328 | triends of er game 2329 | a t': the card game 2330 | don't tion the bord 2331 | one mithe war game 2332 | the mite game 2333 | ston the rime 2334 | the mibian game 2335 | gliemite card game 2336 | blayh 2337 | anima tactics: fieb 2338 | over the bame game 2339 | black hurr 2340 | cherit's board game 2341 | glaed 2342 | hurpes & games 2343 | coundere game 2344 | the great gare 2345 | chicling moles 2346 | the bela game 2347 | huppytse card game 2348 | pypenshe card game 2349 | uno: the booe 2350 | the sire game 2351 | the piteld ce game 2352 | the mibalg game 2353 | glep the game 2354 | the great waren game 2355 | stoc the game 2356 | checling game 2357 | ghle: the card game 2358 | the ping caud game 2359 | one mithe ble game 2360 | brock 2361 | hoppen's board game 2362 | stap: the card game 2363 | the sice game 2364 | svel the bage 2365 | bord the game 2366 | one mithe gor game 2367 | conquest bom game 2368 | the babe gome 2369 | grogs 2370 | the motlln game 2371 | the minion game 2372 | the lese game 2373 | conquest bopnde 2374 | wigher's board game 2375 | triends of sy pack 2376 | the dise game 2377 | treb the game 2378 | the mibiin game 2379 | socket pics 2380 | the lage game 2381 | the bise game 2382 | operation flinto 2383 | the bend game 2384 | the pite game game 2385 | chop the gome 2386 | teed the gome 2387 | h t': the card game 2388 | plop the came game 2389 | conquest dapiles 2390 | triends of ry pace 2391 | blan 2392 | compethe game 2393 | egeepite card game 2394 | slip the game 2395 | chicling bonish 2396 | chiplend game 2397 | chaj: co board game 2398 | chipling ramese 2399 | emeles of the gale 2400 | empericy care game 2401 | chafling rowil 2402 | stot the bage 2403 | pocket's wace 2404 | stopm the che game 2405 | the minkon game 2406 | cati: co board game 2407 | onora the wor game 2408 | chep the game 2409 | campetse game 2410 | miday saar card game 2411 | klogs 2412 | the great wors 2413 | pocket's dace 2414 | operation wors 2415 | grack 2416 | the pithe wof game 2417 | qlank 2418 | operation dorrade 2419 | pocket dick 2420 | krayk 2421 | triends of on game 2422 | hed mane game 2423 | the greay wores 2424 | the bale game 2425 | shopethe game 2426 | the ethe game 2427 | the bocthe card game 2428 | the aiye game 2429 | trandy 2430 | tripk 2431 | pocket aciag game 2432 | stot the tome 2433 | the miveln game 2434 | don't tiin the boen! 2435 | the hostar card game 2436 | stip the game 2437 | frack 2438 | ster the dame 2439 | chun the game 2440 | slop the gome 2441 | block wurs 2442 | goo: the bobdetball 2443 | stop the bage 2444 | chicling mode 2445 | hees the game 2446 | croy che bard game 2447 | pocket's rovilg 2448 | exeameter wor game 2449 | trien 2450 | stoam the ward 2451 | ther the games 2452 | hoppetst game 2453 | black wardy 2454 | the mithe for game 2455 | pocket's bagile 2456 | coppin't board game 2457 | i t': to board game 2458 | snoom the che game 2459 | popperst manken game 2460 | pocket'& pack 2461 | black wards 2462 | chipking boacd game 2463 | 5 t': to board game 2464 | stad the bimes 2465 | trayk 2466 | clock wors 2467 | gleay torracs game 2468 | priamide card game 2469 | monsterpale 2470 | wolld 's board game 2471 | stad the bime 2472 | the sthe gami 2473 | theblind gome 2474 | consstha board game 2475 | charic': board game 2476 | the jine game 2477 | oven the gome 2478 | uno: the ricen 2479 | triekat of the doal 2480 | charstho board game 2481 | pocket's rick 2482 | triands of bongame 2483 | the gale game 2484 | plopetto card game 2485 | uno: the boge 2486 | greay torrans game 2487 | packet's pack 2488 | the bahe game 2489 | cocket's ricil 2490 | empors of the game 2491 | the ling game 2492 | horper & game 2493 | emperidy card game 2494 | copsit's board game 2495 | stan the bice 2496 | stat the bagele 2497 | i t: the pard game 2498 | chad the game 2499 | chunde's roving 2500 | the sthe game 2501 | conquest eopiles 2502 | evo: the cord game 2503 | operation wfinet 2504 | uno: the bioe 2505 | the fice gamese 2506 | choopiny card game 2507 | shecling game 2508 | anima tactics: derka 2509 | the fise game 2510 | tlagk 2511 | poper'st card game 2512 | blaek 2513 | the fare game 2514 | the rise game 2515 | conquest wome 2516 | the midein game 2517 | emperdal card game 2518 | hyty che card game 2519 | chickiny caud game 2520 | ext: the bard game 2521 | hepp the bame 2522 | the ganiin game 2523 | anima tactics: keria 2524 | exeles op wor game 2525 | blatk 2526 | glaak 2527 | the nitel sgame 2528 | triends of ty pock 2529 | trienat of the doy 2530 | ghod the game 2531 | the midal sgame 2532 | conquestaco 2533 | the mites wor game 2534 | the grean wors 2535 | anima tactics: deria 2536 | stop the pames 2537 | evermitor wor game 2538 | chef lhe game 2539 | hurpes & game 2540 | charitho board game 2541 | clack wand the ders 2542 | chinling bomise 2543 | the : the card game 2544 | conquestack 2545 | x to: the card game 2546 | ove sthe rand game 2547 | the bole gome 2548 | the nihe game 2549 | the loge game 2550 | kloyk 2551 | black warr 2552 | the sicien game 2553 | the name game 2554 | 1012: cr board game 2555 | 3 tay de rard game 2556 | spad the gome 2557 | the ficken geme 2558 | hurper & game 2559 | klonk 2560 | the bangere 2561 | eleves on the glee 2562 | chafling gowese 2563 | cheric's board game 2564 | conquest ofpile 2565 | kloak 2566 | clogs 2567 | chickiny cied game 2568 | klack 2569 | monstermaro 2570 | the nankin game 2571 | the mition game 2572 | uno: the lard game 2573 | stop she game 2574 | chacling rogese 2575 | the mitie war game 2576 | the uthe cand game 2577 | coppytse game 2578 | spop the bages 2579 | proy che cord game 2580 | tlook 2581 | stot the bige 2582 | hyty the card game 2583 | the wond games 2584 | blax 2585 | the mitben game 2586 | krack 2587 | operation vortuie 2588 | eleppine card game 2589 | ovel the gome 2590 | the bingeso 2591 | uno: the barvensppel 2592 | conquest dopile 2593 | overatthenver game 2594 | chen ghe game 2595 | the virbin ge game 2596 | roun the gome 2597 | the minder ge game 2598 | the matpen game 2599 | pocket picl 2600 | mindy baar card game 2601 | plax 2602 | the minmin ce game 2603 | shop: the card game 2604 | black hall 2605 | the mackon game 2606 | blin 2607 | black horr 2608 | hoad the game 2609 | gold the game 2610 | herp the game 2611 | hoppytte card game 2612 | chel the game 2613 | conquest bop game 2614 | star the trme 2615 | heen the game 2616 | homp the gome 2617 | chen the eame 2618 | 3 t: the board game 2619 | triemat of the pork 2620 | e bols of the game 2621 | star the gage 2622 | the lige game 2623 | a tanite card game 2624 | the bice game game 2625 | evermiton wor game 2626 | monstermard 2627 | worder's card game 2628 | pocket's tick 2629 | the fact game 2630 | plip the card game 2631 | the llhe game 2632 | das sthe came spiel 2633 | uno: the liges 2634 | chicling mide 2635 | huroes & games 2636 | blok 2637 | plopetde care game 2638 | operation wars 2639 | operation ffing 2640 | operation word 2641 | gvad the gome 2642 | counters rogil 2643 | cho: the cord game 2644 | the m che card game 2645 | the lita game 2646 | croy che card game 2647 | the the gume 2648 | chipling mame 2649 | blacz 2650 | trianas of card game 2651 | i ta: co board game 2652 | monster-ololy 2653 | ptock the war game 2654 | the minge game 2655 | hocket'e dach 2656 | 100 :.ct board game 2657 | the midil wul game 2658 | exeamiter wor game 2659 | humpes of the game 2660 | pocket quiz: game 2661 | conquest rome 2662 | glax 2663 | black wals 2664 | stat the bice 2665 | the micald ce game 2666 | block wors 2667 | hem sthe game 2668 | chip the game 2669 | the mongen game 2670 | theeetse game game 2671 | glags 2672 | klauk 2673 | gno: the word game 2674 | block wore 2675 | char: 'o board game 2676 | spot the bige 2677 | espertte card game 2678 | the bone game 2679 | the dore game 2680 | chifline game 2681 | stot the tige 2682 | shap the game 2683 | the mical card game 2684 | tree the le 2685 | monstermano 2686 | then dhe game 2687 | reel the game 2688 | hocket's dace 2689 | pocket's racilg 2690 | the fihe game 2691 | plack wars 2692 | uno: the word game 2693 | over the gare game 2694 | kliyk 2695 | over the nard game 2696 | black holr 2697 | drepestor the game 2698 | c ts: co board game 2699 | the monge game 2700 | blick wars 2701 | egeles of the game 2702 | triends of oy pack 2703 | the elhl game 2704 | the sure gome 2705 | the fatiin game 2706 | triends of fongame 2707 | the mines wol game 2708 | spop the lime 2709 | cliok 2710 | eremes of the game 2711 | eleems on the game 2712 | priah's af card game 2713 | the wand gome 2714 | charte's roving 2715 | trieras of the porl 2716 | shapeshe game 2717 | the bils game 2718 | chofling game 2719 | bloek 2720 | heroes & games 2721 | monsterparo 2722 | star the game 2723 | glick 2724 | ardam tack card game 2725 | eleppide card game 2726 | the mithe wor game 2727 | stoam the mar game 2728 | prie the pard game 2729 | the breat ware 2730 | black hers 2731 | eleamitor dhe game 2732 | the ficard ce game 2733 | the nicel sgame 2734 | uno: the card game 2735 | monstermalo 2736 | pocket's dock 2737 | es quist card game 2738 | the ling bod game 2739 | bliok 2740 | poppetst lame 2741 | hoot the gome 2742 | flaxk 2743 | conquest worl 2744 | the cite game 2745 | the e0he game 2746 | triakat of the poal 2747 | slick 2748 | chicling bamine 2749 | chafling garese 2750 | empers of the anee 2751 | trea the co 2752 | conquest ropide 2753 | the miriin game 2754 | hocket's card game 2755 | over the bird game 2756 | one arthe wor game 2757 | popeyite card game 2758 | the gation g 2759 | chicming rand game 2760 | conquest bor game 2761 | the lthe game 2762 | the fole ranken game 2763 | huppythe card game 2764 | friendly fire pack y 2765 | chacling rovil 2766 | the pite card game 2767 | the bind gome 2768 | the vite game 2769 | the mond gome 2770 | black wand the dors 2771 | the mite gam game 2772 | qlizk 2773 | the timain game 2774 | uno: the licen 2775 | the graat work 2776 | pocker's ricil 2777 | flap 2778 | glayk 2779 | chofm the card game 2780 | the bickly game 2781 | the mole racker game 2782 | the sthe gome 2783 | the basa game 2784 | conquest bamy 2785 | the boss game 2786 | flagk 2787 | hopp the game 2788 | conquest ofpale 2789 | the mation game 2790 | the nival cgime 2791 | svan the rage 2792 | the file ranken game 2793 | the flhe game 2794 | hep sthe game 2795 | chicling mobine 2796 | chary co board game 2797 | triands of the boal 2798 | poppytse card game 2799 | chan dhe game 2800 | the luhe game 2801 | galnsthr board game 2802 | hocket'& dack 2803 | monstert game 2804 | 6 tay de card game 2805 | corder's borigame 2806 | the lihe game 2807 | pocket's fick 2808 | pocket's rovitg 2809 | plack 2810 | ster the bagete 2811 | slink 2812 | stop the game 2813 | commer's board game 2814 | ston the bame 2815 | operation frintor 2816 | the mitel sgame 2817 | whop the cord game 2818 | plopeest game 2819 | plopetde card game 2820 | triends of pongame 2821 | counde's roving 2822 | slack wars 2823 | the movay game 2824 | star the bay 2825 | black hold 2826 | chicling ranese 2827 | conquest ofplle 2828 | the hinge of game 2829 | poppetss card game 2830 | the bale gome 2831 | thelestor wor game 2832 | triemas of the boal 2833 | pocket buighe game 2834 | unora the war game 2835 | the great wore 2836 | comquest gome 2837 | chory tor wor game 2838 | the bure game 2839 | blacd wors 2840 | anima tactics: derk 2841 | 101s: c0 board game 2842 | chickino bard game 2843 | one mithe pol game 2844 | the mindel game 2845 | the moca game 2846 | chiplind game 2847 | ohe mithe war game 2848 | uno: the gife 2849 | stat the tame 2850 | stor the gides 2851 | the bace game 2852 | cheendhe card game 2853 | the fity game game 2854 | pocketst camile 2855 | the micay card game 2856 | black hark 2857 | ther che board game 2858 | eneensthe che game 2859 | the mingen 2860 | snip tha game 2861 | dou't tion the boatd 2862 | the mita game 2863 | shar the game 2864 | grofedde card game 2865 | stop the lame 2866 | the mindin ge game 2867 | the greay ware 2868 | pocket's pick 2869 | qloek 2870 | monstertaco 2871 | emeles on the glls 2872 | the mithe bur game 2873 | the binl game 2874 | the midel sgame 2875 | 3vo: the bard game 2876 | 3no: the bord game 2877 | theenthe card game 2878 | the fige game 2879 | chafning game 2880 | tree the gome 2881 | chiflind game 2882 | hepp the lome 2883 | ove sthe racd game 2884 | the mine game 2885 | chap the bome 2886 | brakt 2887 | the bthe game 2888 | tod ': war card game 2889 | tred the gome 2890 | pocket's robil 2891 | the toce game 2892 | the mithin game 2893 | sten the game 2894 | stat the bimk 2895 | stat the bite 2896 | uno: the gome 2897 | the suce game 2898 | the biee game 2899 | stet the gage 2900 | hoo: the bige 2901 | chicking game 2902 | poppetst gome 2903 | gland 2904 | trianss of card game 2905 | the great worle 2906 | the bore game 2907 | chaz: co board game 2908 | ster the pige 2909 | spad the bage 2910 | combetse game 2911 | holle of board game 2912 | the nidel cgame 2913 | porpetse game 2914 | sper the game 2915 | emperine card game 2916 | operation worss 2917 | chepnine game 2918 | chinling game 2919 | gref the gole 2920 | poppetst game 2921 | stap the li 2922 | the stca board game 2923 | pocket's ragil 2924 | pocket wuighe game 2925 | chefling gamise 2926 | klock 2927 | tlayk 2928 | char the gome 2929 | conqlest gome 2930 | flie thand ch game 2931 | tronds 2932 | plop the game 2933 | ohe mithe wor game 2934 | conquest gime 2935 | tree the game 2936 | the bame game 2937 | the great wor 2938 | hepp the pame 2939 | bound 2940 | trean herrans game 2941 | anira tactics game 2942 | black word 2943 | stap the bame 2944 | stot the rice 2945 | the great rame 2946 | operation worta 2947 | the mirder game 2948 | cheshing game 2949 | coppetse bame 2950 | quezk 2951 | tred the game 2952 | emperide card game 2953 | the fite came game 2954 | pypelist card game 2955 | the great warm 2956 | the grean ware 2957 | tren the lo 2958 | shores & game 2959 | spop the bige 2960 | chop the game 2961 | block words 2962 | hurper & games 2963 | monsterlers 2964 | when the game 2965 | ovel the bane 2966 | uno: the kard game 2967 | slock wond the dors 2968 | blodk 2969 | blon 2970 | pocket's gomile 2971 | checling ronils 2972 | the mange mer 2973 | the vinge won game 2974 | cher the gome 2975 | glond 2976 | double dans 2977 | kliak 2978 | star the bame 2979 | the bingine 2980 | stir the game 2981 | grofedte card game 2982 | chickino bane game 2983 | klamk 2984 | brack 2985 | pocket's duck 2986 | xhno: the card game 2987 | reor the game 2988 | stop the board game 2989 | fragk 2990 | the manger 2991 | the bingino 2992 | spop the lame 2993 | atto: the card game 2994 | chinling bade 2995 | chilling game 2996 | over thon: ch game 2997 | thar the gome 2998 | the pice come game 2999 | hera the game 3000 | geep the game 3001 | hoor the game 3002 | flanks 3003 | robon 3004 | thop the gomen 3005 | gees the game 3006 | the wanger 3007 | the micy gome 3008 | shap the gam 3009 | trand 3010 | stak the bome 3011 | black hare 3012 | stat the gome 3013 | chinlind game 3014 | dou't tion the bort 3015 | pocket quiz 3016 | the erky game 3017 | i s: the card game 3018 | the linder game 3019 | checkest robile 3020 | pocket's bovitg 3021 | comper's board game 3022 | trank 3023 | blick harding game 3024 | cord the game 3025 | wver the bard game 3026 | anima tactics: diey 3027 | gliek 3028 | comquest pame 3029 | the file game game 3030 | quabk 3031 | the eice gome 3032 | monstertale 3033 | pocket wuiche game 3034 | triends of wongame 3035 | the maleln game 3036 | the mane gome 3037 | stap the bate 3038 | triekas of the poal 3039 | the ping rane game 3040 | the bile game game 3041 | trie the popele 3042 | the minden game 3043 | the bill gamer 3044 | blakk 3045 | the mige game 3046 | roond 3047 | the nival cgame 3048 | hompetse game 3049 | seel the game 3050 | over the beame 3051 | triegas of the poal 3052 | trienat of the dors 3053 | evo: the lard game 3054 | the mila game 3055 | the greoy card game 3056 | quaed 3057 | stot the games 3058 | truplese game 3059 | plipes on the game 3060 | the botthe card game 3061 | starm the card game 3062 | the mase game 3063 | i do: the card game 3064 | the pin: card game 3065 | cheshing gome 3066 | the motien ge game 3067 | the micvin game 3068 | glink 3069 | the dire game 3070 | black war 3071 | stot the game 3072 | the bolt game 3073 | hocket's fack 3074 | the machen game 3075 | the ping cand game 3076 | the great onese 3077 | stot the baee 3078 | black wers 3079 | the wirdan we game 3080 | the sthald fom cwerd 3081 | char: co doard game 3082 | preamites won game 3083 | conquest doms 3084 | thepes & game 3085 | sver the bame 3086 | the tine game 3087 | stan the pames 3088 | chickino fand game 3089 | uno: the bard game 3090 | eteam the wor game 3091 | operation worsudes 3092 | the suse game 3093 | kraok 3094 | conos 3095 | stor the gome 3096 | operation worde 3097 | the uthe game 3098 | chofling gane 3099 | the bingin 3100 | scokethe bardenspiel 3101 | anima tactics: dieka 3102 | triekat of the doys 3103 | chacling bomige 3104 | i t: the bord game 3105 | the griat ware 3106 | ster the bame 3107 | the base game 3108 | the etre game 3109 | cho: the bard game 3110 | the pere gome 3111 | black hale 3112 | pocket's pack 3113 | the matkan game 3114 | steam the che game 3115 | uno: the biven 3116 | triends of ty game 3117 | ther the gome 3118 | ship the game 3119 | broyk 3120 | thes the game 3121 | operation sfane 3122 | the hosrar card game 3123 | stor the bime 3124 | greves of the glle 3125 | glevel on the glee 3126 | blilk 3127 | black hars 3128 | eleems of the game 3129 | uno: the lime 3130 | eneels of the game 3131 | the nidal sgame 3132 | the lase gome 3133 | operation vort 3134 | triends of oy rack 3135 | 100 :.ch board game 3136 | uno: the bardensppel 3137 | over the word game 3138 | the ming board game 3139 | the bangin 3140 | the pithucken game 3141 | priepiny card game 3142 | the nidal cgame 3143 | pnock the wor game 3144 | the mice care game 3145 | the bese game 3146 | the ngymen game 3147 | the pithen ce game 3148 | stoam the cor game 3149 | flax 3150 | glodes on the gale 3151 | the mile game game 3152 | chipling rades 3153 | herpes & game 3154 | glixk 3155 | black wark 3156 | the sune game 3157 | slack ward 3158 | block wars 3159 | chafling gade 3160 | c ta: co board game 3161 | klink 3162 | thechiny game game 3163 | chen ghe bome 3164 | drepen of the game 3165 | tren the gole 3166 | bloxk 3167 | conquest wame 3168 | hno: the cord game 3169 | uno: the boo 3170 | uno: the rice 3171 | char: co board game 3172 | harper & game 3173 | evo: the bard game 3174 | monstertare 3175 | tref the gage 3176 | quazz 3177 | spat the gice 3178 | pocket pacs 3179 | chicling boniers 3180 | chicminy card game 3181 | blanm 3182 | triends of boegame 3183 | the sica board game 3184 | black work 3185 | the will game 3186 | bliple of the game 3187 | stet the bame 3188 | the bite game game 3189 | stop the tor 3190 | char: the fard game 3191 | preameter won game 3192 | thef the gole 3193 | the linl game 3194 | the stht game 3195 | emperita card game 3196 | gleens on the game 3197 | whader's borrd game 3198 | the great wer 3199 | exeraston wer game 3200 | svap the lame 3201 | cley the game 3202 | the mind gome 3203 | the nothe bcard game 3204 | conquest dypiles 3205 | the manden ge 3206 | over tien: cr game 3207 | eveamitor wor game 3208 | sved the pige 3209 | shopeest game 3210 | stat the rice 3211 | the pitlachud game 3212 | stad the bame 3213 | the ling gageme 3214 | over the bumd game 3215 | the mote game 3216 | ereppite card game 3217 | toun dhe gome 3218 | spot the fite 3219 | flick 3220 | 1812: the card game 3221 | roor the gime 3222 | eneals of the game 3223 | the sula game 3224 | the aile game 3225 | the mica game 3226 | the sing board game 3227 | the fala game 3228 | grayk 3229 | overasthenver game 3230 | blonks 3231 | the ning rog game 3232 | chen the grme 3233 | anima tactics: der2a 3234 | toun the gime 3235 | a day mebrand game 3236 | the filelg game 3237 | qlonk 3238 | whurd 's board game 3239 | the masion game 3240 | conquest pome 3241 | theendhe card game 3242 | the bend gome 3243 | the matien game 3244 | evermitor wer game 3245 | the mage gome 3246 | the midey war game 3247 | the matian game 3248 | hocket'e dick 3249 | poppetse fame 3250 | spat the game 3251 | 3 day debrard game 3252 | the fald game 3253 | uno: the given 3254 | pocket's bomise 3255 | uno: the pard game 3256 | 10 p:ict board game 3257 | operation sringo 3258 | block wart 3259 | the mitlon game 3260 | uno: the rise 3261 | the nateon geme 3262 | the baller 3263 | the matiin game 3264 | rour the game 3265 | chass co board game 3266 | the mihe game 3267 | constere gome 3268 | x do: the card game 3269 | ovel the bame 3270 | chofmine card game 3271 | stoom the bar game 3272 | block wark 3273 | the lani game 3274 | empers of the ande 3275 | uno: the bord game 3276 | thip the lome 3277 | black warsy 3278 | stot the bioe 3279 | the mimele game 3280 | plopetse card game 3281 | 100s:-c0 board game 3282 | eve: the bard game 3283 | chapling gade 3284 | quand 3285 | chec the game game 3286 | preamites war game 3287 | the fuhe game 3288 | blaok 3289 | the great gome 3290 | the great gole 3291 | the lune gome 3292 | the etce game 3293 | emperdal care game 3294 | chiplind gome 3295 | 100s:'c. board game 3296 | hompet of the game 3297 | the mathin game 3298 | stat the game 3299 | coppethe game 3300 | glaek 3301 | blauk 3302 | than the gome 3303 | estole of the game 3304 | the mico board game 3305 | chicling balige 3306 | hompes of the game 3307 | tomblese game 3308 | x t': the card game 3309 | pockey's card game 3310 | cocket's rovitg 3311 | pocket's mack 3312 | slank 3313 | pmperica board game 3314 | stoam the che game 3315 | 6 ta: to board game 3316 | computse game 3317 | chimling games 3318 | 2ro: the card game 3319 | pompetse game 3320 | pocket's rocil 3321 | eleans on the game 3322 | the lone game 3323 | the m the card game 3324 | anima tactics: kery 3325 | over the gage 3326 | the mitlin ge game 3327 | plopeite card game 3328 | pocket'e game 3329 | the wege game 3330 | rount 3331 | the fithen game 3332 | imperide card game 3333 | the great wirss 3334 | cherlind game 3335 | qliek 3336 | hoo: the bage 3337 | the minke fgame 3338 | pocket qhess 3339 | trindys of the boal 3340 | stare the card game 3341 | choz: che card game 3342 | monstera gome 3343 | croy ine card game 3344 | chinne's raving 3345 | phip the game 3346 | pocket's rivil 3347 | scokethe barderspiel 3348 | tlink 3349 | the great bare 3350 | pocket quis 3351 | tlask 3352 | theees & games 3353 | the eler game 3354 | over thend ch game 3355 | conquest oopides 3356 | trienat of the poal 3357 | clank 3358 | thep the gome 3359 | poppes on the game 3360 | staf the gomes 3361 | evs: the bard game 3362 | monsters goce 3363 | ghom: the card game 3364 | dou't duan the bort 3365 | blag 3366 | porpeest game 3367 | star the bice 3368 | trie the fo 3369 | quazd 3370 | the wirden we game 3371 | comquest game 3372 | pnock the war game 3373 | star the gades 3374 | coppetst game 3375 | monquest game 3376 | pyrayide card game 3377 | quask 3378 | the great pales game 3379 | the manion game 3380 | geer the game 3381 | the the gome 3382 | the lind gome 3383 | the vithinken game 3384 | poppetsa bored game 3385 | aneemsthe che game 3386 | conquest dopires 3387 | the wice game 3388 | over the boame 3389 | tren the fo 3390 | chofling bowe 3391 | the mankon game 3392 | hoppythe pard game 3393 | a day baal card game 3394 | klaak 3395 | the michucken game 3396 | the goniin ge game 3397 | hoppetst bame 3398 | chepking game 3399 | the mitis wor game 3400 | upo: the game 3401 | the pity care game 3402 | treenstor: ch game 3403 | hoo: the bofdetball 3404 | chawling gomese 3405 | the nene game 3406 | quabt 3407 | chicking roning 3408 | pocket's bagilg 3409 | the mile game 3410 | hocket'e rick 3411 | the fing bod game 3412 | the wand game 3413 | thima tactics: game 3414 | plcket quiz: games 3415 | chim tha game 3416 | the ganden game 3417 | char: the ward game 3418 | the wander 3419 | the great wars 3420 | glapk 3421 | stap the rige 3422 | black harl 3423 | the pindin we game 3424 | the fide care game 3425 | ovel the ba 3426 | chipning game 3427 | stoam the card game 3428 | the fure game 3429 | chicling bomish 3430 | greens of the game 3431 | gliok 3432 | uno: the pord game 3433 | hocket's bick 3434 | slopeest game 3435 | chen the gome 3436 | pocket's pact 3437 | then the gome 3438 | chemling game 3439 | shup the game 3440 | ster the gime 3441 | cocket's raving 3442 | the ling gamede 3443 | the nidal cgime 3444 | the vithunwer game 3445 | 3bo: the board game 3446 | comblese game 3447 | tlep the game 3448 | tlabk 3449 | conquest oopiles 3450 | the sare game 3451 | anie stactics game 3452 | grills of the game 3453 | the bule game 3454 | flodk 3455 | black horl 3456 | chipming game 3457 | -------------------------------------------------------------------------------- /samples/repos2.txt: -------------------------------------------------------------------------------- 1 | cs100-plater 2 | python-list 3 | code-cole 4 | LearningCaid 5 | Stading-Test 6 | learning-fication 7 | django-tove-server 8 | python-jection 9 | hackbook-app 10 | linexin 11 | mikin.github.com 12 | schemant.github.io 13 | BIOMD0000000007 14 | rest-come 15 | websock 16 | comp300_project 17 | Testing-Framework 18 | TestEngection 19 | php_app 20 | test-pp 21 | datadew 22 | mathiel 23 | ServinProject 24 | test1001.github.io 25 | Arduino-Reb 26 | ti 27 | hws 28 | play-chacker 29 | sterentest 30 | marking 31 | testing-android 32 | kanea900.github.io 33 | JavaScript 34 | K038757.000 35 | sinitecr.github.io 36 | iOS-App 37 | php-app 38 | commine 39 | ProjectI 40 | node-mittolization 41 | node-jectation 42 | starppp 43 | C2 44 | K079440.000 45 | click-tracker 46 | TestAngro 47 | liban.github.com 48 | mkl 49 | simple-servine 50 | project_prection 51 | CSS-150 52 | TestReport 53 | Calding 54 | batanack.github.io 55 | K053797.000 56 | django-pplitation 57 | django-Example 58 | dmbx 59 | ass-app 60 | convest 61 | ionnest 62 | cord-cone 63 | testing-talization 64 | java-framework 65 | SSS 66 | python-website 67 | ka 68 | taskapp 69 | web_project 70 | test-example 71 | Devore-Example 72 | K003576.000 73 | simples 74 | objec.github.com 75 | websime-books 76 | ruby-mb 77 | SES_Test 78 | ServelTest 79 | malinian.github.io 80 | webson11.github.io 81 | Penire 82 | K030599.000 83 | catcher 84 | libener 85 | HTML5SApplication 86 | PergitProject 87 | Proyecto_Example 88 | spring-rest 89 | rest-example 90 | SOS_Projection 91 | vagrant-plitation 92 | server_server 93 | RubyTest 94 | myliming.github.io 95 | TS 96 | python-ple 97 | gtt 98 | prading-web 99 | linglig 100 | mylinian.github.io 101 | d43-ang 102 | python-book 103 | twitter-js 104 | myniter 105 | testing-pp 106 | lall2014.github.io 107 | schinter.github.io 108 | scrater 109 | nodejs-work 110 | stlankey.github.io 111 | Project_Jate 112 | python-pralication 113 | bite400m.github.io 114 | mpd 115 | Network-Extension 116 | bunchos 117 | Project_Autorial 118 | ProjectApilation 119 | singide 120 | xcs 121 | django-calization 122 | mankingo.github.io 123 | rubyter 124 | gamerap 125 | scholles.github.io 126 | MODEL1000000001 127 | kanking 128 | SpringProject 129 | tsm 130 | SimpleFilation 131 | ran 132 | project-plection 133 | hit 134 | BIOMD0000000286 135 | ardasch 136 | vwm 137 | spring-project 138 | Practice-Project 139 | tesh-meb 140 | 2014-01-00-Example 141 | pong-ap 142 | simple-book-App 143 | python-ligin 144 | abthack 145 | lingxen 146 | networky.com 147 | teather-demo 148 | git-lig 149 | K051056.000 150 | TestMag 151 | K005560.000 152 | Project-Kate 153 | iOS-Rep 154 | django-maplication 155 | bf 156 | K039288.000 157 | HelloWurld 158 | kput 159 | gamer-jest 160 | tla 161 | sanbol01.github.io 162 | CS_1302_Project 163 | MODEL1302030004 164 | TestPro 165 | connection 166 | sls 167 | datatech.github.io 168 | scripth 169 | node-move 170 | d43--ly 171 | reploct 172 | makisian.github.io 173 | JavaScriptsytion 174 | TestApp 175 | dynamrin.github.io 176 | android-buolder 177 | kdp 178 | jenami 179 | arduino_gement 180 | BIOMD0000000276 181 | simple-app 182 | samachin.github.io 183 | learningGame 184 | stlanken.github.io 185 | DS-Project 186 | TowerGame 187 | ecd 188 | opendons.com 189 | wot-projeck 190 | sliding 191 | mll 192 | K021974.000 193 | Python-SExtension 194 | CSS-140 195 | django-too-chicker 196 | dpp 197 | my_book 198 | shackix 199 | cs-plit 200 | TT 201 | grunt-app 202 | cs110-plater 203 | K057986.000 204 | web_projects 205 | OpenScrept 206 | battrit 207 | sdd 208 | dokining.github.io 209 | proyecto_ds 210 | p- 211 | liventem.github.io 212 | arduino_demo 213 | TextGime 214 | hba 215 | turs 216 | MODEL1006260004 217 | python-ly 218 | arduino-js 219 | python-hest 220 | BIOMD0000000005 221 | djaigle 222 | django-upation 223 | class_dest 224 | HTML5_Assignment 225 | stc 226 | arman601.github.io 227 | test.github.io 228 | libaclin.github.io 229 | spring-lacking 230 | d43-moc 231 | test-me 232 | my-project 233 | HelloChold 234 | commane.js 235 | TestRis 236 | android-couther 237 | quickta 238 | LearningGome 239 | django-plugin 240 | A18599 241 | simple-service 242 | A40748 243 | CS160_Project 244 | SOA 245 | DD 246 | K006955.000 247 | TestLag 248 | stlanten.github.io 249 | partrojs.github.io 250 | SimpleReplitation 251 | Arduino-Project 252 | node-jestion 253 | d43-ank 254 | django-issage 255 | project12 256 | stack_test 257 | djs 258 | testing_mation 259 | jacaling.github.io 260 | manking-web 261 | conminter 262 | projects_web 263 | K003998.000 264 | marte.js 265 | DMM 266 | milinist.github.io 267 | experdit.github.io 268 | nodejsi 269 | rarding 270 | matchin 271 | learning-wacker 272 | stading-leb 273 | node-fonework 274 | python-Example 275 | EOS-APP 276 | Imagent 277 | ons-app 278 | testger 279 | manisian.github.io 280 | Net-Ble 281 | MODEL1000000000 282 | tesh-rip 283 | solicker.github.io 284 | nodejow 285 | ness3501.github.io 286 | mangoom 287 | hellor-ander 288 | nodelub 289 | Chiling 290 | iOS_Web 291 | PODEL1003010005 292 | samanter.github.io 293 | jevim 294 | TestDionApp 295 | SOS_Project 296 | kanka201.github.io 297 | 2014_201 298 | jakebase.github.io 299 | Raf 300 | SmartCone 301 | Autores 302 | Cass_Test 303 | strine-2014 304 | stripe_2014 305 | testfit 306 | daa 307 | scb 308 | class-mation 309 | liz 310 | RPUTest 311 | pyProject 312 | MODEL1306060003 313 | pi 314 | TestEnvection 315 | TestingGim 316 | php-ppit 317 | parte.js 318 | lite-reb 319 | pythont-Demo 320 | schellig.github.io 321 | MSC-201 322 | ProjectCoinder 323 | SimpleExample 324 | levin.github.io 325 | iss-app 326 | ProjectExample 327 | SOP2014-Project 328 | python-xample 329 | junkoden.github.io 330 | A46999 331 | jslange 332 | K059597.000 333 | Android_App 334 | COM1300-Project 335 | node-cole 336 | nodeloc 337 | test_prantion 338 | preatacl.github.io 339 | Clash_Repository 340 | CS350_Project 341 | ims-app 342 | affecher.github.io 343 | hi 344 | contest 345 | python-oralication 346 | ChatApp 347 | CS140_Project 348 | grunt-demo 349 | yai21084.github.io 350 | kankingo.github.io 351 | CS150_Project 352 | git2014 353 | MODEL1006000030 354 | django-canculator 355 | COMP201_Project 356 | backbone-android 357 | pb 358 | generikh.com 359 | traving-web 360 | dea 361 | python-phome 362 | nodelod 363 | auto-ab 364 | panefine 365 | MODEL1000010001 366 | MODEL1910010001 367 | CSC-150 368 | pma 369 | K043969.000 370 | cdl 371 | my-hlog 372 | TestsBat 373 | TestAngit 374 | CMI-2014 375 | SimpleWer 376 | Swift-Example 377 | Java-Framework 378 | lokin.github.com 379 | spring_project 380 | CS307-Assignment 381 | sms 382 | spacket 383 | lotining.github.io 384 | BIOMD0000000001 385 | polling 386 | OpenCollection 387 | K003994.000 388 | convast 389 | rwc 390 | PCI-2014 391 | HTML-Test 392 | aplander.github.io 393 | project-Framework 394 | testing-template 395 | mychang 396 | schite11.github.io 397 | ywi20148.github.io 398 | PODEL1006090003 399 | parsoglo.github.io 400 | CS-Appp 401 | php-tester 402 | CSS-101 403 | Automes 404 | sik 405 | COCS-2014 406 | d43-fry 407 | file2014.github.io 408 | spp-proje 409 | 20142014 410 | akeaster.github.io 411 | kilingan.github.io 412 | realt-server 413 | python-0-Example 414 | lingring.github.io 415 | d43-tmm 416 | robo.js 417 | test-ap 418 | naneinen.github.io 419 | DataProjectice 420 | MODEL1000010000 421 | hpa 422 | javaling.github.io 423 | python_service 424 | django-reserver 425 | simple-gervine 426 | CS-2014-Project 427 | d43-tul 428 | d43-bw 429 | BC 430 | rails-teplitation 431 | stlankin.github.io 432 | CS401-Assignment- 433 | proyecto_alication 434 | IOA-Project 435 | Arduino_Project 436 | Java-Applimation 437 | dc 438 | yaliy000.github.io 439 | Kastest 440 | node-filest-plugin 441 | tisck 442 | dynaper 443 | mare2014.github.io 444 | pk 445 | scringit.github.io 446 | lineuran.github.io 447 | node-stre 448 | dbc 449 | baskbook.com 450 | linux20 451 | Dust-Repo 452 | meteor-fil 453 | stanker 454 | 100 455 | go--ppp 456 | CS407-Assignments 457 | projectTest 458 | fgs 459 | Arduino-Web 460 | testing-test 461 | tinylian.github.io 462 | A46599 463 | ruby-ry 464 | php-foot 465 | macano10.github.io 466 | WebProject 467 | python_2014 468 | sprine-2014 469 | expressing 470 | demolian.github.io 471 | Simple-Website-App 472 | dsb 473 | K007998.000 474 | SOS-Projection 475 | CS401-Assignment 476 | projection.com 477 | TestGitg 478 | getericy.com 479 | PDO 480 | MyFore-Example 481 | meteor-jh 482 | android-biolder 483 | CSS-100 484 | der 485 | carting 486 | K058356.000 487 | ArduinoForter 488 | IOP-Project 489 | PardinProject 490 | K003997.000 491 | langbase.github.io 492 | ServitProject 493 | bi 494 | iOS-Test 495 | PHP_RPP 496 | BIOMD0000000281 497 | malinden.github.io 498 | rankin01.github.io 499 | Project-Extensions 500 | frees.github.com 501 | MODEL1000060000 502 | Test1 503 | K007239.000 504 | python-phate 505 | 2014-12-Example 506 | BIOMD0000000459 507 | iOS-Lib 508 | django-bracker 509 | corling 510 | devalib 511 | simple_app 512 | nodejs 513 | webset-ap 514 | DS_Project 515 | HellorWer 516 | java8014.github.io 517 | vinicich.github.io 518 | Android-App 519 | lib-chapt 520 | TestingApp 521 | iOA_Project 522 | ArduinoFil 523 | hellorky.com 524 | PYG2014 525 | git-project-plugin 526 | imatemp 527 | java-web 528 | kanga201.github.io 529 | 2014-01-20-Example 530 | Testing 531 | multibook 532 | Practice_Example 533 | testingection 534 | ember-hation 535 | project_proj 536 | CS400-Assignment 537 | aprinker.github.io 538 | Java-Extention 539 | angular-dutorial 540 | lasinget.github.io 541 | new_project 542 | d43-tub 543 | SimpleApp-Android 544 | CS140-Projection 545 | samer201.github.io 546 | mython-clack 547 | comming.js 548 | cxe 549 | projects-web 550 | sinalian.github.io 551 | node-strt 552 | SS 553 | django-mation-ble 554 | testerk 555 | git_project 556 | node-ximple 557 | vim-project 558 | php-rools 559 | K003558.000 560 | BIOMD000000035L 561 | project_pontonfing 562 | mocki.github.io 563 | Arduino-Lid-Test 564 | django-log 565 | meter-amp 566 | Bartorm 567 | A46979 568 | helloder.github.io 569 | php-lib 570 | py-project 571 | Varavel_Project 572 | android-mor 573 | mash-cate 574 | django-tracker 575 | vagrant-ipp 576 | sanjing.js 577 | sageswir.github.io 578 | Python20-Example 579 | neb-pest 580 | django-pplication 581 | my-crocker 582 | tilylock.github.io 583 | schask01.github.io 584 | vagrant-chacking 585 | testApp 586 | rgo 587 | sameriam.github.io 588 | oss-app 589 | TestRepertin 590 | hw 591 | websa401.github.io 592 | studing 593 | py_look 594 | grunt-bo 595 | PHP-Proje 596 | 2014_001 597 | test-pra 598 | K099539.000 599 | Network 600 | ruby-in 601 | siz 602 | pf 603 | Persing-Android 604 | python-21 605 | BIOMD0000000271 606 | testing-web 607 | assignmenter 608 | TestingClient 609 | python-2014 610 | CIC2014 611 | sak 612 | test-praplitation 613 | modile 614 | prt 615 | contect 616 | Netwook-Extensions 617 | reso 618 | libyder 619 | pyt 620 | Project_Extensions 621 | python-lest 622 | swadect 623 | LearningFider 624 | stlinken.github.io 625 | node-lat 626 | buncher 627 | testing-lab 628 | APM-Project-Plugin 629 | OSE 630 | COSP345_Project 631 | djp 632 | mss 633 | A43000 634 | Project_Dema 635 | Project 636 | mylim.github.io 637 | ssd 638 | PHP-Script 639 | python 640 | Github-App 641 | TestGit 642 | ngw-project-plugin 643 | SA 644 | githeb 645 | eventean.github.io 646 | CECS-201-Project 647 | simple-amp 648 | network 649 | MODEL1306200007 650 | denatio 651 | mob 652 | SimpleWeApp 653 | 2014-00-201 654 | sib 655 | learning 656 | meteor-stom 657 | web-ppp 658 | snl_project 659 | generich.com 660 | testingrob 661 | daskbonk.github.io 662 | SHA 663 | K005520.000 664 | P- 665 | django-test 666 | project-proc 667 | MODEL1006010000 668 | A10319 669 | lines501.github.io 670 | dynamhin.github.io 671 | vagrent-plication 672 | sthacker.github.io 673 | PHP-Project-Plugin 674 | IOS-APP 675 | meteor1 676 | spring-ripo 677 | CS500-Assignment 678 | CS140_Projection 679 | pythonPrat 680 | OplineApplication 681 | android-mon 682 | mci 683 | sagage44.github.io 684 | SimpleApplication 685 | twitted 686 | cs3-ppp 687 | Spring-Example 688 | testing-lock 689 | pertgem 690 | nodejast.github.io 691 | Project-Bate 692 | CSS-server 693 | -IA 694 | E2 695 | sr 696 | K003999.000 697 | d43-moj 698 | test1901.github.io 699 | nww 700 | Net-Projects 701 | linux.github.com 702 | nardinal.github.io 703 | MyShine 704 | CIB 705 | sim32014 706 | mangoon 707 | django-app 708 | manking 709 | server-server 710 | php-lliter 711 | webming-android 712 | PythonProjection 713 | SimpleChacker 714 | COS3300_Project 715 | mangonation 716 | MyCore-Example 717 | kanma011.github.io 718 | sthack01.github.io 719 | lungjing 720 | test-come 721 | django-lag-java 722 | spacher 723 | backbook.com 724 | testing-tramework 725 | django-bite 726 | HTML5_Test 727 | Server_Example 728 | K027699.000 729 | pw 730 | pythonf-whb 731 | leading-web 732 | googleg 733 | parming.js 734 | Network-Extensions 735 | lingxem 736 | fma 737 | django-mpution 738 | spring-lest 739 | github.js 740 | Coursero_Android 741 | K027996.000 742 | K025995.000 743 | sok 744 | OpenScript 745 | schineny.github.io 746 | grunt-search 747 | Project_Patilation 748 | K005195.000 749 | mlc 750 | cscraph 751 | baver-ject 752 | cuntice 753 | gyt 754 | ort 755 | java_has 756 | colting 757 | d43-mo 758 | CS14004 759 | RuilScript 760 | MT-Server 761 | python-apprication 762 | program 763 | bie-tus 764 | node2014 765 | testing-tolization 766 | mmm 767 | manging.js 768 | pynasien.github.io 769 | openconfog 770 | iOS-Web 771 | PythonProg 772 | python-jectice 773 | aptoter 774 | MODEL101615001 775 | MODEL1006200000 776 | WWS 777 | COMS-201-Project 778 | super-server 779 | testing-framework 780 | openSTest 781 | marto.js 782 | git-spp 783 | 140 784 | test-201 785 | learning-app 786 | pyml 787 | edsort 788 | linex201.github.io 789 | testing_pp 790 | K003556.000 791 | rubyser 792 | A16536 793 | CS110_Project 794 | TestPra 795 | DS-Projec 796 | websing-test 797 | testing-js 798 | gja 799 | web_ppp 800 | pymanite 801 | A07707 802 | mikeping.github.io 803 | project 804 | PHP-Project 805 | wo-derp 806 | manko.github.com 807 | github 808 | skatler 809 | lag 810 | backbook-app 811 | php-project 812 | emberwam.github.io 813 | COMS-2014 814 | htm 815 | dja 816 | BIOMD0000000058 817 | Java-SApplitation 818 | comerad 819 | 2014-00-Android 820 | sbd 821 | MODEL1006210000 822 | Gemevise 823 | jsong.github.io 824 | Searn-Test 825 | scriter 826 | CS109-Assignment2 827 | node-ap 828 | yan 829 | laravel.js 830 | SAM-Framework 831 | proyection.com 832 | wdd-ppp 833 | node-fe 834 | JDG 835 | CS130_Project 836 | d43-mmp 837 | CSS2014 838 | testing_came 839 | comp300-project 840 | generator-js 841 | Network-Manager 842 | db 843 | Assignment_Demo 844 | ruby-mo 845 | project-patchase 846 | stading-web 847 | iOD-Project 848 | project-enveration 849 | test-ph 850 | django-lotics 851 | kab 852 | zj3 853 | OpenGLI 854 | shine.js 855 | python-work 856 | new 857 | jankiont.github.io 858 | grunt-web 859 | COSS-Repository 860 | comp530 861 | A44553 862 | testera 863 | fuzzybook 864 | django-extoration 865 | testing-ter-server 866 | kplind11.github.io 867 | generator-python 868 | python-r.com 869 | likingat.github.io 870 | repo1899.github.io 871 | COSS-2014 872 | d43-tmq 873 | node-rebood 874 | proyect 875 | C- 876 | vim-web 877 | git-xpp 878 | marsony 879 | MODEL1106090000 880 | Java-CApplitation 881 | kinemayp.github.io 882 | java_web 883 | stlanden.github.io 884 | sprine_2014 885 | project-js 886 | servite-js 887 | bb 888 | cdm 889 | same2014.github.io 890 | block_tep 891 | iOS-Leb 892 | K005599.000 893 | smp 894 | meteor-block 895 | HTML5_Application 896 | angular-tutorial 897 | malindan.github.io 898 | CS3-App 899 | java-works 900 | d43-map 901 | kankinge.github.io 902 | easy-lo 903 | Coursera_Android 904 | test-ry 905 | MS-Project 906 | tumblog 907 | elp-project 908 | Project-Jate 909 | vin-project 910 | android-quich-ap 911 | testrot 912 | pc 913 | simpley 914 | K001789.000 915 | MODEL0010200000 916 | MODEL1306200000 917 | Net-Cli 918 | BIOMD0000000070 919 | test_up 920 | CS147-Project 921 | task-leb 922 | A69698 923 | stingev 924 | coll2014 925 | MODEL1306060000 926 | First_Repo 927 | JavaWeb 928 | K039655.000 929 | karling 930 | easy-lig 931 | rss-tos 932 | project_test 933 | testing-pa 934 | tcm 935 | K003596.000 936 | CS1590Project 937 | test-preptitation 938 | colling 939 | OndineProject 940 | Projet-Plugin 941 | phonemis 942 | Project-Framework 943 | docker-ap 944 | CS340_Project 945 | pt 946 | CS32014 947 | Test-Ex-Android 948 | SpringTest 949 | 2014-Ex-Example 950 | First_Test 951 | esd 952 | instara 953 | randopp 954 | CS236_Project 955 | man 956 | tabysky2.github.io 957 | opento-web 958 | Simple_Project 959 | testerp 960 | K029586.000 961 | Arduino-Bation 962 | idlinest.github.io 963 | Arduino 964 | Devire-Example 965 | 20141014.github.io 966 | CS490-Assignment2 967 | bakebise.github.io 968 | COPS-201-Project 969 | HTML-Framework 970 | COM3300_Project 971 | saskApp 972 | my-blog 973 | nodej901.github.io 974 | univero 975 | HelloEWer 976 | lincuran.github.io 977 | demomen 978 | SimpleApp 979 | node-filext-plugin 980 | Hell-CApplitation 981 | pynamic 982 | python-lagin 983 | vik 984 | vcm-project 985 | simplestil 986 | damarian.github.io 987 | d43-kop 988 | linux201 989 | example_js 990 | Javascriptitation 991 | pss_tes 992 | pythont 993 | TestNote 994 | ban 995 | rest.github.io 996 | training-web 997 | BIOMD0000000223 998 | iCE-Project 999 | SOP_Project 1000 | MyChing 1001 | rm 1002 | python-leation 1003 | django-somplate 1004 | lic 1005 | nodejs-wub 1006 | K003296.000 1007 | Sarver-Examples 1008 | node-vestalization 1009 | toke-pp 1010 | BackApp 1011 | test201 1012 | iOS-Project 1013 | peoserd 1014 | echapant.github.io 1015 | ServerTest 1016 | IOS-2014 1017 | CS417-Assignment 1018 | schinked.github.io 1019 | 2014-Ex-Android 1020 | geter-app 1021 | A46975 1022 | BIOMD0000000060 1023 | lizex201.github.io 1024 | auto-an 1025 | stanter 1026 | node-idt 1027 | likining.github.io 1028 | testing-past 1029 | bancher 1030 | TestAng 1031 | CSC-100-Project 1032 | dsm 1033 | my-book 1034 | MODEL1300040014 1035 | A47579 1036 | dist 1037 | COSS-201-Project 1038 | node-githhe-server 1039 | jchemang.github.io 1040 | python-pliner 1041 | node-py 1042 | mongle-plate 1043 | CSC2014 1044 | pynamis 1045 | sim-lab 1046 | CSC2014-Project 1047 | RubyTack 1048 | conving 1049 | clash20 1050 | demonisy.com 1051 | lilycrsa.github.io 1052 | wertrojs 1053 | COSS310_Project 1054 | Must-Repo 1055 | test_app 1056 | CVS-Server 1057 | devaleb 1058 | d43-tra 1059 | TaskApp 1060 | shating 1061 | CSC1465_Project 1062 | K038581.000 1063 | A67599 1064 | asthamp 1065 | css-app 1066 | wew 1067 | Project2 1068 | vagrant-dutorial 1069 | CP 1070 | CSP-340 1071 | node-pplication 1072 | K017256.000 1073 | SimpleTest 1074 | sanza901.github.io 1075 | d43-bdz 1076 | SOS-2014 1077 | test-preplitation 1078 | nathob 1079 | docker_net 1080 | dresest 1081 | d43-trv 1082 | monglob 1083 | inbermet 1084 | schido11.github.io 1085 | docker-pog 1086 | WoFTest 1087 | shatroj 1088 | serving 1089 | kilinans.github.io 1090 | php-rester 1091 | learning-lization 1092 | project_mate 1093 | Proyecto-Example 1094 | autoreb 1095 | werting 1096 | PHP-Compilication 1097 | Github-Test 1098 | mando.github.com 1099 | BuilScript 1100 | K005596.000 1101 | tudblog 1102 | block-tracker 1103 | traving_test 1104 | COS-2014 1105 | mikining.github.io 1106 | sprine-mation 1107 | maz 1108 | my-s-201 1109 | d4s 1110 | MODEL1300060000 1111 | php-essignment 1112 | python-cestion 1113 | cpr 1114 | hmm 1115 | MODEL1006070002 1116 | jsharter.github.io 1117 | K073555.000 1118 | mancher 1119 | Simple-Cot-Android 1120 | GitHuble-Android 1121 | samashin.github.io 1122 | K053966.000 1123 | iPI-Project 1124 | mobile 1125 | SimpleChocker 1126 | docker-masier 1127 | CS2086 1128 | ali 1129 | klnejack.github.io 1130 | CS100-Assignment 1131 | node-st 1132 | test101 1133 | MSS-2014 1134 | matapem 1135 | K033276.000 1136 | SOS-Project 1137 | TestNet 1138 | TestEngration 1139 | oea 1140 | node-app 1141 | wib2014 1142 | revin.github.io 1143 | SOC2014 1144 | IMI-Project 1145 | stackar 1146 | node-forework 1147 | node-epp 1148 | test1004.github.io 1149 | K040660.000 1150 | extrenss.com 1151 | hfm 1152 | t- 1153 | Stading_Test 1154 | Same2010Project 1155 | css 1156 | sal 1157 | PHP-Script-Android 1158 | DataProjectite 1159 | lj 1160 | test-app 1161 | Clash-Repository 1162 | testemp 1163 | kansa901.github.io 1164 | libevel 1165 | contapp 1166 | CSC4372_Project 1167 | marling 1168 | zf 1169 | Project_Butilation 1170 | testing-templation 1171 | rails_teplitation 1172 | minicher.github.io 1173 | project-web 1174 | gupe208y.github.io 1175 | python2 1176 | hwm 1177 | My-Appp 1178 | faleb201.github.io 1179 | hymar.github.io 1180 | my-q-2014 1181 | reve.github.io 1182 | 2014-00-android 1183 | docker-pin 1184 | SPM-Project 1185 | K008-84.000 1186 | html5pp 1187 | spring-rab 1188 | arduino 1189 | schister.github.io 1190 | test-reb 1191 | Simple-Project 1192 | Arduino-Tracker 1193 | django-samplate 1194 | git-proje 1195 | netre-search 1196 | pythonts_project 1197 | Arduino-Gracker 1198 | liveraly.github.io 1199 | dacalits.github.io 1200 | Tester 1201 | django-seanchator 1202 | jsang.github.io 1203 | lts 1204 | genero-plo 1205 | ruby-mi 1206 | coderom 1207 | svl 1208 | demoning.github.io 1209 | K051676.000 1210 | python-service 1211 | coderow 1212 | jocaling.github.io 1213 | spring-repo 1214 | manco.github.com 1215 | de 1216 | K009967.000 1217 | MyShirt 1218 | convert 1219 | gamer-ject 1220 | malling.github.com 1221 | HTML-Web-Server 1222 | lanux_201 1223 | hentert 1224 | SimpleServer 1225 | grunt-Assignment 1226 | mizening.github.io 1227 | A41849 1228 | my-flew 1229 | test-an 1230 | Project_Demo 1231 | projest 1232 | docker-example 1233 | node-sample 1234 | TP 1235 | javascripts 1236 | HB-Server 1237 | CS160_Projection 1238 | project-pation 1239 | manchian.github.io 1240 | symfony-js 1241 | CSL 1242 | Multing 1243 | python_app 1244 | generiky.com 1245 | node-gate-project 1246 | tuperes 1247 | 20141018.github.io 1248 | K063688.000 1249 | StareNote 1250 | Hell-CApplication 1251 | test-ab 1252 | K016804.000 1253 | nodejs-application 1254 | hellowerexample 1255 | JSON-Test 1256 | demo.js 1257 | OpenCLM 1258 | Google-Framework 1259 | JavaScriptitation 1260 | test-extest 1261 | learn-server 1262 | test-ay 1263 | nec 1264 | node-pp 1265 | Projectorution 1266 | python-nestion 1267 | Ousi 1268 | iOS_Project 1269 | CSS-130 1270 | epec 1271 | icatest 1272 | BIOMD0000000010 1273 | sub 1274 | perchep 1275 | libott 1276 | continter 1277 | kthankun.github.io 1278 | HackProject 1279 | SimpleWebServer 1280 | Project-2014-Toals 1281 | node-jent-client 1282 | manickix.github.io 1283 | Project_Test 1284 | SOS-Framework 1285 | project_st 1286 | SimpleIndentation 1287 | miangrap 1288 | php-lab 1289 | statiz91.github.io 1290 | BIOMD0000000339 1291 | heacher 1292 | CSC1000 1293 | nodej091.github.io 1294 | py-pally 1295 | Testing_Android 1296 | test-pppitication 1297 | python_dest 1298 | compander 1299 | K014560.000-Plugin 1300 | 2014erver 1301 | K001989.000 1302 | emc 1303 | BIOMD0000000000 1304 | MODEL1102010017 1305 | node-jection 1306 | TRFTest 1307 | 2014-bractor 1308 | project-strection 1309 | node2008.github.io 1310 | learn.github.io 1311 | convist 1312 | COMP-101-Project 1313 | K003966.000 1314 | Java-CApplication 1315 | Linux-Framework 1316 | MODEL1006080000 1317 | lab 1318 | javalabs.github.io 1319 | pit 1320 | wintemo 1321 | php-eximple 1322 | alchaph 1323 | san 1324 | dac20041.github.io 1325 | newhjlin.github.io 1326 | django-tmep 1327 | kalankan.github.io 1328 | testing 1329 | PHP-Assignments 1330 | lokining.github.io 1331 | vagrant-tutorial 1332 | jabmavit.github.io 1333 | opencraft 1334 | sprine2 1335 | linyert 1336 | oos 1337 | testest 1338 | asdho.github.io 1339 | d43-gik 1340 | C3 1341 | idff2014.github.io 1342 | repotem 1343 | MODEL1006010009 1344 | websing_test 1345 | sthasker.github.io 1346 | stripri 1347 | crick-tracker 1348 | sline.js 1349 | K003575.000 1350 | TestRepert 1351 | iS_Project 1352 | Example_is 1353 | project-website 1354 | laminden.github.io 1355 | BIOMD0000000330 1356 | d43-mog 1357 | sama2014.github.io 1358 | matcher 1359 | LearningGame 1360 | d43-hle 1361 | Hellorver 1362 | html564 1363 | PHP-Project-pester 1364 | kanka901.github.io 1365 | django-application 1366 | websap91.github.io 1367 | KUA 1368 | otm 1369 | RubyWer 1370 | website-book 1371 | proyectapproject 1372 | project_pest 1373 | python-books 1374 | GoogleApplitation 1375 | simple-book-hpp 1376 | laz 1377 | CS101-Assignment- 1378 | lezill99.github.io 1379 | javalian.github.io 1380 | veci 1381 | HelloWer 1382 | git-project 1383 | morting 1384 | carding 1385 | arduino-baicher 1386 | node-witwer-server 1387 | CS406-Assignment 1388 | perling 1389 | WRE 1390 | matates 1391 | K003996.000 1392 | Matling 1393 | HelloSrard 1394 | quicker 1395 | node-sentation 1396 | tbl 1397 | bw 1398 | node-ans 1399 | cttp 1400 | python-application 1401 | bocker-201 1402 | testing_test 1403 | nome2014.github.io 1404 | K003986.000 1405 | my-n-201 1406 | MODEL1306280000 1407 | python_lagin 1408 | batelig 1409 | Marser 1410 | PMS-Script 1411 | SimpleWeblitation 1412 | CS416_Assignment-1 1413 | python-plate 1414 | node-sic-project 1415 | adff2014.github.io 1416 | sag 1417 | JavaApp 1418 | MODEL1006000000 1419 | cs32014 1420 | sample-datch 1421 | vanyard1.github.io 1422 | goc 1423 | wfffchen.github.io 1424 | imagent 1425 | montorp 1426 | K007657.000 1427 | scrapro 1428 | haseback.github.io 1429 | K023555.000 1430 | TestRepd 1431 | sin 1432 | cfs 1433 | simple-web-project 1434 | ruby-inder 1435 | d43-tmv 1436 | mails_app 1437 | BIOMD0000000492 1438 | Sprine-Framework 1439 | MODEL1006060000 1440 | kfs 1441 | sut 1442 | randomp 1443 | django-falculator 1444 | 945-Project 1445 | K003976.000 1446 | testing_xample 1447 | ybp 1448 | K033596.000 1449 | examplexample 1450 | Learning-App 1451 | csheront.github.io 1452 | 65211014.github.io 1453 | saz 1454 | K003557.000 1455 | python-simplate 1456 | ServerText 1457 | IOS-Project 1458 | learninger 1459 | instack 1460 | lib 1461 | althime 1462 | dt 1463 | K003278.000 1464 | rails-teptitation 1465 | testrapt 1466 | repotes 1467 | htw 1468 | noves.js 1469 | contast 1470 | html5.github.io 1471 | peesonProject 1472 | manasian.github.io 1473 | lobigle 1474 | node-gitwer-server 1475 | python-trawer 1476 | aplandit.github.io 1477 | starling.github.io 1478 | dwd 1479 | Java-Application 1480 | poplexter 1481 | Server-Example 1482 | python-forect 1483 | twitter 1484 | WebSApp 1485 | class-dation 1486 | apsshat 1487 | cloudbat 1488 | caa 1489 | minitar 1490 | samples 1491 | ds 1492 | sanving.js 1493 | css2014 1494 | atheboes.github.io 1495 | firstar 1496 | karkang.github.com 1497 | node-sile-project 1498 | hss 1499 | suni201 1500 | sja 1501 | dynashin.github.io 1502 | BIOMD0000000006 1503 | Project-Extention 1504 | jankling.github.io 1505 | demonian.github.io 1506 | ArduinoScation 1507 | K056559.000 1508 | SpringFramework 1509 | GIF-2014 1510 | testing2.com 1511 | unlond14.github.io 1512 | dimaing 1513 | rest1405.github.io 1514 | python-sacking 1515 | sklingen.github.io 1516 | K003205.000 1517 | TestAngAution 1518 | design-pation 1519 | stardak 1520 | creaspp 1521 | MyScame 1522 | docker-porker 1523 | d43-mop 1524 | siniteck.github.io 1525 | iOS-Lub 1526 | testroject 1527 | ProjectTest 1528 | BaskApp 1529 | django-ubsite 1530 | maschoc 1531 | RelloWer 1532 | robrife 1533 | fcas 1534 | Netwoik_Extensions 1535 | stharmew.github.io 1536 | CS12014 1537 | ballerver 1538 | phonestors 1539 | tc 1540 | TE 1541 | come201 1542 | kanjaing.github.io 1543 | d43-mab 1544 | ria 1545 | tentele 1546 | PHP-Assignment 1547 | code-ap 1548 | haseeng8.github.io 1549 | vagrant-slication 1550 | test-st 1551 | clask_pp 1552 | IS_Project 1553 | webscass.github.io 1554 | vagrant-app 1555 | Carling 1556 | Mystong 1557 | django-palitation 1558 | php-example 1559 | coderon 1560 | rest-meb 1561 | TestNot 1562 | molting 1563 | testing2 1564 | php-proje 1565 | hcs 1566 | grunt-planter 1567 | jplank14.github.io 1568 | Mobile-Masker 1569 | lin 1570 | grunt-chp 1571 | provesk 1572 | ColePar 1573 | multing 1574 | testing4.com 1575 | colming 1576 | harkang.github.com 1577 | manking.js 1578 | K004510.000 1579 | node-mitiolization 1580 | Project-Controlder 1581 | Personale 1582 | 310-Project 1583 | testing-example 1584 | Practico-Project 1585 | git-stript 1586 | Python-Test 1587 | test_2014 1588 | mdd 1589 | hwc 1590 | testing-cate 1591 | CS 1592 | barling 1593 | mandonation 1594 | vestark 1595 | jan 1596 | LearningGaid 1597 | Contest 1598 | levalig 1599 | DataDan 1600 | MODEL1306080000 1601 | java-cest 1602 | app-techer 1603 | arduino-ig 1604 | strinker.github.io 1605 | open120 1606 | django-plogin 1607 | Simple-Col-Android 1608 | QuickLogin 1609 | marrepo 1610 | aksasten.github.io 1611 | CSS-108 1612 | MODEL1306010000 1613 | maney201.github.io 1614 | COCS-201-Project 1615 | caninuto.github.io 1616 | 2048-cap 1617 | django-pplidations 1618 | SML-Framework 1619 | carturs 1620 | test_ly 1621 | CSP2014 1622 | MODEL1012010002 1623 | LaF 1624 | git-ppp 1625 | python-dio 1626 | COM3380_Project 1627 | pynapes 1628 | CSS-106 1629 | node-lib 1630 | phytest 1631 | d43-gra 1632 | demilation 1633 | java2014.github.io 1634 | CM 1635 | PODEL1306010003 1636 | levaleo 1637 | project-pontonsing 1638 | SAS 1639 | testing-xample 1640 | laby0106.github.io 1641 | django-pulitation 1642 | yong.js 1643 | MODEL1006010003 1644 | MODEL1002010008 1645 | fms 1646 | Inthonal-Project 1647 | twetterver 1648 | liberes 1649 | MODEL1006080003 1650 | comp301_project 1651 | django-hhalle 1652 | projectapproject 1653 | java-work 1654 | marding.js 1655 | auto-201 1656 | untrick 1657 | python-app 1658 | K047275.000 1659 | marking.js 1660 | Project-Rutorial 1661 | nodej008.github.io 1662 | stack-pocker 1663 | partiem 1664 | SimpleSart 1665 | jshemand.github.io 1666 | BIOMD0000000004 1667 | randapp 1668 | mynapes 1669 | LearningCoid 1670 | d43-kly 1671 | css-spp 1672 | matacap 1673 | lekin.github.com 1674 | ArduinoLig 1675 | quaiz018.github.io 1676 | CS-2014 1677 | bacashen.github.io 1678 | pails_app 1679 | lsd 1680 | MODEL1000000004 1681 | comviser 1682 | python-lacking 1683 | Project-Website 1684 | example-atp 1685 | K042521.000 1686 | saminten.github.io 1687 | 2014-01-00--xample 1688 | ClashTest 1689 | python-apalication 1690 | d43-dmm 1691 | python-servication 1692 | Fushing 1693 | pynamite 1694 | MODEL1906210000 1695 | comerock.github.io 1696 | coderob 1697 | sjd 1698 | tansasts.github.io 1699 | faminden.github.io 1700 | test_fic_project 1701 | luzonbuy.github.io 1702 | iOS-Webs 1703 | vanking-web 1704 | my_blog 1705 | juths.github.io 1706 | K003984.000 1707 | sase1004.github.io 1708 | vimling 1709 | grunt-js 1710 | K030084.000 1711 | actimith.github.io 1712 | K067958.000 1713 | calar201.github.io 1714 | CSCI-201-Project 1715 | twitter-sserver 1716 | node-sim-project 1717 | scmindin.github.io 1718 | SimpleInventation 1719 | A- 1720 | GitHub_Test 1721 | clyan.github.com 1722 | Unitorl 1723 | WebRime 1724 | ServerProject 1725 | K007997.000 1726 | K035966.000 1727 | chon 1728 | oss-ppp 1729 | node-projection 1730 | Image-Pracker 1731 | rubo.js 1732 | vagrant-plication 1733 | came100w.github.io 1734 | simple-java-client 1735 | kanka501.github.io 1736 | simplexample 1737 | carlier 1738 | karghath.github.io 1739 | geteriss.com 1740 | carling 1741 | javascript 1742 | dsc 1743 | kangling.github.io 1744 | TestGite 1745 | marting 1746 | php-wook 1747 | salintas.github.io 1748 | pj 1749 | MyChime 1750 | grunt-bot 1751 | pr 1752 | if101.github.io 1753 | OpenStor 1754 | Project_Extension 1755 | ishineen.github.io 1756 | bandopp 1757 | creande 1758 | MODEL1306210018 1759 | project_putorial 1760 | grunt-project 1761 | SymfonyManager 1762 | CSS2015 1763 | Practict_Example 1764 | spacker 1765 | SimpleCracker 1766 | Python-Example 1767 | twi20148.github.io 1768 | CSC2414 1769 | d43-mos 1770 | Install 1771 | IOS-2014-App 1772 | pyManel 1773 | codeapp 1774 | mancheo 1775 | cs 1776 | SSA 1777 | lak 1778 | TestProject 1779 | ArduinoUpTest 1780 | uninale 1781 | HOS 1782 | lim2014 1783 | node-sention 1784 | K003567.000 1785 | vtc 1786 | python-lige 1787 | python-Comm 1788 | minis-android 1789 | AP 1790 | qs1t 1791 | ProjectChack 1792 | my_ipap 1793 | MODEL1002010000 1794 | python-procal 1795 | dma 1796 | stack-packer 1797 | SimpleApplitation 1798 | mathing 1799 | githing 1800 | CSC4330-Project 1801 | GI 1802 | UnityCrack 1803 | android-quich-app 1804 | devire_server 1805 | grunt-slit 1806 | example_ample 1807 | HTML5-Assignment- 1808 | node2018.github.io 1809 | testing-tation 1810 | wonfert 1811 | MODEL1316010003 1812 | stacker 1813 | divalab 1814 | laminten.github.io 1815 | traging-web 1816 | python-diowe 1817 | offfchen.github.io 1818 | HTML5Framework 1819 | test_sic_project 1820 | qab 1821 | Internate-Android 1822 | testing_tation 1823 | d43-kcs 1824 | sanvin01.github.io 1825 | python-mo-web 1826 | 360-Project 1827 | BIOMD0000000571 1828 | project_pution 1829 | CSS3310-Project 1830 | usf2014 1831 | sad 1832 | testingry.com 1833 | mininist.github.io 1834 | contact 1835 | myan 1836 | Spring-Manager 1837 | node-xomple 1838 | lcg 1839 | bakesant.github.io 1840 | tallerver 1841 | grunt-gom 1842 | gat-lig 1843 | django_llan 1844 | ccp 1845 | scraper_app 1846 | dj 1847 | php-ieva 1848 | stlankun.github.io 1849 | A69869 1850 | scrigits.github.io 1851 | spring2 1852 | sprine 1853 | makebase.github.io 1854 | mevin.github.com 1855 | ProjectThack 1856 | levin.github.com 1857 | SimpleFination 1858 | SML-Projection 1859 | SimpleShocker 1860 | demonizy.com 1861 | lizinzar.github.io 1862 | python-ll 1863 | K057509.000 1864 | la 1865 | django-lag 1866 | django-lub 1867 | K003903.000 1868 | TestRap 1869 | docker-server 1870 | K003560.000 1871 | m4x1 1872 | Python-PExtinsion 1873 | hacker-2014 1874 | django-reanchate 1875 | Wo-Test 1876 | jankinge.github.io 1877 | minining.github.io 1878 | alchack 1879 | python-log 1880 | nite-ens 1881 | class-1 1882 | home5014.github.io 1883 | css-ppp 1884 | 2014-bracker 1885 | ts 1886 | repoted 1887 | grunt-Dem 1888 | testepp 1889 | example-deb 1890 | testrat 1891 | assignment 1892 | linculan.github.io 1893 | ccu 1894 | K005960.000 1895 | sterppp 1896 | java-test 1897 | saa 1898 | php-lub 1899 | A14905 1900 | COS-Server 1901 | WebServer 1902 | ssm 1903 | d43-bza 1904 | web-project-plugin 1905 | A36949 1906 | ste 1907 | django-mode-sarver 1908 | stading 1909 | nodejask.github.io 1910 | python20-Example 1911 | python-losine 1912 | K009669.000 1913 | bakebace.github.io 1914 | mikingin.github.io 1915 | HTML5-Dema 1916 | stlankuy.github.io 1917 | bichtp45.github.io 1918 | lan 1919 | sca-project 1920 | MyTore-Example 1921 | K111268.000 1922 | sds 1923 | pyshing 1924 | MODEL1306070019 1925 | Autorel 1926 | A49667 1927 | docker-peg 1928 | tuperma 1929 | my_trat 1930 | Testrat 1931 | TestBoy 1932 | node-xample 1933 | wew_project 1934 | MODEL1002010001 1935 | acd 1936 | backbone_android 1937 | project_enveration 1938 | marse.js 1939 | rails_app 1940 | starking.github.io 1941 | easy-lu 1942 | d43-mow 1943 | d43-add 1944 | schant13.github.io 1945 | Class-Framework 1946 | WebSong 1947 | vimitoch.github.io 1948 | MODEL1306090002 1949 | malonkas.github.io 1950 | MyFTest 1951 | hs 1952 | yangest 1953 | d43-mob 1954 | testery 1955 | goo 1956 | K032034.000 1957 | Class-Repository 1958 | repo120 1959 | githama 1960 | d43-lig 1961 | splanden.github.io 1962 | his 1963 | grt 1964 | rwm 1965 | django-varculator 1966 | MODEL1306060007 1967 | evento-web 1968 | ipen2014.github.io 1969 | libbery 1970 | android_rape 1971 | sjs 1972 | stacket 1973 | kam 1974 | Practico_Project 1975 | sxreevi.github.com 1976 | K003590.000 1977 | symfony.js 1978 | JavaScriptProject 1979 | sun 1980 | Testing-Android 1981 | UnityTrack 1982 | batchet 1983 | CS316-Assignment 1984 | datathan.github.io 1985 | kmd 1986 | py 1987 | Python3-LDb 1988 | K003599.000 1989 | simple-book 1990 | project-2014-boals 1991 | testing-nalization 1992 | martrojs.github.io 1993 | phypo.github.com 1994 | python-ing 1995 | geteriky.com 1996 | schineng.github.io 1997 | 2014-00.001 1998 | ruby-inger 1999 | BIOMD0000000201 2000 | Rand-Lab 2001 | -------------------------------------------------------------------------------- /samples/repos2_unique.txt: -------------------------------------------------------------------------------- 1 | CSS-130 2 | Dust-Repo 3 | project_putorial 4 | PODEL1003010005 5 | arduino-baicher 6 | CS14004 7 | MODEL1300060000 8 | node-xomple 9 | websing-test 10 | multing 11 | OpenCLM 12 | testing-xample 13 | test-pppitication 14 | icatest 15 | libevel 16 | CSS2015 17 | MODEL1006260004 18 | Java-SApplitation 19 | iOS-Webs 20 | 2014-Ex-Example 21 | generich.com 22 | saskApp 23 | rubyter 24 | meteor-fil 25 | karkang.github.com 26 | Hell-CApplication 27 | martrojs.github.io 28 | MODEL1910010001 29 | vanyard1.github.io 30 | dynashin.github.io 31 | CSS-150 32 | hellorky.com 33 | MyCore-Example 34 | meteor-block 35 | starling.github.io 36 | K001789.000 37 | gamer-jest 38 | commine 39 | sthasker.github.io 40 | angular-dutorial 41 | Bartorm 42 | scrapro 43 | Imagent 44 | project_prection 45 | nathob 46 | node-jectation 47 | SimpleFination 48 | pythonts_project 49 | TestingGim 50 | d43-bw 51 | TestAngro 52 | bakebace.github.io 53 | K040660.000 54 | TestBoy 55 | K035966.000 56 | ServitProject 57 | webson11.github.io 58 | dokining.github.io 59 | edsort 60 | K003966.000 61 | MODEL1302030004 62 | cs-plit 63 | CSC1465_Project 64 | MODEL1306200007 65 | arduino-ig 66 | csheront.github.io 67 | python-ligin 68 | Autorel 69 | simplexample 70 | CIC2014 71 | testing-pp 72 | battrit 73 | python-phome 74 | Arduino-Web 75 | proyection.com 76 | siniteck.github.io 77 | learninger 78 | Net-Cli 79 | Project-Kate 80 | IMI-Project 81 | SimpleApp-Android 82 | matacap 83 | css-spp 84 | imatemp 85 | css-app 86 | testing-pa 87 | kanka201.github.io 88 | K053966.000 89 | Project-Framework 90 | vagrant-dutorial 91 | testrat 92 | A43000 93 | untrick 94 | DS_Project 95 | python-phate 96 | singide 97 | sprine2 98 | PythonProjection 99 | my-n-201 100 | Mobile-Masker 101 | TRFTest 102 | projection.com 103 | spring-repo 104 | Proyecto-Example 105 | docker-masier 106 | K111268.000 107 | node-strt 108 | OpenGLI 109 | K009669.000 110 | actimith.github.io 111 | yangest 112 | myan 113 | MODEL1006010003 114 | cs32014 115 | CS500-Assignment 116 | ruby-ry 117 | testing-templation 118 | CS340_Project 119 | git-lig 120 | Devire-Example 121 | node-witwer-server 122 | kanka901.github.io 123 | K003567.000 124 | ServinProject 125 | comerad 126 | Arduino-Reb 127 | vin-project 128 | generiky.com 129 | python-lacking 130 | MODEL1306210018 131 | stanker 132 | asthamp 133 | kplind11.github.io 134 | uninale 135 | peesonProject 136 | K001989.000 137 | CSC-100-Project 138 | stading-leb 139 | peoserd 140 | tallerver 141 | testroject 142 | COMP-101-Project 143 | project_st 144 | vestark 145 | COSS-Repository 146 | Project_Extension 147 | mocki.github.io 148 | Clash_Repository 149 | K004510.000 150 | ipen2014.github.io 151 | A07707 152 | phypo.github.com 153 | come201 154 | Java-Extention 155 | Python-SExtension 156 | docker-pin 157 | starppp 158 | scripth 159 | django-seanchator 160 | schido11.github.io 161 | libyder 162 | asdho.github.io 163 | DS-Projec 164 | node-githhe-server 165 | webscass.github.io 166 | linex201.github.io 167 | CS490-Assignment2 168 | PHP-Script-Android 169 | pertgem 170 | learn-server 171 | K073555.000 172 | python-ple 173 | batanack.github.io 174 | affecher.github.io 175 | twetterver 176 | cs100-plater 177 | android-biolder 178 | A18599 179 | K003590.000 180 | ember-hation 181 | html564 182 | python-leation 183 | hentert 184 | althime 185 | repoted 186 | coderom 187 | bakesant.github.io 188 | django-tmep 189 | jabmavit.github.io 190 | testing4.com 191 | MODEL1906210000 192 | comming.js 193 | shackix 194 | MODEL1006080000 195 | ngw-project-plugin 196 | K079440.000 197 | wfffchen.github.io 198 | K006955.000 199 | lokining.github.io 200 | pythont-Demo 201 | Mystong 202 | CVS-Server 203 | Sprine-Framework 204 | nardinal.github.io 205 | suni201 206 | Project_Dema 207 | stlankin.github.io 208 | kanea900.github.io 209 | proyectapproject 210 | ProjectApilation 211 | MODEL1006060000 212 | COM3380_Project 213 | A69869 214 | cs3-ppp 215 | OpenScrept 216 | hackbook-app 217 | quaiz018.github.io 218 | design-pation 219 | marse.js 220 | K039655.000 221 | nodejask.github.io 222 | bancher 223 | django-log 224 | batchet 225 | turs 226 | K003575.000 227 | arman601.github.io 228 | likingat.github.io 229 | php-lliter 230 | CS150_Project 231 | CS236_Project 232 | tisck 233 | demilation 234 | testing-tation 235 | git-stript 236 | yaliy000.github.io 237 | Practict_Example 238 | lines501.github.io 239 | nodejs-application 240 | schite11.github.io 241 | grunt-gom 242 | wew_project 243 | CS109-Assignment2 244 | lokin.github.com 245 | Carling 246 | my_trat 247 | contect 248 | python-diowe 249 | 2014-01-00--xample 250 | test1001.github.io 251 | demonizy.com 252 | test-pra 253 | LearningCaid 254 | HTML5-Dema 255 | experdit.github.io 256 | creande 257 | samanter.github.io 258 | MODEL1002010008 259 | CSS3310-Project 260 | stlanten.github.io 261 | sthack01.github.io 262 | A46975 263 | project_enveration 264 | caninuto.github.io 265 | partiem 266 | kanga201.github.io 267 | test-ry 268 | reploct 269 | COM1300-Project 270 | LearningGaid 271 | openconfog 272 | project-plection 273 | django-lub 274 | d43-map 275 | node-mittolization 276 | WebSong 277 | devalib 278 | samashin.github.io 279 | Server_Example 280 | project-proc 281 | snl_project 282 | manasian.github.io 283 | d43-gik 284 | netre-search 285 | python-hest 286 | lincuran.github.io 287 | haseeng8.github.io 288 | sthacker.github.io 289 | bie-tus 290 | WebSApp 291 | ness3501.github.io 292 | simpley 293 | phytest 294 | macano10.github.io 295 | TestNot 296 | quickta 297 | php-ppit 298 | pynasien.github.io 299 | convast 300 | manco.github.com 301 | dresest 302 | RelloWer 303 | nodejsi 304 | Arduino-Tracker 305 | file2014.github.io 306 | convist 307 | python-lige 308 | MODEL1316010003 309 | django-mode-sarver 310 | phonestors 311 | project_proj 312 | my-hlog 313 | K003999.000 314 | IOS-2014 315 | ruby-mb 316 | easy-lo 317 | django-Example 318 | pythonf-whb 319 | django-ubsite 320 | python-sacking 321 | demonian.github.io 322 | rails-teplitation 323 | rubyser 324 | MT-Server 325 | python-losine 326 | projectapproject 327 | K008-84.000 328 | spacher 329 | frees.github.com 330 | MSS-2014 331 | scriter 332 | TestingClient 333 | K030084.000 334 | stackar 335 | convest 336 | wintemo 337 | go--ppp 338 | K003976.000 339 | monglob 340 | html5pp 341 | colming 342 | COMS-201-Project 343 | Assignment_Demo 344 | wdd-ppp 345 | K063688.000 346 | MODEL1006080003 347 | my_ipap 348 | COS3300_Project 349 | Raf 350 | TestGitg 351 | testing-web 352 | K051056.000 353 | python-apalication 354 | learning-wacker 355 | jakebase.github.io 356 | SML-Framework 357 | A69698 358 | Matling 359 | CSC1000 360 | ArduinoLig 361 | node-mitiolization 362 | lanux_201 363 | fcas 364 | mython-clack 365 | examplexample 366 | ruby-mo 367 | IOS-2014-App 368 | Server-Example 369 | Project_Jate 370 | My-Appp 371 | K005520.000 372 | CS401-Assignment 373 | K038581.000 374 | Network-Extensions 375 | PCI-2014 376 | creaspp 377 | helloder.github.io 378 | A47579 379 | iOD-Project 380 | testing-tramework 381 | K027996.000 382 | traving_test 383 | python-forect 384 | vinicich.github.io 385 | hacker-2014 386 | project_pest 387 | laminden.github.io 388 | node-pplication 389 | wertrojs 390 | malinian.github.io 391 | MODEL1000010000 392 | auto-an 393 | MODEL1006210000 394 | docker-ap 395 | python-ll 396 | test-preptitation 397 | javalabs.github.io 398 | randomp 399 | A36949 400 | project-2014-boals 401 | Clash-Repository 402 | MODEL1306080000 403 | ProjectCoinder 404 | lungjing 405 | carturs 406 | jchemang.github.io 407 | rest1405.github.io 408 | EOS-APP 409 | node-cole 410 | tesh-meb 411 | MODEL101615001 412 | faleb201.github.io 413 | echapant.github.io 414 | CSC4372_Project 415 | JavaScriptsytion 416 | testing_test 417 | python-simplate 418 | qs1t 419 | Stading-Test 420 | tuperes 421 | mangoon 422 | daskbonk.github.io 423 | meteor-stom 424 | idlinest.github.io 425 | my-q-2014 426 | sprine-2014 427 | MODEL1306200000 428 | molting 429 | python-servication 430 | MODEL1306060003 431 | python-mo-web 432 | java_has 433 | d43-fry 434 | Simple_Project 435 | strine-2014 436 | coderob 437 | montorp 438 | networky.com 439 | linux20 440 | CS3-App 441 | Rand-Lab 442 | K003558.000 443 | instara 444 | Project_Patilation 445 | hymar.github.io 446 | mash-cate 447 | libbery 448 | docker-pog 449 | Hell-CApplitation 450 | CSC2014 451 | web-ppp 452 | testing_xample 453 | webming-android 454 | twitter-sserver 455 | HelloEWer 456 | SymfonyManager 457 | sase1004.github.io 458 | geter-app 459 | SimpleApplitation 460 | Varavel_Project 461 | testrapt 462 | apsshat 463 | sanza901.github.io 464 | RPUTest 465 | android-couther 466 | Stading_Test 467 | CS160_Project 468 | test-201 469 | d43-ang 470 | morting 471 | poplexter 472 | SimpleSart 473 | marting 474 | test-an 475 | wot-projeck 476 | node-sim-project 477 | TestAngAution 478 | myliming.github.io 479 | cuntice 480 | samachin.github.io 481 | karling 482 | ims-app 483 | learning-fication 484 | ColePar 485 | Simple-Col-Android 486 | HTML5Framework 487 | CSC-150 488 | solicker.github.io 489 | comp300-project 490 | gamerap 491 | marking.js 492 | node-epp 493 | lak 494 | project-pation 495 | univero 496 | parte.js 497 | iOS-Lib 498 | pythont 499 | jankling.github.io 500 | CS-2014 501 | makebase.github.io 502 | Projet-Plugin 503 | django-bite 504 | Netwoik_Extensions 505 | JavaScriptitation 506 | openSTest 507 | php-essignment 508 | rubo.js 509 | cs110-plater 510 | kankinge.github.io 511 | java2014.github.io 512 | project_pution 513 | if101.github.io 514 | php-eximple 515 | marte.js 516 | Testrat 517 | python-ing 518 | my-s-201 519 | simple-book-hpp 520 | jevim 521 | nodejast.github.io 522 | php-wook 523 | liveraly.github.io 524 | grunt-bot 525 | nodelub 526 | testepp 527 | javaling.github.io 528 | python-jection 529 | sklingen.github.io 530 | Testing_Android 531 | HelloWer 532 | K053797.000 533 | CSC2414 534 | stripe_2014 535 | wib2014 536 | genero-plo 537 | Must-Repo 538 | partrojs.github.io 539 | COMS-2014 540 | hellor-ander 541 | tentele 542 | sterentest 543 | stingev 544 | K043969.000 545 | MODEL1006010009 546 | ywi20148.github.io 547 | MODEL1000060000 548 | django-varculator 549 | lingxen 550 | luzonbuy.github.io 551 | phonemis 552 | sprine 553 | play-chacker 554 | cord-cone 555 | jsharter.github.io 556 | randopp 557 | devaleb 558 | UnityTrack 559 | pyManel 560 | K099539.000 561 | Hellorver 562 | simple-servine 563 | project-patchase 564 | python-website 565 | MODEL1306070019 566 | K067958.000 567 | Project-Extention 568 | PHP-Project-pester 569 | hellowerexample 570 | ruby-inder 571 | iOS-Lub 572 | dynaper 573 | MODEL1306280000 574 | GIF-2014 575 | teather-demo 576 | android_rape 577 | preatacl.github.io 578 | CSS-106 579 | 2014-bractor 580 | Simple-Cot-Android 581 | python-0-Example 582 | kanka501.github.io 583 | Project_Autorial 584 | gamer-ject 585 | learn.github.io 586 | klnejack.github.io 587 | CS100-Assignment 588 | generikh.com 589 | Netwook-Extensions 590 | A40748 591 | class-dation 592 | conving 593 | project_pontonfing 594 | pymanite 595 | proyecto_alication 596 | oss-app 597 | docker_net 598 | minis-android 599 | SOP2014-Project 600 | SimpleFilation 601 | commane.js 602 | geteriss.com 603 | TestAngit 604 | Net-Projects 605 | django-canculator 606 | test-st 607 | Unitorl 608 | K007657.000 609 | example_js 610 | testerk 611 | kanjaing.github.io 612 | comviser 613 | 2014-01-00-Example 614 | demomen 615 | node-forework 616 | vimitoch.github.io 617 | repo1899.github.io 618 | K007997.000 619 | linexin 620 | twi20148.github.io 621 | jsong.github.io 622 | eventean.github.io 623 | node-xample 624 | TestGite 625 | BuilScript 626 | K005195.000 627 | node-ans 628 | 2014erver 629 | MODEL1006200000 630 | autoreb 631 | CS400-Assignment 632 | K057986.000 633 | sample-datch 634 | kalankan.github.io 635 | aptoter 636 | test-come 637 | A41849 638 | K003903.000 639 | python-plate 640 | Simple-Website-App 641 | neb-pest 642 | yong.js 643 | node2018.github.io 644 | android-mor 645 | OpenStor 646 | manickix.github.io 647 | K003998.000 648 | manging.js 649 | testing-nalization 650 | Practico_Project 651 | mikin.github.com 652 | testing-tolization 653 | GoogleApplitation 654 | MyChing 655 | K032034.000 656 | rest.github.io 657 | Arduino-Lid-Test 658 | android-quich-ap 659 | extrenss.com 660 | sterppp 661 | SimpleShocker 662 | Project-Controlder 663 | CSC4330-Project 664 | TestRepertin 665 | libaclin.github.io 666 | marding.js 667 | DataProjectice 668 | haseback.github.io 669 | python-jectice 670 | CS307-Assignment 671 | CS130_Project 672 | akeaster.github.io 673 | WoFTest 674 | spp-proje 675 | my-crocker 676 | jplank14.github.io 677 | nodej901.github.io 678 | malinden.github.io 679 | sinalian.github.io 680 | A46999 681 | Practice_Example 682 | mails_app 683 | node-move 684 | HTML5_Assignment 685 | Spring-Manager 686 | tesh-rip 687 | ruby-in 688 | python-list 689 | iOS-Web 690 | aksasten.github.io 691 | project-pontonsing 692 | Same2010Project 693 | mikeping.github.io 694 | block-tracker 695 | linux.github.com 696 | mandonation 697 | code-cole 698 | python-xample 699 | githing 700 | android-quich-app 701 | K056559.000 702 | web-project-plugin 703 | matapem 704 | IOP-Project 705 | adff2014.github.io 706 | coderow 707 | spacker 708 | Multing 709 | carlier 710 | scringit.github.io 711 | ballerver 712 | backbone_android 713 | django-mation-ble 714 | easy-lu 715 | django-lag-java 716 | Project-Rutorial 717 | MODEL1000000004 718 | CMI-2014 719 | mikining.github.io 720 | mangoom 721 | geteriky.com 722 | dynamrin.github.io 723 | git-proje 724 | mychang 725 | rails-teptitation 726 | COS-Server 727 | gupe208y.github.io 728 | calar201.github.io 729 | PHP-Project-Plugin 730 | kansa901.github.io 731 | usf2014 732 | gat-lig 733 | d4s 734 | comp301_project 735 | comp300_project 736 | K038757.000 737 | crick-tracker 738 | offfchen.github.io 739 | python_dest 740 | demonisy.com 741 | d43-mob 742 | P- 743 | myniter 744 | bacashen.github.io 745 | mikingin.github.io 746 | naneinen.github.io 747 | grunt-Dem 748 | TestEngration 749 | test_up 750 | matchin 751 | schinter.github.io 752 | milinist.github.io 753 | python-trawer 754 | CS416_Assignment-1 755 | ass-app 756 | pyshing 757 | 2014-00-Android 758 | node-jestion 759 | d43--ly 760 | baver-ject 761 | lotining.github.io 762 | jenami 763 | python-nestion 764 | 310-Project 765 | 2014-bracker 766 | twitted 767 | A67599 768 | SOS-2014 769 | sageswir.github.io 770 | minining.github.io 771 | nodelod 772 | K003984.000 773 | SimpleChacker 774 | test_fic_project 775 | django-extoration 776 | MyFTest 777 | PHP_RPP 778 | K003576.000 779 | learning-lization 780 | CECS-201-Project 781 | CS12014 782 | traging-web 783 | MODEL1006070002 784 | sanjing.js 785 | siz 786 | spring-rab 787 | K059597.000 788 | django-issage 789 | COM3300_Project 790 | repotes 791 | MODEL1006000030 792 | sim32014 793 | K005560.000 794 | qab 795 | CS140_Projection 796 | node-idt 797 | imagent 798 | git-spp 799 | Practico-Project 800 | sprine_2014 801 | lilycrsa.github.io 802 | K030599.000 803 | Marser 804 | django-somplate 805 | CSS-server 806 | SOS-Framework 807 | linculan.github.io 808 | BIOMD000000035L 809 | class_dest 810 | lingxem 811 | samer201.github.io 812 | testingrob 813 | home5014.github.io 814 | mongle-plate 815 | backbone-android 816 | mylim.github.io 817 | tuperma 818 | noves.js 819 | node-lat 820 | alchaph 821 | node-sention 822 | GitHuble-Android 823 | websime-books 824 | SimpleChocker 825 | K003556.000 826 | Proyecto_Example 827 | gja 828 | langbase.github.io 829 | A49667 830 | task-leb 831 | COS-2014 832 | 2014_201 833 | K003596.000 834 | testing-talization 835 | clask_pp 836 | K003996.000 837 | test-ph 838 | mylinian.github.io 839 | project-js 840 | java-cest 841 | splanden.github.io 842 | django-falculator 843 | SimpleCracker 844 | mancher 845 | vimling 846 | ArduinoUpTest 847 | test1004.github.io 848 | K003997.000 849 | python-lest 850 | juths.github.io 851 | likining.github.io 852 | oss-ppp 853 | damarian.github.io 854 | testing_mation 855 | schask01.github.io 856 | test-ap 857 | TextGime 858 | SOS-Projection 859 | djaigle 860 | maschoc 861 | karghath.github.io 862 | django-reanchate 863 | lekin.github.com 864 | SML-Projection 865 | 2014-00-201 866 | jankinge.github.io 867 | panefine 868 | stlanden.github.io 869 | server-server 870 | simple-book 871 | mankingo.github.io 872 | getericy.com 873 | HB-Server 874 | easy-lig 875 | schemant.github.io 876 | SimpleInventation 877 | django-samplate 878 | simple-amp 879 | A46979 880 | CS101-Assignment- 881 | ArduinoFil 882 | project-enveration 883 | git-ppp 884 | schister.github.io 885 | CS407-Assignments 886 | A10319 887 | K021974.000 888 | tudblog 889 | Google-Framework 890 | UnityCrack 891 | stlankuy.github.io 892 | StareNote 893 | node-filest-plugin 894 | stack-pocker 895 | python-pliner 896 | iCE-Project 897 | kanking 898 | django-tove-server 899 | ardasch 900 | manko.github.com 901 | Searn-Test 902 | grunt-chp 903 | emberwam.github.io 904 | py-pally 905 | py_look 906 | python-lagin 907 | servite-js 908 | SOC2014 909 | stardak 910 | Test-Ex-Android 911 | liventem.github.io 912 | stading 913 | RubyTack 914 | nodejow 915 | vagrant-plitation 916 | pong-ap 917 | wonfert 918 | python_2014 919 | coll2014 920 | MODEL1306090002 921 | scrigits.github.io 922 | malonkas.github.io 923 | CS160_Projection 924 | node-st 925 | node-filext-plugin 926 | K007239.000 927 | kput 928 | Arduino-Bation 929 | meteor-jh 930 | PHP-Compilication 931 | CSS-108 932 | testing_pp 933 | CSS-140 934 | PMS-Script 935 | django-hhalle 936 | sxreevi.github.com 937 | server_server 938 | Gemevise 939 | ArduinoForter 940 | COSS310_Project 941 | alchack 942 | jsang.github.io 943 | statiz91.github.io 944 | MODEL1000010001 945 | node-ximple 946 | APM-Project-Plugin 947 | K003994.000 948 | android-mon 949 | m4x1 950 | testing_tation 951 | denatio 952 | WRE 953 | Fushing 954 | K005960.000 955 | HellorWer 956 | simple-book-App 957 | stlinken.github.io 958 | MODEL1002010000 959 | stharmew.github.io 960 | demoning.github.io 961 | nodej008.github.io 962 | opento-web 963 | stripri 964 | COMP201_Project 965 | project-strection 966 | liban.github.com 967 | TestEngection 968 | Inthonal-Project 969 | ruby-mi 970 | jocaling.github.io 971 | django-bracker 972 | devire_server 973 | Calding 974 | batelig 975 | SPM-Project 976 | SOS-Project 977 | meter-amp 978 | opendons.com 979 | iPI-Project 980 | CSS-100 981 | symfony.js 982 | manking.js 983 | COSP345_Project 984 | manking 985 | django-pulitation 986 | django-pplitation 987 | node2008.github.io 988 | pss_tes 989 | backbook-app 990 | lasinget.github.io 991 | app-techer 992 | rss-tos 993 | iOA_Project 994 | node-vestalization 995 | continter 996 | SAM-Framework 997 | Internate-Android 998 | python-apprication 999 | ons-app 1000 | sprine-mation 1001 | OpenCollection 1002 | python-oralication 1003 | testing-ter-server 1004 | sanving.js 1005 | node-gate-project 1006 | aprinker.github.io 1007 | testemp 1008 | rails_teplitation 1009 | yai21084.github.io 1010 | Chiling 1011 | django-calization 1012 | CS406-Assignment 1013 | SOP_Project 1014 | pails_app 1015 | pynamite 1016 | mizening.github.io 1017 | java8014.github.io 1018 | simple-gervine 1019 | Github-App 1020 | OndineProject 1021 | SimpleWeApp 1022 | MODEL1006010000 1023 | BaskApp 1024 | node-jent-client 1025 | Persing-Android 1026 | MODEL1012010002 1027 | mangonation 1028 | COCS-2014 1029 | django-reserver 1030 | K003557.000 1031 | iss-app 1032 | datathan.github.io 1033 | websap91.github.io 1034 | datadew 1035 | K003296.000 1036 | HelloSrard 1037 | A14905 1038 | robrife 1039 | mininist.github.io 1040 | K029586.000 1041 | levaleo 1042 | Net-Ble 1043 | python-ly 1044 | Project-Bate 1045 | TestRap 1046 | scmindin.github.io 1047 | kthankun.github.io 1048 | -IA 1049 | K033276.000 1050 | auto-201 1051 | saminten.github.io 1052 | php-foot 1053 | ArduinoScation 1054 | Automes 1055 | SimpleIndentation 1056 | test_sic_project 1057 | lezill99.github.io 1058 | django-pplidations 1059 | Devore-Example 1060 | colting 1061 | lingring.github.io 1062 | 2014-01-20-Example 1063 | testing_came 1064 | sama2014.github.io 1065 | K003560.000 1066 | IS_Project 1067 | CS1590Project 1068 | spring-ripo 1069 | libener 1070 | Python20-Example 1071 | demolian.github.io 1072 | CS140_Project 1073 | reve.github.io 1074 | testingry.com 1075 | ServerText 1076 | kilinans.github.io 1077 | githama 1078 | BIOMD0000000571 1079 | schinked.github.io 1080 | 20142014 1081 | sline.js 1082 | Cass_Test 1083 | mancheo 1084 | MODEL1000000001 1085 | example-atp 1086 | BIOMD0000000000 1087 | LearningGome 1088 | ishineen.github.io 1089 | K047275.000 1090 | python-cestion 1091 | MODEL1006000000 1092 | evento-web 1093 | HTML-Web-Server 1094 | divalab 1095 | K005596.000 1096 | objec.github.com 1097 | jslange 1098 | vagrant-plication 1099 | came100w.github.io 1100 | colling 1101 | iS_Project 1102 | K003986.000 1103 | grunt-Assignment 1104 | googleg 1105 | assignmenter 1106 | epec 1107 | TestsBat 1108 | grunt-bo 1109 | PODEL1006090003 1110 | python-21 1111 | swadect 1112 | SimpleReplitation 1113 | RubyWer 1114 | pynapes 1115 | CS140-Projection 1116 | MODEL1106090000 1117 | CS316-Assignment 1118 | levalig 1119 | Python-PExtinsion 1120 | android-buolder 1121 | pythonPrat 1122 | 2014-12-Example 1123 | matates 1124 | provesk 1125 | python_lagin 1126 | Arduino-Gracker 1127 | idff2014.github.io 1128 | COSS-2014 1129 | test1901.github.io 1130 | testingection 1131 | PHP-Proje 1132 | CS32014 1133 | PergitProject 1134 | vanking-web 1135 | SES_Test 1136 | Java-CApplitation 1137 | simple-java-client 1138 | carting 1139 | test-ay 1140 | tilylock.github.io 1141 | CSP-340 1142 | test_prantion 1143 | php-tester 1144 | node-jection 1145 | Kastest 1146 | makisian.github.io 1147 | A46599 1148 | contast 1149 | node-rebood 1150 | realt-server 1151 | django-pplication 1152 | atheboes.github.io 1153 | example-deb 1154 | traving-web 1155 | LearningCoid 1156 | vcm-project 1157 | malling.github.com 1158 | nodej091.github.io 1159 | Linux-Framework 1160 | ClashTest 1161 | skatler 1162 | clash20 1163 | levin.github.com 1164 | django-upation 1165 | sanbol01.github.io 1166 | websa401.github.io 1167 | python-Comm 1168 | node-stre 1169 | stlanken.github.io 1170 | K025995.000 1171 | stack-packer 1172 | K003599.000 1173 | dynamhin.github.io 1174 | nodejs-wub 1175 | cttp 1176 | laby0106.github.io 1177 | lall2014.github.io 1178 | TestRepd 1179 | kangling.github.io 1180 | testrot 1181 | 20141018.github.io 1182 | php-rools 1183 | open120 1184 | 2048-cap 1185 | IOA-Project 1186 | tansasts.github.io 1187 | prading-web 1188 | HTML5-Assignment- 1189 | django-maplication 1190 | SOS_Projection 1191 | class-mation 1192 | repotem 1193 | jankiont.github.io 1194 | revin.github.io 1195 | web_ppp 1196 | ProjectThack 1197 | marling 1198 | Python3-LDb 1199 | stlankey.github.io 1200 | junkoden.github.io 1201 | bakebise.github.io 1202 | SimpleWeblitation 1203 | example_ample 1204 | MODEL1002010001 1205 | shating 1206 | tinylian.github.io 1207 | miangrap 1208 | marrepo 1209 | ccu 1210 | manking-web 1211 | Project-Extensions 1212 | lite-reb 1213 | Java-Applimation 1214 | django-plogin 1215 | baskbook.com 1216 | kanma011.github.io 1217 | MyFore-Example 1218 | scholles.github.io 1219 | linglig 1220 | parsoglo.github.io 1221 | jacaling.github.io 1222 | python20-Example 1223 | PYG2014 1224 | CS-2014-Project 1225 | tabysky2.github.io 1226 | block_tep 1227 | heacher 1228 | django-lotics 1229 | test_ly 1230 | LearningFider 1231 | node-gitwer-server 1232 | php-ieva 1233 | sagage44.github.io 1234 | django_llan 1235 | mevin.github.com 1236 | MODEL1102010017 1237 | spacket 1238 | shatroj 1239 | javalian.github.io 1240 | 945-Project 1241 | python-procal 1242 | lib-chapt 1243 | PardinProject 1244 | coderon 1245 | Project_Extensions 1246 | PODEL1306010003 1247 | sca-project 1248 | ServelTest 1249 | spring-lacking 1250 | ionnest 1251 | bite400m.github.io 1252 | same2014.github.io 1253 | MODEL0010200000 1254 | kinemayp.github.io 1255 | 2014-Ex-Android 1256 | K027699.000 1257 | datatech.github.io 1258 | unlond14.github.io 1259 | linux201 1260 | docker-porker 1261 | K014560.000-Plugin 1262 | test201 1263 | testing-lock 1264 | TestEnvection 1265 | rarding 1266 | node-fonework 1267 | Class-Repository 1268 | testing2.com 1269 | DataProjectite 1270 | inbermet 1271 | schellig.github.io 1272 | node-pp 1273 | bocker-201 1274 | simplestil 1275 | laravel.js 1276 | K007998.000 1277 | manchian.github.io 1278 | K057509.000 1279 | K005599.000 1280 | node-fe 1281 | dac20041.github.io 1282 | django-plugin 1283 | multibook 1284 | 2014-00-android 1285 | git-xpp 1286 | php-rester 1287 | nite-ens 1288 | modile 1289 | harkang.github.com 1290 | TestRis 1291 | Example_is 1292 | MSC-201 1293 | vagrant-slication 1294 | K009967.000 1295 | kankingo.github.io 1296 | barling 1297 | django-mpution 1298 | rest-come 1299 | iOS_Web 1300 | comerock.github.io 1301 | mathiel 1302 | libott 1303 | ybp 1304 | node2014 1305 | Wo-Test 1306 | stading-web 1307 | nodejs-work 1308 | marsony 1309 | MODEL1000000000 1310 | conminter 1311 | mynapes 1312 | test-preplitation 1313 | Java-CApplication 1314 | iOS-Leb 1315 | COPS-201-Project 1316 | project_mate 1317 | salintas.github.io 1318 | webset-ap 1319 | carding 1320 | Image-Pracker 1321 | wo-derp 1322 | bichtp45.github.io 1323 | django-too-chicker 1324 | perling 1325 | CSCI-201-Project 1326 | testger 1327 | CSP2014 1328 | DataDan 1329 | websing_test 1330 | nome2014.github.io 1331 | backbook.com 1332 | php-lub 1333 | node-projection 1334 | werting 1335 | TestMag 1336 | fuzzybook 1337 | perchep 1338 | K003278.000 1339 | demo.js 1340 | stlankun.github.io 1341 | abthack 1342 | lim2014 1343 | test-extest 1344 | ruby-inger 1345 | minicher.github.io 1346 | malindan.github.io 1347 | testing-past 1348 | d43-mo 1349 | 65211014.github.io 1350 | HTML5SApplication 1351 | node-sic-project 1352 | HelloChold 1353 | bandopp 1354 | K039288.000 1355 | K003205.000 1356 | MyShirt 1357 | kilingan.github.io 1358 | python-application 1359 | CSS-101 1360 | aplandit.github.io 1361 | schineng.github.io 1362 | vagrant-ipp 1363 | corling 1364 | python-pralication 1365 | cscraph 1366 | sim-lab 1367 | css-ppp 1368 | Penire 1369 | OplineApplication 1370 | TestPra 1371 | node-sile-project 1372 | SmartCone 1373 | MODEL1306060000 1374 | my-flew 1375 | Project-Jate 1376 | sanvin01.github.io 1377 | strinker.github.io 1378 | MyScame 1379 | TestDionApp 1380 | CS_1302_Project 1381 | sameriam.github.io 1382 | WebRime 1383 | Javascriptitation 1384 | rankin01.github.io 1385 | test-praplitation 1386 | CS-Appp 1387 | TestLag 1388 | liberes 1389 | CSS2014 1390 | stanter 1391 | Project-2014-Toals 1392 | proyecto_ds 1393 | git-project-plugin 1394 | K023555.000 1395 | MODEL1300040014 1396 | COCS-201-Project 1397 | python-Example 1398 | COSS-201-Project 1399 | django-palitation 1400 | laminten.github.io 1401 | K051676.000 1402 | jshemand.github.io 1403 | sinitecr.github.io 1404 | grunt-planter 1405 | spring-lest 1406 | testing-cate 1407 | Autores 1408 | mathing 1409 | K016804.000 1410 | schant13.github.io 1411 | code-ap 1412 | SimpleWer 1413 | Ousi 1414 | scrater 1415 | arduino_gement 1416 | lizinzar.github.io 1417 | dacalits.github.io 1418 | auto-ab 1419 | maney201.github.io 1420 | project-Framework 1421 | K042521.000 1422 | lizex201.github.io 1423 | lobigle 1424 | K058356.000 1425 | node-sentation 1426 | toke-pp 1427 | CS2086 1428 | 20141014.github.io 1429 | compander 1430 | lineuran.github.io 1431 | projects_web 1432 | faminden.github.io 1433 | Sarver-Examples 1434 | parming.js 1435 | dmbx 1436 | MODEL1306010000 1437 | ProjectChack 1438 | vagrant-chacking 1439 | Projectorution 1440 | grunt-slit 1441 | CS401-Assignment- 1442 | vagrent-plication 1443 | marto.js 1444 | dimaing 1445 | aplander.github.io 1446 | githeb 1447 | CSC2014-Project 1448 | mare2014.github.io 1449 | newhjlin.github.io 1450 | MyShine 1451 | MyTore-Example 1452 | rest-meb 1453 | MODEL1306060007 1454 | Project_Butilation 1455 | schineny.github.io 1456 | clyan.github.com 1457 | pynamis 1458 | django-lag 1459 | K033596.000 1460 | manisian.github.io 1461 | RuilScript 1462 | python-r.com 1463 | TestRepert 1464 | carling 1465 | 2014-00.001 1466 | 360-Project 1467 | docker-peg 1468 | linyert 1469 | firstar 1470 | Coursero_Android 1471 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | from sklearn.cross_validation import train_test_split 5 | 6 | import Utils 7 | from ShortTextCodec import ShortTextCodec, BinomialShortTextCodec 8 | from RBM import CharBernoulliRBM, CharBernoulliRBMSoftmax 9 | 10 | def stringify_param(name, value): 11 | if name == 'tag': 12 | prefix = '' 13 | else: 14 | prefix = ''.join([token[0] for token in name.split('_')]) 15 | 16 | if isinstance(value, bool): 17 | value = '' # The prefix alone tells us what we need to know - that this boolean param is the opposite of its default 18 | elif isinstance(value, float): 19 | # e.g. 1E-03 20 | value = '{:.0E}'.format(value) 21 | elif not isinstance(value, int) and not isinstance(value, basestring): 22 | raise ValueError("Don't know how to format {}".format(type(value))) 23 | return prefix + str(value) 24 | 25 | def pickle_name(args, parser): 26 | fname = args.input_fname.split('.')[0].split('/')[-1] 27 | fname += '_' 28 | for arg in ['tag', 'batch_size', 'n_hidden', 'softmax', 'learning_rate_backoff', 'preserve_case', 'epochs', 'learning_rate', 'weight_cost', 'left']: 29 | value = getattr(args, arg) 30 | if value != parser.get_default(arg): 31 | fname += '_' + stringify_param(arg, value) 32 | 33 | return fname + '.pickle' 34 | 35 | 36 | if __name__ == '__main__': 37 | # TODO: An option for checkpointing model every n epochs 38 | # TODO: Should maybe separate out vectorization and training? They're sort of 39 | # orthogonal (options like maxlen, preserve-case etc. don't even do anything 40 | # when starting from a pretrained model), and the options here are getting 41 | # bloated. 42 | parser = argparse.ArgumentParser(description='Train a character-level RBM on short texts', 43 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 44 | parser.add_argument('input_fname', metavar='txtfile', 45 | help='A text file to train on, with one instance per line') 46 | parser.add_argument('--test-ratio', dest='test_ratio', type=float, default=0.05, 47 | help='The ratio of data to hold out to monitor for overfitting') 48 | parser.add_argument('--no-softmax', dest='softmax', action='store_false', 49 | help='Don\'t use softmax visible units') 50 | parser.add_argument('--preserve-case', dest='preserve_case', action='store_true', 51 | help="Preserve case, rather than lowercasing all input strings. Increases size of visible layer substantially.") 52 | parser.add_argument('--binomial', action='store_true', help='Use the binomial text codec (for comma-separated two-part names)') 53 | parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, default=10, 54 | help='Size of a (mini)batch. This also controls # of fantasy particles.') 55 | parser.add_argument('--maxlen', dest='max_text_length', type=int, default=20, 56 | help='Maximum length of strings (i.e. # of softmax units).' + 57 | ' Longer lines in the input file will be ignored') 58 | parser.add_argument('--minlen', dest='min_text_length', type=int, default=0, 59 | help='Minimum length of strings. Shorter lines in input file will be ignored.') 60 | # TODO: It'd be cool to be able to say "take the n most frequent non-alpha characters in the input file" 61 | parser.add_argument('--extra-chars', dest='extra_chars', default=' ', 62 | help='Characters to consider in addition to [a-zA-Z]') 63 | parser.add_argument('--hid', '--hidden-units', dest='n_hidden', default=180, type=int, 64 | help='Number of hidden units') 65 | parser.add_argument('-l', '--learning-rate', dest='learning_rate', default=0.1, type=float, help="Learning rate.") 66 | parser.add_argument('--weight-cost', dest='weight_cost', default=0.0001, type=float, 67 | help='Multiplied by derivative of L2 norm on weights. Practical Guide recommends 0.0001 to start') 68 | parser.add_argument('--lr-backoff', dest='learning_rate_backoff', action='store_true', 69 | help='Gradually reduce the learning rate at each epoch') 70 | parser.add_argument('-e', '--epochs', dest='epochs', default=5, type=int, help="Number of times to cycle through the training data") 71 | parser.add_argument('--left', action='store_true', help='Pad strings shorter than maxlen from the left rather than the right.') 72 | parser.add_argument('-m', '--model', dest='model', default=None, 73 | help="Start from a previously trained model. Options affecting network topology will be ignored.") 74 | parser.add_argument('--tag', dest='tag', default='', 75 | help='A name for this run. The model will be pickled to ' + 76 | 'a corresponding filename. That name will already encode ' + 77 | 'important hyperparams.') 78 | 79 | args = parser.parse_args() 80 | 81 | # TODO: trap ctrl+c and pickle the model before bailing 82 | 83 | # If the path to a pretrained, pickled model was provided, resurrect it, and 84 | # update the attributes that make sense to change (stuff like #hidden units, 85 | # or max string length of course can't be changed) 86 | if args.model: 87 | f = open(args.model) 88 | rbm = pickle.load(f) 89 | f.close() 90 | rbm.learning_rate = args.learning_rate 91 | rbm.base_learning_rate = args.learning_rate 92 | rbm.lr_backoff = args.learning_rate_backoff 93 | rbm.n_iter = args.epochs 94 | rbm.batch_size = args.batch_size 95 | rbm.weight_cost = args.weight_cost 96 | codec = rbm.codec 97 | else: 98 | codec_kls = BinomialShortTextCodec if args.binomial else ShortTextCodec 99 | codec = codec_kls(args.extra_chars, args.max_text_length, 100 | args.min_text_length, args.preserve_case, args.left) 101 | model_kwargs = {'codec': codec, 102 | 'n_components': args.n_hidden, 103 | 'learning_rate': args.learning_rate, 104 | 'lr_backoff': args.learning_rate_backoff, 105 | 'n_iter': args.epochs, 106 | 'verbose': 1, 107 | 'batch_size': args.batch_size, 108 | 'weight_cost': args.weight_cost, 109 | } 110 | kls = CharBernoulliRBMSoftmax if args.softmax else CharBernoulliRBM 111 | rbm = kls(**model_kwargs) 112 | 113 | vecs = Utils.vectors_from_txtfile(args.input_fname, codec) 114 | train, validation = train_test_split(vecs, test_size=args.test_ratio) 115 | print "Training data shape : " + str(train.shape) 116 | 117 | rbm.fit(train, validation) 118 | out_fname = pickle_name(args, parser) 119 | f = open(out_fname, 'wb') 120 | pickle.dump(rbm, f) 121 | f.close() 122 | print "Wrote model to " + out_fname 123 | --------------------------------------------------------------------------------