├── LICENSE ├── datasets ├── world_order.py ├── pypi_lang.py ├── preprocess.py ├── iter_dataset.py └── dataset.py ├── test.py ├── dataloader.py ├── model.py ├── utils.py ├── trainer.py ├── train.py ├── .gitignore ├── sgns_loss.py └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Devin de Hueck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/world_order.py: -------------------------------------------------------------------------------- 1 | from .dataset import SkipGramDataset 2 | import re 3 | 4 | 5 | class WorldOrderDataset(SkipGramDataset): 6 | 7 | def __init__(self, args, examples_path=None, dict_path=None): 8 | SkipGramDataset.__init__(self, args) 9 | self.name = 'World Order Book Dataset' 10 | self.queries = ['nuclear', 'mankind', 'khomeini', 'ronald'] 11 | 12 | if examples_path is not None and dict_path is not None: 13 | self.load(examples_path, dict_path) 14 | else: 15 | self.files = self.tokenize_files() 16 | self.generate_examples_serial() 17 | 18 | print(f'There are {len(self.dictionary)} tokens and {len(self.examples)} examples.') 19 | 20 | def load_files(self): 21 | return self.files 22 | 23 | def tokenize_files(self): 24 | files = [] 25 | with open('data/world_order_kissinger.txt') as f: 26 | for line in f: 27 | words_no_dig_punc = (re.sub(r'[^\w]', ' ', line.lower())).split() 28 | words_no_dig_punc = [x for x in words_no_dig_punc if not any(c.isdigit() for c in x)] 29 | files.append(words_no_dig_punc) 30 | return files 31 | -------------------------------------------------------------------------------- /datasets/pypi_lang.py: -------------------------------------------------------------------------------- 1 | from .dataset import SkipGramDataset 2 | import pandas as pd 3 | from .preprocess import Tokenizer 4 | from tqdm import tqdm 5 | 6 | 7 | class PyPILangDataset(SkipGramDataset): 8 | 9 | def __init__(self, args, examples_path=None, dict_path=None): 10 | SkipGramDataset.__init__(self, args) 11 | self.name = 'PyPI Language Dataset' 12 | self.queries = ['tensorflow', 'pytorch', 'nlp', 'performance', 'encryption'] 13 | 14 | if examples_path is not None and dict_path is not None: 15 | self.load(examples_path, dict_path) 16 | else: 17 | self.tokenizer = Tokenizer(args) 18 | self.files = self.tokenize_files() 19 | self.generate_examples_serial() 20 | 21 | self.save('pypi_examples.pth', 'pypi_dict.pth') 22 | 23 | print(f'There are {len(self.dictionary)} tokens and {len(self.examples)} examples.') 24 | 25 | def load_files(self): 26 | return self.files 27 | 28 | def tokenize_files(self): 29 | node_lang_df = pd.read_csv(self.args.dataset_dir, na_filter=False) 30 | lang_data = node_lang_df['language'].values[:50000] 31 | return [self.tokenizer.tokenize_doc(f) for f in tqdm(lang_data, desc='Tokenizing Docs')] 32 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as t 3 | import torch.nn as nn 4 | 5 | x = np.array([0.87, 0.23, 0.65, 0.99, 0.003, -0.12, 1.3]) 6 | y = np.array([1., 0., 1., 1., 0., 0., 1.]) 7 | 8 | x2 = np.array([0.12, 0.43, 0.789, 0.63, 0.213, -0.34, 1.4]) 9 | y2 = np.array([0., 1., 1., 1., 0., 0., 1.]) 10 | 11 | xt = t.from_numpy(x) 12 | yt = t.from_numpy(y) 13 | 14 | x2t = t.from_numpy(x2) 15 | y2t = t.from_numpy(y2) 16 | 17 | true_loss = nn.BCEWithLogitsLoss() 18 | 19 | print("##########################################") 20 | print(f"The value we want is: {true_loss(xt, yt) + true_loss(x2t, y2t)}") 21 | print("##########################################") 22 | 23 | 24 | def torch_loss(input, target): 25 | max_val = input.clamp(min=0) 26 | loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() 27 | return loss.mean() 28 | 29 | 30 | print("\n##########################################") 31 | print(f"TORCH Loss is: {torch_loss(xt, yt) + torch_loss(x2t, y2t)}") 32 | print("##########################################") 33 | 34 | 35 | def my_loss(x, y): 36 | max_val = np.clip(x, 0, None) 37 | loss = x - x * y + max_val + np.log(np.exp(-max_val) + np.exp((-x - max_val))) 38 | return loss.mean() 39 | 40 | 41 | print("\n##########################################") 42 | print(f"MY Loss is: {my_loss(x, y) + my_loss(x2t, y2t)}") 43 | print("##########################################") -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | from torch.utils import data 3 | 4 | 5 | def numpy_collate(batch): 6 | if isinstance(batch[0], onp.ndarray): 7 | return onp.stack(batch) 8 | elif isinstance(batch[0], (tuple, list)): 9 | transposed = zip(*batch) 10 | return [numpy_collate(samples) for samples in transposed] 11 | else: 12 | return onp.array(batch) 13 | 14 | 15 | class NumpyLoader(data.DataLoader): 16 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, 17 | pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 18 | super(self.__class__, self).__init__(dataset, 19 | batch_size=batch_size, 20 | shuffle=shuffle, 21 | sampler=sampler, 22 | batch_sampler=batch_sampler, 23 | num_workers=num_workers, 24 | collate_fn=numpy_collate, 25 | pin_memory=pin_memory, 26 | drop_last=drop_last, 27 | timeout=timeout, 28 | worker_init_fn=worker_init_fn 29 | ) 30 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import torch 3 | import numpy as onp 4 | 5 | 6 | class SkipGramEmbeddings: 7 | 8 | def __init__(self, vocab_size, embed_len): 9 | super(SkipGramEmbeddings, self).__init__() 10 | # This initialization is important! 11 | torch_embed = torch.nn.Embedding(vocab_size, embed_len) 12 | self.word_embeds = jax.numpy.array(torch_embed.weight.detach().numpy()) 13 | 14 | def forward(self, center, context): 15 | """ 16 | Acts as a lookup for the center and context words' embeddings 17 | 18 | :param center: The center words indicies 19 | :param context: The context words indicies 20 | :return: The embedding parameters 21 | """ 22 | return self.word_embeds[center], self.word_embeds[context] 23 | 24 | @staticmethod 25 | def nearest_neighbors(word, dictionary, vectors): 26 | """ 27 | Finds vector closest to word_idx vector 28 | :param word: String 29 | :param dictionary: Gensim dictionary object 30 | :return: Integer corresponding to word vector in self.word_embeds 31 | """ 32 | vectorsW = onp.asarray(vectors) 33 | index = dictionary.token2id[word] 34 | query = vectors[index] 35 | 36 | ranks = vectorsW.dot(query).squeeze() 37 | denom = query.T.dot(query).squeeze() 38 | denom = denom * onp.sum(vectorsW ** 2, 1) 39 | denom = onp.sqrt(denom) 40 | ranks = ranks / denom 41 | mostSimilar = [] 42 | [mostSimilar.append(idx) for idx in ranks.argsort()[::-1]] 43 | nearest_neighbors = mostSimilar[:10] 44 | nearest_neighbors = [dictionary[comp] for comp in nearest_neighbors] 45 | 46 | return nearest_neighbors 47 | 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import numpy as np 3 | 4 | 5 | class AliasMultinomial(object): 6 | """ 7 | Fast sampling from a multinomial distribution. 8 | https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 9 | 10 | Code taken from: https://github.com/TropComplique/lda2vec-pytorch/blob/master/utils/alias_multinomial.py 11 | """ 12 | 13 | def __init__(self, probs): 14 | """ 15 | probs: a float tensor with shape [K]. 16 | It represents probabilities of different outcomes. 17 | There are K outcomes. Probabilities sum to one. 18 | """ 19 | K = len(probs) 20 | self.q = t.zeros(K) 21 | self.J = t.LongTensor([0] * K) 22 | 23 | # sort the data into the outcomes with probabilities 24 | # that are larger and smaller than 1/K 25 | smaller = [] 26 | larger = [] 27 | for kk, prob in enumerate(probs): 28 | self.q[kk] = K * prob 29 | if self.q[kk] < 1.0: 30 | smaller.append(kk) 31 | else: 32 | larger.append(kk) 33 | 34 | # loop though and create little binary mixtures that 35 | # appropriately allocate the larger outcomes over the 36 | # overall uniform mixture 37 | while len(smaller) > 0 and len(larger) > 0: 38 | small = smaller.pop() 39 | large = larger.pop() 40 | 41 | self.J[small] = large 42 | self.q[large] = (self.q[large] - 1.0) + self.q[small] 43 | 44 | if self.q[large] < 1.0: 45 | smaller.append(large) 46 | else: 47 | larger.append(large) 48 | 49 | self.q.clamp(0.0, 1.0) 50 | self.J.clamp(0, K - 1) 51 | 52 | def draw(self, N): 53 | """Draw N samples from the distribution.""" 54 | 55 | K = self.J.size(0) 56 | r = t.LongTensor(np.random.randint(0, K, size=N)) 57 | q = self.q.index_select(0, r).clamp(0.0, 1.0) 58 | j = self.J.index_select(0, r) 59 | b = t.bernoulli(q) 60 | oq = r.mul(b.long()) 61 | oj = j.mul((1 - b).long()) 62 | 63 | return (oq + oj).numpy() 64 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from jax import jit, grad 2 | import optax as optim 3 | from dataloader import NumpyLoader 4 | from model import SkipGramEmbeddings 5 | from sgns_loss import SGNSLoss 6 | from tqdm import tqdm 7 | # from datasets.pypi_lang import PyPILangDataset 8 | from datasets.world_order import WorldOrderDataset 9 | from functools import partial 10 | import numpy as np 11 | 12 | 13 | class Trainer: 14 | 15 | def __init__(self, args): 16 | # Load data 17 | self.args = args 18 | self.dataset = WorldOrderDataset(args)#, examples_path='data/pypi_examples.pth', dict_path='data/pypi_dict.pth') 19 | self.vocab_size = len(self.dataset.dictionary) 20 | print("Finished loading dataset") 21 | 22 | self.dataloader = NumpyLoader(self.dataset, batch_size=args.batch_size, 23 | shuffle=True, num_workers=args.workers) 24 | 25 | self.model = SkipGramEmbeddings(self.vocab_size, args.embedding_len) 26 | self.sgns = SGNSLoss(self.dataset) 27 | # Set up optimizer - rmsprop seems to work the best 28 | optimizer = optim.adam(args.lr) 29 | self.opt_init = optimizer.init 30 | self.opt_update = optimizer.update 31 | self.apply_updates = optim.apply_updates 32 | 33 | @partial(jit, static_argnums=(0,)) 34 | def update(self, params, opt_state, batch): 35 | g = grad(self.sgns.forward)(params, batch) 36 | updates, opt_state = self.opt_update(g, opt_state) 37 | params = self.apply_updates(params, updates) 38 | return opt_state, params, g 39 | 40 | def train(self): 41 | # Initialize optimizer state! 42 | params = self.model.word_embeds 43 | opt_state = self.opt_init(params) 44 | for epoch in range(self.args.epochs): 45 | print(f'Beginning epoch: {epoch + 1}/{self.args.epochs}') 46 | for i, batch in enumerate(tqdm(self.dataloader)): 47 | opt_state, params, g = self.update(params, opt_state, batch) 48 | self.log_step(epoch, params, g) 49 | 50 | 51 | 52 | def log_step(self, epoch, params, g): 53 | print(f'EPOCH: {epoch} | GRAD MAGNITUDE: {np.sum(g)}') 54 | # Log embeddings! 55 | print('\nLearned embeddings:') 56 | for word in self.dataset.queries: 57 | print(f'word: {word} neighbors: {self.model.nearest_neighbors(word, self.dataset.dictionary, params)}') 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from trainer import Trainer 3 | import torch as t 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(description="PyTorch LDA2Vec Training") 8 | 9 | """ 10 | Data handling 11 | """ 12 | parser.add_argument('--dataset-dir', type=str, default='data/', 13 | help='dataset directory (default: data/)') 14 | parser.add_argument('--workers', type=int, default=1, metavar='N', 15 | help='dataloader threads (default: 4)') 16 | parser.add_argument('--window-size', type=int, default=5, help='Window size\ 17 | used when generating training examples (default: 5)') 18 | parser.add_argument('--file-batch-size', type=int, default=250, help='Batch size\ 19 | used when multi-threading the generation of training examples\ 20 | (default: 250)') 21 | 22 | """ 23 | Model Parameters 24 | """ 25 | parser.add_argument('--embedding-len', type=int, default=128, help='Length of\ 26 | embeddings in model (default: 128)') 27 | 28 | """ 29 | Training Hyperparameters 30 | """ 31 | parser.add_argument('--epochs', type=int, default=15, metavar='N', 32 | help='number of epochs to train for - iterations over the dataset (default: 15)') 33 | parser.add_argument('--batch-size', type=int, default=1024, 34 | metavar='N', help='number of examples in a training batch (default: 1024)') 35 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 36 | help='learning rate (default: 1e-3)') 37 | parser.add_argument('--seed', type=int, default=42, metavar='S', 38 | help='random seed (default: 42)') 39 | 40 | """ 41 | Checkpoint options 42 | """ 43 | parser.add_argument('--log-step', type=int, default=250, help='Step at which for every step training info\ 44 | is logged. (default: 250)') 45 | 46 | """ 47 | Training Settings 48 | """ 49 | parser.add_argument('--device', type=str, default=t.device("cuda:0" if t.cuda.is_available() else "cpu"), 50 | help='device to train on (default: cuda:0 if cuda is available otherwise cpu)') 51 | 52 | return parser.parse_args() 53 | 54 | 55 | if __name__ == '__main__': 56 | args = get_args() 57 | trainer = Trainer(args) 58 | # Begin Training! 59 | trainer.train() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | experiments/ 3 | .idea/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /sgns_loss.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import jax.nn as nn 3 | from utils import AliasMultinomial 4 | 5 | 6 | class SGNSLoss: 7 | BETA = 0.75 # exponent to adjust sampling frequency 8 | NUM_SAMPLES = 15 9 | 10 | def __init__(self, dataset): 11 | super(SGNSLoss, self).__init__() 12 | self.dataset = dataset 13 | self.vocab_len = len(dataset.dictionary) 14 | 15 | # Helpful values for unigram distribution generation 16 | # Should use cfs instead but: https://github.com/RaRe-Technologies/gensim/issues/2574 17 | self.transformed_freq_vec = np.array([dataset.dictionary.dfs[i] for i in range(self.vocab_len)]) ** self.BETA 18 | self.freq_sum = np.sum(self.transformed_freq_vec) 19 | # Generate table 20 | self.unigram_table = self.generate_unigram_table() 21 | 22 | def forward(self, params, batch): 23 | # Unpack data 24 | center_ids, context_ids = batch 25 | # Get vectors 26 | center, context = params[center_ids], params[context_ids] 27 | # Squeeze into dimensions we want 28 | center, context = center.squeeze(), context.squeeze() # batch_size x embed_size 29 | 30 | # Compute true portion 31 | true_scores = (center * context).sum(-1) # batch_size 32 | loss = self.bce_loss_w_logits(true_scores, np.ones_like(true_scores)) 33 | 34 | # Compute negatively sampled portion - NUM_SAMPLES # of negative samples for each true context 35 | for i in range(self.NUM_SAMPLES): 36 | samples = self.get_unigram_samples(n=center.shape[0], word_embeds=params) 37 | neg_sample_scores = (center * samples).sum(-1) 38 | # Update loss 39 | loss += self.bce_loss_w_logits(neg_sample_scores, np.zeros_like(neg_sample_scores)) 40 | 41 | return loss 42 | 43 | @staticmethod 44 | def bce_loss_w_logits(x, y): 45 | max_val = np.clip(-x, 0, None) 46 | loss = (1-y) * x + max_val + np.log(np.exp(-max_val) + np.exp((-x - max_val))) 47 | return loss.mean() 48 | 49 | def get_unigram_samples(self, n, word_embeds): 50 | """ 51 | Returns a sample according to a unigram distribution 52 | Randomly choose a value from self.unigram_table 53 | """ 54 | rand_idxs = self.unigram_table.draw(n) 55 | return word_embeds[rand_idxs].squeeze() 56 | 57 | def get_unigram_prob(self, token_idx): 58 | return (self.transformed_freq_vec[token_idx].item()) / self.freq_sum.item() 59 | 60 | def generate_unigram_table(self): 61 | # Probability at each index corresponds to probability of selecting that token 62 | pdf = [self.get_unigram_prob(t_idx) for t_idx in range(0, self.vocab_len)] 63 | # Generate the table from PDF 64 | return AliasMultinomial(pdf) 65 | -------------------------------------------------------------------------------- /datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import re 3 | 4 | 5 | class Tokenizer: 6 | 7 | def __init__(self, args, custom_stop=set()): 8 | self.args = args 9 | self.custom_stop = custom_stop 10 | # Define pipeline - use different nlp is using pretrained 11 | self.nlp = spacy.load("en_core_web_sm", disable=[]) 12 | # Merge named entities 13 | merge_ents = self.nlp.create_pipe("merge_entities") 14 | self.nlp.add_pipe(merge_ents) 15 | 16 | def tokenize_doc(self, doc_str): 17 | """ 18 | Tokenize a document string 19 | Modified version of Moody's Tokenization in: 20 | https://github.com/cemoody/lda2vec/blob/master/lda2vec/preprocess.py 21 | 22 | :params doc_str: String 23 | :returns: list of Strings, i.e. tokens 24 | """ 25 | 26 | # Send doc_str through pipeline 27 | spacy_doc = self.nlp(doc_str) 28 | # Filter 29 | filtered_doc = filter(self.is_valid_token, spacy_doc) 30 | # Convert to text make lowercase 31 | clean_doc = [t.text.lower().strip() for t in filtered_doc] 32 | # Only allow characters in the alphabet, '_', and digits 33 | clean_doc = [re.sub('[^a-zA-Z0-9]', '', t) for t in clean_doc] 34 | # Remove any resulting empty indices 35 | clean_doc = [t for t in clean_doc if len(t) > 0] 36 | # Filter out any custom stop 37 | clean_doc = [t for t in clean_doc if t not in self.custom_stop] 38 | 39 | return clean_doc 40 | 41 | def is_valid_token(self, token): 42 | """ 43 | Determines if a token is valid or not 44 | 45 | :params token: String 46 | :returns: Boolean 47 | """ 48 | if token.like_url: 49 | return False 50 | if token.like_email: 51 | return False 52 | if token.is_stop or token.text in self.custom_stop: 53 | return False 54 | 55 | return True 56 | 57 | def moodys_merge_noun_chunks(self, doc): 58 | """ 59 | Merge noun chunks into a single token. 60 | 61 | Modified from sources of: 62 | - https://github.com/cemoody/lda2vec/blob/master/lda2vec/preprocess.py 63 | - https://spacy.io/api/pipeline-functions#merge_noun_chunks 64 | 65 | :params doc: Doc object. 66 | :returns: Doc object with merged noun chunks. 67 | """ 68 | bad_deps = ('amod', 'compound') 69 | 70 | if not doc.is_parsed: 71 | return doc 72 | with doc.retokenize() as retokenizer: 73 | for np in doc.noun_chunks: 74 | 75 | # Only keep adjectives and nouns, e.g. "good ideas" 76 | while len(np) > 1 and np[0].dep_ not in bad_deps: 77 | np = np[1:] 78 | 79 | if len(np) > 1: 80 | # Merge NPs 81 | attrs = {"tag": np.root.tag, "dep": np.root.dep} 82 | retokenizer.merge(np, attrs=attrs) 83 | return doc 84 | 85 | 86 | class SimpleTokenizer: 87 | 88 | def __init__(self, args, custom_stop=set()): 89 | self.args = args 90 | self.custom_stop = custom_stop 91 | 92 | def tokenize_doc(self, doc_str): 93 | # Filter out any custom stop 94 | clean_doc = [t for t in doc_str.split() if t not in self.custom_stop] 95 | return clean_doc 96 | -------------------------------------------------------------------------------- /datasets/iter_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import pandas as pd 5 | from itertools import cycle 6 | from gensim.corpora import Dictionary 7 | from torch.utils.data.dataset import Dataset 8 | 9 | 10 | class SkipGramDataset(Dataset): 11 | 12 | def __init__(self, args, fname='data/pypi_nodes_lang.csv'): 13 | self.args = args 14 | self.fname = fname 15 | self.dictionary = None 16 | self.examples = [] 17 | self.name = '' 18 | 19 | def _get_generator(self): 20 | for df in pd.read_csv(self.fname, header=None, chunksize=1): 21 | doc = self._tokenize(df.values) 22 | for example in self._generate_examples_from_file(doc): 23 | yield example 24 | 25 | def _tokenize(self, doc): 26 | return doc.split() 27 | 28 | def __iter__(self, index): 29 | return cycle(self._get_generator()) 30 | 31 | def _build_dictionary(self): 32 | """ 33 | Creates a Gensim Dictionary 34 | :return: None - modifies self.dictionary 35 | """ 36 | print("Building Dictionary...") 37 | self.dictionary = Dictionary(self.load_files()) 38 | 39 | def _generate_examples_from_file(self, file): 40 | """ 41 | Generate all examples from a file within window size 42 | :param file: File from self.files 43 | :returns: List of examples 44 | """ 45 | 46 | examples = [] 47 | for i, token in enumerate(file): 48 | if token == -1: 49 | # Out of dictionary token 50 | continue 51 | 52 | # Generate context tokens for the current token 53 | context_words = self._generate_contexts(i, file) 54 | 55 | # Form Examples: 56 | # center, context - follows form: (input, target) 57 | new_examples = [(token, ctxt) for ctxt in context_words if ctxt != -1] 58 | 59 | # Add to class 60 | examples.extend(new_examples) 61 | return examples 62 | 63 | def _generate_contexts(self, token_idx, tokenized_doc): 64 | """ 65 | Generate Token's Context Words 66 | Generates all the context words within the window size defined 67 | during initialization around token. 68 | 69 | :param token_idx: Index at which center token is found in tokenized_doc 70 | :param tokenized_doc: List - Document broken into tokens 71 | :returns: List of context words 72 | """ 73 | contexts = [] 74 | # Iterate over each position in window 75 | for w in range(-self.args.window_size, self.args.window_size + 1): 76 | context_pos = token_idx + w 77 | 78 | # Make sure current center and context are valid 79 | is_outside_doc = context_pos < 0 or context_pos >= len(tokenized_doc) 80 | center_is_context = token_idx == context_pos 81 | 82 | if is_outside_doc or center_is_context: 83 | # Not valid - skip to next window position 84 | continue 85 | 86 | contexts.append(tokenized_doc[context_pos]) 87 | return contexts 88 | 89 | def _example_to_tensor(self, center, target): 90 | """ 91 | Takes raw example and turns it into tensor values 92 | 93 | :params example: Tuple of form: (center word, document id) 94 | :params target: String of the target word 95 | :returns: A tuple of tensors 96 | """ 97 | center, target = np.array([int(center)]), np.array([int(target)]) 98 | return center, target 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Skip-Gram Negative Sampling Built in Jax 2 | 3 | A jax implementation of word2vec's skip-gram model with negative sampling as described in Mikolov et al., 2013. I found that with jax this model runs about 5x-10x times faster than a raw pytorch implementation on the CPU (Have note yet tested on a GPU)! 4 | 5 | ## Tokenizing a dataset: 6 | 7 | The code provided should be easily extendable to other datasets. Please see `datasets/dataset.py` for the base class you should inherit in your own (just like `datasets/world_order.py` does). Feel free to create an issue if you are having trouble. 8 | 9 | ```python 10 | from .dataset import SkipGramDataset 11 | from .preprocess import Tokenizer 12 | 13 | 14 | class ExampleDataset(SkipGramDataset): 15 | 16 | def __init__(self, args, examples_path=None, dict_path=None): 17 | SkipGramDataset.__init__(self, args) 18 | self.name = 'Example Dataset' 19 | self.queries = ['words', 'to', 'watch', 'during', 'training'] 20 | 21 | if examples_path is not None and dict_path is not None: 22 | self.load(examples_path, dict_path) 23 | else: 24 | self.tokenizer = Tokenizer(args) 25 | # Set self.files to a list of tokenized data! 26 | self.files = self.tokenize_files() 27 | # Generates examples in window size - e.g. (center_idx, context_idx) 28 | self.generate_examples_serial() 29 | # Save dataset files - this tokenization and example generation can take awhile with a lot of data 30 | self.save('training_examples.pth', 'dictionary.pth') 31 | 32 | print(f'There are {len(self.dictionary)} tokens and {len(self.examples)} examples.') 33 | 34 | def load_files(self): 35 | """ Requires by SkipGramDataset to generate examples - must be tokenized files """ 36 | return self.files 37 | 38 | def tokenize_files(self): 39 | # read in from a file or wherever your data is kept 40 | raw_data = ["this is document_1", "this is document_2", ..., "this is document_n"] 41 | return [self.tokenizer.tokenize_doc(f) for f in raw_data] 42 | 43 | ``` 44 | 45 | ## Running the model: 46 | 47 | There are many hyperparameters that are easy to set via command-line arguments when calling `train.py`: 48 | 49 | Example: 50 | 51 | ``python train.py --embedding-len 64 --batch-size 2048 --epochs 500`` 52 | 53 | All hyperparameters in `train.py`: 54 | 55 | ``` 56 | optional arguments: 57 | -h, --help show this help message and exit 58 | --dataset-dir DATASET_DIR 59 | dataset directory (default: data/) 60 | --workers N dataloader threads (default: 4) 61 | --window-size WINDOW_SIZE 62 | Window size used when generating training examples 63 | (default: 5) 64 | --file-batch-size FILE_BATCH_SIZE 65 | Batch size used when multi-threading the generation of 66 | training examples (default: 250) 67 | --embedding-len EMBEDDING_LEN 68 | Length of embeddings in model (default: 128) 69 | --epochs N number of epochs to train for - iterations over the 70 | dataset (default: 15) 71 | --batch-size N number of examples in a training batch (default: 1024) 72 | --lr LR learning rate (default: 1e-3) 73 | --seed S random seed (default: 42) 74 | --log-step LOG_STEP Step at which for every step training info is logged. 75 | (default: 250) 76 | --device DEVICE device to train on (default: cuda:0 if cuda is 77 | available otherwise cpu) 78 | ``` 79 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from gensim.corpora import Dictionary 5 | from torch.utils.data.dataset import Dataset 6 | 7 | 8 | class SkipGramDataset(Dataset): 9 | 10 | def __init__(self, args): 11 | self.args = args 12 | self.dictionary = None 13 | self.examples = [] 14 | self.name = '' 15 | 16 | def __getitem__(self, index): 17 | center, context = self.examples[index] 18 | return np.array([center]), np.array(context) 19 | 20 | def __len__(self): 21 | return len(self.examples) 22 | 23 | def save(self, examples_path, dict_path): 24 | print('Saving Dataset Examples...') 25 | torch.save({ 26 | 'examples': self.examples, 27 | }, examples_path) 28 | print('Saving Dataset Dictionary...') 29 | self.dictionary.save(dict_path) 30 | print('Saved Dataset!') 31 | 32 | def load(self, examples_path, dict_path): 33 | print('Loading Dataset Examples...') 34 | self.examples = torch.load(examples_path)['examples'] 35 | print('Loading Dataset Dictionary...') 36 | self.dictionary = Dictionary().load(dict_path) 37 | print('Loaded Saved Dataset!') 38 | 39 | def generate_examples_serial(self): 40 | """ 41 | Generates examples with no multiprocessing - straight through! 42 | :return: None - updates class properties 43 | """ 44 | # Now we have a Gensim Dictionary to work with 45 | self._build_dictionary() 46 | # Remove any tokens with a frequency less than 10 47 | self.dictionary.filter_extremes(no_below=10, no_above=0.75) 48 | 49 | self.examples = [] 50 | for file in tqdm(self.load_files(), desc="Generating Examples (serial)"): 51 | file = self.dictionary.doc2idx(file) 52 | self.examples.extend(self._generate_examples_from_file(file)) 53 | 54 | def load_files(self): 55 | """ 56 | Sets self.files as a list of tokenized documents! 57 | :returns: List of files 58 | """ 59 | # Needs to be implemented by child class 60 | raise NotImplementedError 61 | 62 | def _build_dictionary(self): 63 | """ 64 | Creates a Gensim Dictionary 65 | :return: None - modifies self.dictionary 66 | """ 67 | print("Building Dictionary...") 68 | self.dictionary = Dictionary(self.load_files()) 69 | 70 | def _generate_examples_from_file(self, file): 71 | """ 72 | Generate all examples from a file within window size 73 | :param file: File from self.files 74 | :returns: List of examples 75 | """ 76 | 77 | examples = [] 78 | for i, token in enumerate(file): 79 | if token == -1: 80 | # Out of dictionary token 81 | continue 82 | 83 | # Generate context tokens for the current token 84 | context_words = self._generate_contexts(i, file) 85 | 86 | # Form Examples: 87 | # center, context - follows form: (input, target) 88 | new_examples = [(token, ctxt) for ctxt in context_words if ctxt != -1] 89 | 90 | # Add to class 91 | examples.extend(new_examples) 92 | return examples 93 | 94 | def _generate_contexts(self, token_idx, tokenized_doc): 95 | """ 96 | Generate Token's Context Words 97 | Generates all the context words within the window size defined 98 | during initialization around token. 99 | 100 | :param token_idx: Index at which center token is found in tokenized_doc 101 | :param tokenized_doc: List - Document broken into tokens 102 | :returns: List of context words 103 | """ 104 | contexts = [] 105 | # Iterate over each position in window 106 | for w in range(-self.args.window_size, self.args.window_size + 1): 107 | context_pos = token_idx + w 108 | 109 | # Make sure current center and context are valid 110 | is_outside_doc = context_pos < 0 or context_pos >= len(tokenized_doc) 111 | center_is_context = token_idx == context_pos 112 | 113 | if is_outside_doc or center_is_context: 114 | # Not valid - skip to next window position 115 | continue 116 | 117 | contexts.append(tokenized_doc[context_pos]) 118 | return contexts 119 | 120 | def _example_to_tensor(self, center, target): 121 | """ 122 | Takes raw example and turns it into tensor values 123 | 124 | :params example: Tuple of form: (center word, document id) 125 | :params target: String of the target word 126 | :returns: A tuple of tensors 127 | """ 128 | center, target = np.array([int(center)]), np.array([int(target)]) 129 | return center, target 130 | --------------------------------------------------------------------------------