├── models └── .keep ├── diagnostics └── .keep ├── tests ├── __init__.py ├── test_loss.py ├── test_models.py └── test_data.py ├── paragraphvec ├── __init__.py ├── loss.py ├── export_vectors.py ├── models.py ├── utils.py ├── train.py └── data.py ├── .github ├── dmdbow.png ├── learned_vectors_pca.png ├── ISSUE_TEMPLATE.md └── PULL_REQUEST_TEMPLATE.md ├── codecov.yml ├── requirements.txt ├── .coveragerc ├── .travis.yml ├── data └── example.csv ├── setup.py ├── LICENSE ├── .gitignore └── README.md /models/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diagnostics/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paragraphvec/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/dmdbow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inejc/paragraph-vectors/HEAD/.github/dmdbow.png -------------------------------------------------------------------------------- /.github/learned_vectors_pca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inejc/paragraph-vectors/HEAD/.github/learned_vectors_pca.png -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: 2 | layout: header, changes, diff, sunburst, uncovered 3 | coverage: 4 | status: 5 | patch: 6 | default: 7 | target: '75' 8 | project: 9 | default: 10 | target: auto 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | ##### PyTorch and CUDA versions 7 | 8 | 9 | ##### Description 10 | 11 | 12 | ##### Additional info (stack trace, etc.) 13 | 14 | 15 | ##### Steps to reproduce the behaviour 16 | 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2017.7.27.1 2 | chardet==3.0.4 3 | cycler==0.10.0 4 | fire==0.1.2 5 | idna==2.6 6 | matplotlib==2.1.0 7 | numpy==1.13.1 8 | py==1.4.34 9 | pyparsing==2.2.0 10 | pytest==3.2.3 11 | python-dateutil==2.6.1 12 | pytz==2017.2 13 | PyYAML==3.12 14 | requests==2.18.4 15 | six==1.10.0 16 | torchtext==0.2.0 17 | tqdm==4.15.0 18 | urllib3==1.22 19 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ##### Issue 2 | 3 | 4 | 5 | 6 | ##### Description of changes 7 | 8 | 9 | 10 | ##### Includes 11 | 12 | 13 | - [X] Code changes 14 | - [ ] Tests 15 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | 4 | concurrency = multiprocessing 5 | parallel = True 6 | 7 | source = 8 | paragraphvec 9 | 10 | omit = 11 | *tests* 12 | 13 | [report] 14 | exclude_lines = 15 | pragma: no cover 16 | pass 17 | def __repr__ 18 | if self\.debug 19 | raise AssertionError 20 | raise NotImplementedError 21 | if 0: 22 | if __name__ == .__main__.: 23 | 24 | ignore_errors = True 25 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | 5 | from paragraphvec.loss import NegativeSampling 6 | 7 | 8 | class NegativeSamplingTest(TestCase): 9 | 10 | def setUp(self): 11 | self.loss_f = NegativeSampling() 12 | 13 | def test_forward(self): 14 | # todo: test actual value 15 | scores = torch.FloatTensor([[12.1, 1.3, 6.5], [18.9, 2.1, 9.4]]) 16 | loss = self.loss_f.forward(scores) 17 | self.assertTrue(loss.data[0] >= 0) 18 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | sudo: false 4 | 5 | python: 6 | - "3.5" 7 | 8 | cache: 9 | pip: true 10 | 11 | before_install: 12 | - pip install -U pip wheel setuptools 13 | - pip install pytest pytest-cov 14 | - pip install codecov 15 | 16 | install: 17 | - pip install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl 18 | - pip install -r requirements.txt 19 | 20 | script: 21 | - py.test --cov-report xml --cov paragraphvec 22 | 23 | after_success: 24 | - codecov 25 | -------------------------------------------------------------------------------- /data/example.csv: -------------------------------------------------------------------------------- 1 | text 2 | "In the week before their departure to Arrakis, when all the final scurrying about had reached a nearly unbearable frenzy, an old crone came to visit the mother of the boy, Paul." 3 | "It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreides family as home for twenty-six generations bore that cooled-sweat feeling it acquired before a change in the weather." 4 | "The old woman was let in by the side door down the vaulted passage by Paul's room and she was allowed a moment to peer in at him where he lay in his bed." 5 | "By the half-light of a suspensor lamp, dimmed and hanging near the floor, the awakened boy could see a bulky female shape at his door, standing one step ahead of his mother. The old woman was a witch shadow - hair like matted spiderwebs, hooded 'round darkness of features, eyes like glittering jewels." 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | description = 'A PyTorch implementation of Paragraph Vectors (doc2vec).' 4 | 5 | with open('README.md') as f: 6 | long_description = f.read() 7 | 8 | with open('requirements.txt') as f: 9 | requires = f.read().splitlines() 10 | 11 | setup( 12 | name='paragraph-vectors', 13 | version='0.0.1', 14 | author='Nejc Ilenic', 15 | description=description, 16 | long_description=long_description, 17 | license='MIT', 18 | keywords='nlp documents embedding machine-learning', 19 | install_requires=requires, 20 | packages=find_packages(), 21 | test_suite='tests', 22 | classifiers=[ 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | 'License :: OSI Approved :: MIT License', 25 | 'Natural Language :: English', 26 | 'Operating System :: OS Independent', 27 | 'Programming Language :: Python :: 3.5', 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /paragraphvec/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NegativeSampling(nn.Module): 6 | """Negative sampling loss as proposed by T. Mikolov et al. in Distributed 7 | Representations of Words and Phrases and their Compositionality. 8 | """ 9 | def __init__(self): 10 | super(NegativeSampling, self).__init__() 11 | self._log_sigmoid = nn.LogSigmoid() 12 | 13 | def forward(self, scores): 14 | """Computes the value of the loss function. 15 | 16 | Parameters 17 | ---------- 18 | scores: autograd.Variable of size (batch_size, num_noise_words + 1) 19 | Sparse unnormalized log probabilities. The first element in each 20 | row is the ground truth score (i.e. the target), other elements 21 | are scores of samples from the noise distribution. 22 | """ 23 | k = scores.size()[1] - 1 24 | return -torch.sum( 25 | self._log_sigmoid(scores[:, 0]) 26 | + torch.sum(self._log_sigmoid(-scores[:, 1:]), dim=1) / k 27 | ) / scores.size()[0] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nejc Ilenic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### macOS ### 2 | *.DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Icon must end with two \r 7 | Icon 8 | 9 | # Thumbnails 10 | ._* 11 | 12 | # Files that might appear in the root of a volume 13 | .DocumentRevisions-V100 14 | .fseventsd 15 | .Spotlight-V100 16 | .TemporaryItems 17 | .Trashes 18 | .VolumeIcon.icns 19 | .com.apple.timemachine.donotpresent 20 | 21 | # Directories potentially created on remote AFP share 22 | .AppleDB 23 | .AppleDesktop 24 | Network Trash Folder 25 | Temporary Items 26 | .apdisk 27 | 28 | ### PyCharm ### 29 | .idea/ 30 | 31 | ### Python ### 32 | # Byte-compiled / optimized / DLL files 33 | __pycache__/ 34 | *.py[cod] 35 | *$py.class 36 | 37 | # C extensions 38 | *.so 39 | 40 | # Distribution / packaging 41 | .Python 42 | build/ 43 | develop-eggs/ 44 | dist/ 45 | downloads/ 46 | eggs/ 47 | .eggs/ 48 | lib/ 49 | lib64/ 50 | parts/ 51 | sdist/ 52 | var/ 53 | wheels/ 54 | *.egg-info/ 55 | .installed.cfg 56 | *.egg 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *,cover 77 | .hypothesis/ 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | env/ 95 | 96 | # custom 97 | data/*.csv 98 | !data/example.csv 99 | diagnostics/*.csv 100 | diagnostics/*.png 101 | models/*.pth.tar 102 | -------------------------------------------------------------------------------- /paragraphvec/export_vectors.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import re 3 | from os.path import join 4 | 5 | import fire 6 | import torch 7 | 8 | from paragraphvec.data import load_dataset 9 | from paragraphvec.models import DM, DBOW 10 | from paragraphvec.utils import DATA_DIR, MODELS_DIR 11 | 12 | 13 | def start(data_file_name, model_file_name): 14 | """Saves trained paragraph vectors to a csv file in the *data* directory. 15 | 16 | Parameters 17 | ---------- 18 | data_file_name: str 19 | Name of a file in the *data* directory that was used during training. 20 | 21 | model_file_name: str 22 | Name of a file in the *models* directory (a model trained on 23 | the *data_file_name* dataset). 24 | """ 25 | dataset = load_dataset(data_file_name) 26 | 27 | vec_dim = int(re.search('_vecdim\.(\d+)_', model_file_name).group(1)) 28 | 29 | model = _load_model( 30 | model_file_name, 31 | vec_dim, 32 | num_docs=len(dataset), 33 | num_words=len(dataset.fields['text'].vocab) - 1) 34 | 35 | _write_to_file(data_file_name, model_file_name, model, vec_dim) 36 | 37 | 38 | def _load_model(model_file_name, vec_dim, num_docs, num_words): 39 | model_ver = re.search('_model\.(dm|dbow)', model_file_name).group(1) 40 | if model_ver is None: 41 | raise ValueError("Model file name contains an invalid" 42 | "version of the model") 43 | 44 | model_file_path = join(MODELS_DIR, model_file_name) 45 | 46 | try: 47 | checkpoint = torch.load(model_file_path) 48 | except AssertionError: 49 | checkpoint = torch.load( 50 | model_file_path, 51 | map_location=lambda storage, location: storage) 52 | 53 | if model_ver == 'dbow': 54 | model = DBOW(vec_dim, num_docs, num_words) 55 | else: 56 | model = DM(vec_dim, num_docs, num_words) 57 | 58 | model.load_state_dict(checkpoint['model_state_dict']) 59 | return model 60 | 61 | 62 | def _write_to_file(data_file_name, model_file_name, model, vec_dim): 63 | result_lines = [] 64 | 65 | with open(join(DATA_DIR, data_file_name)) as f: 66 | reader = csv.reader(f) 67 | 68 | for i, line in enumerate(reader): 69 | # skip text 70 | result_line = line[1:] 71 | if i == 0: 72 | # header line 73 | result_line += ["d{:d}".format(x) for x in range(vec_dim)] 74 | else: 75 | vector = model.get_paragraph_vector(i - 1) 76 | result_line += [str(x) for x in vector] 77 | 78 | result_lines.append(result_line) 79 | 80 | result_file_name = model_file_name[:-7] + 'csv' 81 | 82 | with open(join(DATA_DIR, result_file_name), 'w') as f: 83 | writer = csv.writer(f) 84 | writer.writerows(result_lines) 85 | 86 | 87 | if __name__ == '__main__': 88 | fire.Fire() 89 | -------------------------------------------------------------------------------- /paragraphvec/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DM(nn.Module): 6 | """Distributed Memory version of Paragraph Vectors. 7 | 8 | Parameters 9 | ---------- 10 | vec_dim: int 11 | Dimensionality of vectors to be learned (for paragraphs and words). 12 | 13 | num_docs: int 14 | Number of documents in a dataset. 15 | 16 | num_words: int 17 | Number of distinct words in a daset (i.e. vocabulary size). 18 | """ 19 | def __init__(self, vec_dim, num_docs, num_words): 20 | super(DM, self).__init__() 21 | # paragraph matrix 22 | self._D = nn.Parameter( 23 | torch.randn(num_docs, vec_dim), requires_grad=True) 24 | # word matrix 25 | self._W = nn.Parameter( 26 | torch.randn(num_words, vec_dim), requires_grad=True) 27 | # output layer parameters 28 | self._O = nn.Parameter( 29 | torch.FloatTensor(vec_dim, num_words).zero_(), requires_grad=True) 30 | 31 | def forward(self, context_ids, doc_ids, target_noise_ids): 32 | """Sparse computation of scores (unnormalized log probabilities) 33 | that should be passed to the negative sampling loss. 34 | 35 | Parameters 36 | ---------- 37 | context_ids: torch.Tensor of size (batch_size, num_context_words) 38 | Vocabulary indices of context words. 39 | 40 | doc_ids: torch.Tensor of size (batch_size,) 41 | Document indices of paragraphs. 42 | 43 | target_noise_ids: torch.Tensor of size (batch_size, num_noise_words + 1) 44 | Vocabulary indices of target and noise words. The first element in 45 | each row is the ground truth index (i.e. the target), other 46 | elements are indices of samples from the noise distribution. 47 | 48 | Returns 49 | ------- 50 | autograd.Variable of size (batch_size, num_noise_words + 1) 51 | """ 52 | # combine a paragraph vector with word vectors of 53 | # input (context) words 54 | x = torch.add( 55 | self._D[doc_ids, :], torch.sum(self._W[context_ids, :], dim=1)) 56 | 57 | # sparse computation of scores (unnormalized log probabilities) 58 | # for negative sampling 59 | return torch.bmm( 60 | x.unsqueeze(1), 61 | self._O[:, target_noise_ids].permute(1, 0, 2)).squeeze() 62 | 63 | def get_paragraph_vector(self, index): 64 | return self._D[index, :].data.tolist() 65 | 66 | 67 | class DBOW(nn.Module): 68 | """Distributed Bag of Words version of Paragraph Vectors. 69 | 70 | Parameters 71 | ---------- 72 | vec_dim: int 73 | Dimensionality of vectors to be learned (for paragraphs and words). 74 | 75 | num_docs: int 76 | Number of documents in a dataset. 77 | 78 | num_words: int 79 | Number of distinct words in a daset (i.e. vocabulary size). 80 | """ 81 | def __init__(self, vec_dim, num_docs, num_words): 82 | super(DBOW, self).__init__() 83 | # paragraph matrix 84 | self._D = nn.Parameter( 85 | torch.randn(num_docs, vec_dim), requires_grad=True) 86 | # output layer parameters 87 | self._O = nn.Parameter( 88 | torch.FloatTensor(vec_dim, num_words).zero_(), requires_grad=True) 89 | 90 | def forward(self, doc_ids, target_noise_ids): 91 | """Sparse computation of scores (unnormalized log probabilities) 92 | that should be passed to the negative sampling loss. 93 | 94 | Parameters 95 | ---------- 96 | doc_ids: torch.Tensor of size (batch_size,) 97 | Document indices of paragraphs. 98 | 99 | target_noise_ids: torch.Tensor of size (batch_size, num_noise_words + 1) 100 | Vocabulary indices of target and noise words. The first element in 101 | each row is the ground truth index (i.e. the target), other 102 | elements are indices of samples from the noise distribution. 103 | 104 | Returns 105 | ------- 106 | autograd.Variable of size (batch_size, num_noise_words + 1) 107 | """ 108 | # sparse computation of scores (unnormalized log probabilities) 109 | # for negative sampling 110 | return torch.bmm( 111 | self._D[doc_ids, :].unsqueeze(1), 112 | self._O[:, target_noise_ids].permute(1, 0, 2)).squeeze() 113 | 114 | def get_paragraph_vector(self, index): 115 | return self._D[index, :].data.tolist() 116 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | 5 | from paragraphvec.loss import NegativeSampling 6 | from paragraphvec.models import DM, DBOW 7 | 8 | 9 | class DMTest(TestCase): 10 | 11 | def setUp(self): 12 | self.batch_size = 2 13 | self.num_noise_words = 2 14 | self.num_docs = 3 15 | self.num_words = 15 16 | self.vec_dim = 10 17 | 18 | self.context_ids = torch.LongTensor([[0, 2, 5, 6], [3, 4, 1, 6]]) 19 | self.doc_ids = torch.LongTensor([1, 2]) 20 | self.target_noise_ids = torch.LongTensor([[1, 3, 4], [2, 4, 7]]) 21 | self.model = DM( 22 | self.vec_dim, self.num_docs, self.num_words) 23 | 24 | def test_num_parameters(self): 25 | self.assertEqual( 26 | sum([x.size()[0] * x.size()[1] for x in self.model.parameters()]), 27 | self.num_docs * self.vec_dim + 2 * self.num_words * self.vec_dim) 28 | 29 | def test_forward(self): 30 | x = self.model.forward( 31 | self.context_ids, self.doc_ids, self.target_noise_ids) 32 | 33 | self.assertEqual(x.size()[0], self.batch_size) 34 | self.assertEqual(x.size()[1], self.num_noise_words + 1) 35 | 36 | def test_backward(self): 37 | cost_func = NegativeSampling() 38 | optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001) 39 | for _ in range(2): 40 | x = self.model.forward( 41 | self.context_ids, self.doc_ids, self.target_noise_ids) 42 | x = cost_func.forward(x) 43 | self.model.zero_grad() 44 | x.backward() 45 | optimizer.step() 46 | 47 | self.assertEqual(torch.sum(self.model._D.grad[0, :].data), 0) 48 | self.assertNotEqual(torch.sum(self.model._D.grad[1, :].data), 0) 49 | self.assertNotEqual(torch.sum(self.model._D.grad[2, :].data), 0) 50 | 51 | context_ids = self.context_ids.numpy().flatten() 52 | target_noise_ids = self.target_noise_ids.numpy().flatten() 53 | 54 | for word_id in range(15): 55 | if word_id in context_ids: 56 | self.assertNotEqual( 57 | torch.sum(self.model._W.grad[word_id, :].data), 0) 58 | else: 59 | self.assertEqual( 60 | torch.sum(self.model._W.grad[word_id, :].data), 0) 61 | 62 | if word_id in target_noise_ids: 63 | self.assertNotEqual( 64 | torch.sum(self.model._O.grad[:, word_id].data), 0) 65 | else: 66 | self.assertEqual( 67 | torch.sum(self.model._O.grad[:, word_id].data), 0) 68 | 69 | 70 | class DBOWTest(TestCase): 71 | 72 | def setUp(self): 73 | self.batch_size = 2 74 | self.num_noise_words = 2 75 | self.num_docs = 3 76 | self.num_words = 15 77 | self.vec_dim = 10 78 | 79 | self.doc_ids = torch.LongTensor([1, 2]) 80 | self.target_noise_ids = torch.LongTensor([[1, 3, 4], [2, 4, 7]]) 81 | self.model = DBOW( 82 | self.vec_dim, self.num_docs, self.num_words) 83 | 84 | def test_num_parameters(self): 85 | self.assertEqual( 86 | sum([x.size()[0] * x.size()[1] for x in self.model.parameters()]), 87 | self.num_docs * self.vec_dim + self.num_words * self.vec_dim) 88 | 89 | def test_forward(self): 90 | x = self.model.forward(self.doc_ids, self.target_noise_ids) 91 | 92 | self.assertEqual(x.size()[0], self.batch_size) 93 | self.assertEqual(x.size()[1], self.num_noise_words + 1) 94 | 95 | def test_backward(self): 96 | cost_func = NegativeSampling() 97 | optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001) 98 | for _ in range(2): 99 | x = self.model.forward(self.doc_ids, self.target_noise_ids) 100 | x = cost_func.forward(x) 101 | self.model.zero_grad() 102 | x.backward() 103 | optimizer.step() 104 | 105 | self.assertEqual(torch.sum(self.model._D.grad[0, :].data), 0) 106 | self.assertNotEqual(torch.sum(self.model._D.grad[1, :].data), 0) 107 | self.assertNotEqual(torch.sum(self.model._D.grad[2, :].data), 0) 108 | 109 | target_noise_ids = self.target_noise_ids.numpy().flatten() 110 | 111 | for word_id in range(15): 112 | if word_id in target_noise_ids: 113 | self.assertNotEqual( 114 | torch.sum(self.model._O.grad[:, word_id].data), 0) 115 | else: 116 | self.assertEqual( 117 | torch.sum(self.model._O.grad[:, word_id].data), 0) 118 | -------------------------------------------------------------------------------- /paragraphvec/utils.py: -------------------------------------------------------------------------------- 1 | from os import remove 2 | from os.path import join, dirname, isfile 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | _root_dir = dirname(dirname(__file__)) 8 | 9 | DATA_DIR = join(_root_dir, 'data') 10 | MODELS_DIR = join(_root_dir, 'models') 11 | _DIAGNOSTICS_DIR = join(_root_dir, 'diagnostics') 12 | 13 | _DM_MODEL_NAME = ("{:s}_model.{:s}.{:s}_contextsize.{:d}_numnoisewords.{:d}" 14 | "_vecdim.{:d}_batchsize.{:d}_lr.{:f}_epoch.{:d}_loss.{:f}" 15 | ".pth.tar") 16 | _DM_DIAGNOSTIC_FILE_NAME = ("{:s}_model.{:s}.{:s}_contextsize.{:d}" 17 | "_numnoisewords.{:d}_vecdim.{:d}_batchsize.{:d}" 18 | "_lr.{:f}.csv") 19 | _DBOW_MODEL_NAME = ("{:s}_model.{:s}_numnoisewords.{:d}_vecdim.{:d}" 20 | "_batchsize.{:d}_lr.{:f}_epoch.{:d}_loss.{:f}.pth.tar") 21 | _DBOW_DIAGNOSTIC_FILE_NAME = ("{:s}_model.{:s}_numnoisewords.{:d}_vecdim.{:d}" 22 | "_batchsize.{:d}_lr.{:f}.csv") 23 | 24 | 25 | def save_training_state(data_file_name, 26 | model_ver, 27 | vec_combine_method, 28 | context_size, 29 | num_noise_words, 30 | vec_dim, 31 | batch_size, 32 | lr, 33 | epoch_i, 34 | loss, 35 | model_state, 36 | save_all, 37 | generate_plot, 38 | is_best_loss, 39 | prev_model_file_path, 40 | model_ver_is_dbow): 41 | """Saves the state of the model. If generate_plot is True, it also 42 | saves current epoch's loss value and generates a plot of all loss 43 | values up to this epoch. 44 | 45 | Returns 46 | ------- 47 | str representing a model file path from the previous epoch 48 | """ 49 | if generate_plot: 50 | # save the loss value for a diagnostic plot 51 | if model_ver_is_dbow: 52 | diagnostic_file_name = _DBOW_DIAGNOSTIC_FILE_NAME.format( 53 | data_file_name[:-4], 54 | model_ver, 55 | num_noise_words, 56 | vec_dim, 57 | batch_size, 58 | lr) 59 | else: 60 | diagnostic_file_name = _DM_DIAGNOSTIC_FILE_NAME.format( 61 | data_file_name[:-4], 62 | model_ver, 63 | vec_combine_method, 64 | context_size, 65 | num_noise_words, 66 | vec_dim, 67 | batch_size, 68 | lr) 69 | 70 | diagnostic_file_path = join(_DIAGNOSTICS_DIR, diagnostic_file_name) 71 | 72 | if epoch_i == 0 and isfile(diagnostic_file_path): 73 | remove(diagnostic_file_path) 74 | 75 | with open(diagnostic_file_path, 'a') as f: 76 | f.write('{:f}\n'.format(loss)) 77 | 78 | # generate a diagnostic loss plot 79 | with open(diagnostic_file_path) as f: 80 | loss_values = [float(l.rstrip()) for l in f.readlines()] 81 | 82 | diagnostic_plot_file_path = diagnostic_file_path[:-3] + 'png' 83 | fig = plt.figure() 84 | plt.plot(range(1, epoch_i + 2), loss_values, color='r') 85 | plt.xlabel('epoch') 86 | plt.ylabel('training loss') 87 | fig.savefig(diagnostic_plot_file_path, bbox_inches='tight') 88 | plt.close() 89 | 90 | # save the model 91 | if model_ver_is_dbow: 92 | model_file_name = _DBOW_MODEL_NAME.format( 93 | data_file_name[:-4], 94 | model_ver, 95 | num_noise_words, 96 | vec_dim, 97 | batch_size, 98 | lr, 99 | epoch_i + 1, 100 | loss) 101 | else: 102 | model_file_name = _DM_MODEL_NAME.format( 103 | data_file_name[:-4], 104 | model_ver, 105 | vec_combine_method, 106 | context_size, 107 | num_noise_words, 108 | vec_dim, 109 | batch_size, 110 | lr, 111 | epoch_i + 1, 112 | loss) 113 | 114 | model_file_path = join(MODELS_DIR, model_file_name) 115 | 116 | if save_all: 117 | torch.save(model_state, model_file_path) 118 | return None 119 | elif is_best_loss: 120 | if prev_model_file_path is not None: 121 | remove(prev_model_file_path) 122 | 123 | torch.save(model_state, model_file_path) 124 | return model_file_path 125 | else: 126 | return prev_model_file_path 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Paragraph Vectors 2 | [![Build Status](https://travis-ci.org/inejc/paragraph-vectors.svg?branch=master)](https://travis-ci.org/inejc/paragraph-vectors) 3 | [![codecov](https://codecov.io/gh/inejc/paragraph-vectors/branch/master/graph/badge.svg)](https://codecov.io/gh/inejc/paragraph-vectors) 4 | [![codebeat badge](https://codebeat.co/badges/e5008ad0-240c-48e9-a158-2547989b798e)](https://codebeat.co/projects/github-com-inejc-paragraph-vectors-master) 5 | [![Codacy Badge](https://api.codacy.com/project/badge/Grade/c865067aa4194184ae0c649b865b1fd2)](https://www.codacy.com/app/inejc/paragraph-vectors?utm_source=github.com&utm_medium=referral&utm_content=inejc/paragraph-vectors&utm_campaign=Badge_Grade) 6 | 7 | A PyTorch implementation of Paragraph Vectors (doc2vec). 8 |

9 | 10 |

11 | 12 | All models minimize the Negative Sampling objective as proposed by T. Mikolov et al. [1]. This provides scope for sparse updates (i.e. only vectors of sampled noise words are used in forward and backward passes). In addition to that, batches of training data (with noise sampling) are generated in parallel on a CPU while the model is trained on a GPU. 13 | 14 | **Caveat emptor!** Be warned that **`paragraph-vectors`** is in an early-stage development phase. Feedback, comments, suggestions, contributions, etc. are more than welcome. 15 | 16 | ### Installation 17 | 1. Install [PyTorch](http://pytorch.org) (follow the link for instructions). 18 | 2. Install the **`paragraph-vectors`** library. 19 | ``` 20 | git clone https://github.com/inejc/paragraph-vectors.git 21 | cd paragraph-vectors 22 | pip install -e . 23 | ``` 24 | Note that installation in a virtual environment is the recommended way. 25 | 26 | ### Usage 27 | 1. Put a csv file in the [data](data) directory. Each row represents a single document and the first column should always contain the text. Note that a header line is mandatory. 28 | ```text 29 | data/example.csv 30 | ---------------- 31 | text,... 32 | "In the week before their departure to Arrakis, when all the final scurrying about had reached a nearly unbearable frenzy, an old crone came to visit the mother of the boy, Paul.",... 33 | "It was a warm night at Castle Caladan, and the ancient pile of stone that had served the Atreides family as home for twenty-six generations bore that cooled-sweat feeling it acquired before a change in the weather.",... 34 | ... 35 | ``` 36 | 2. Run [train.py](paragraphvec/train.py) with selected parameters (models are saved in the [models](models) directory). 37 | ```bash 38 | python train.py start --data_file_name 'example.csv' --num_epochs 100 --batch_size 32 --num_noise_words 2 --vec_dim 100 --lr 1e-3 39 | ``` 40 | 41 | #### Parameters 42 | * **`data_file_name`**: str\ 43 | Name of a file in the *data* directory. 44 | * **`model_ver`**: str, one of ('dm', 'dbow'), default='dbow'\ 45 | Version of the model as proposed by Q. V. Le et al. [5], Distributed Representations of Sentences and Documents. 'dbow' stands for Distributed Bag Of Words, 'dm' stands for Distributed Memory. 46 | * **`vec_combine_method`**: str, one of ('sum', 'concat'), default='sum'\ 47 | Method for combining paragraph and word vectors when model_ver='dm'. Currently only the 'sum' operation is implemented. 48 | * **`context_size`**: int, default=0\ 49 | Half the size of a neighbourhood of target words when model_ver='dm' (i.e. how many words left and right are regarded as context). When model_ver='dm' context_size has to greater than 0, when model_ver='dbow' context_size has to be 0. 50 | * **`num_noise_words`**: int\ 51 | Number of noise words to sample from the noise distribution. 52 | * **`vec_dim`**: int\ 53 | Dimensionality of vectors to be learned (for paragraphs and words). 54 | * **`num_epochs`**: int\ 55 | Number of iterations to train the model (i.e. number of times every example is seen during training). 56 | * **`batch_size`**: int\ 57 | Number of examples per single gradient update. 58 | * **`lr`**: float\ 59 | Learning rate of the Adam optimizer. 60 | * **`save_all`**: bool, default=False\ 61 | Indicates whether a checkpoint is saved after each epoch. If false, only the best performing model is saved. 62 | * **`generate_plot`**: bool, default=True\ 63 | Indicates whether a diagnostic plot displaying loss value over epochs is generated after each epoch. 64 | * **`max_generated_batches`**: int, default=5\ 65 | Maximum number of pre-generated batches. 66 | * **`num_workers`**: int, default=1\ 67 | Number of batch generator jobs to run in parallel. If value is set to -1, total number of machine CPUs is used. Note that order of batches is not guaranteed when **`num_workers`** > 1. 68 | 69 | 3. Export trained paragraph vectors to a csv file (vectors are saved in the [data](data) directory). 70 | ```bash 71 | python export_vectors.py start --data_file_name 'example.csv' --model_file_name 'example_model.dbow_numnoisewords.2_vecdim.100_batchsize.32_lr.0.001000_epoch.25_loss.0.981524.pth.tar' 72 | ``` 73 | 74 | #### Parameters 75 | * **`data_file_name`**: str\ 76 | Name of a file in the *data* directory that was used during training. 77 | * **`model_file_name`**: str\ 78 | Name of a file in the *models* directory (a model trained on the **`data_file_name`** dataset). 79 | 80 | ### Example of trained vectors 81 | First two principal components (1% cumulative variance explained) of 300-dimensional document vectors trained on arXiv abstracts. Shown are two subcategories from Computer Science. Dataset was comprised of 74219 documents and 91417 unique words. 82 |

83 | 84 |

85 | 86 | ### Resources 87 | * [1] [Distributed Representations of Words and Phrases and their Compositionality, T. Mikolov et al.](https://arxiv.org/abs/1310.4546) 88 | * [2] [Learning word embeddings efficiently with noise-contrastive estimation, A. Mnih et al.](http://papers.nips.cc/paper/5165-learning-word-embeddings-efficiently-with) 89 | * [3] [Notes on Noise Contrastive Estimation and Negative Sampling, C. Dyer](https://arxiv.org/abs/1410.8251) 90 | * [4] [Approximating the Softmax (a blog post), S. Ruder](http://ruder.io/word-embeddings-softmax/index.html) 91 | * [5] [Distributed Representations of Sentences and Documents, Q. V. Le et al.](https://arxiv.org/abs/1405.4053) 92 | * [6] [Document Embedding with Paragraph Vectors, A. M. Dai et al.](https://arxiv.org/abs/1507.07998) 93 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import time 2 | from unittest import TestCase 3 | 4 | from paragraphvec.data import load_dataset, NCEData 5 | 6 | 7 | class NCEDataTest(TestCase): 8 | 9 | def setUp(self): 10 | self.dataset = load_dataset('example.csv') 11 | 12 | def test_num_examples_for_different_batch_sizes(self): 13 | len_1 = self._num_examples_with_batch_size(1) 14 | 15 | for batch_size in range(2, 100): 16 | len_x = self._num_examples_with_batch_size(batch_size) 17 | self.assertEqual(len_x, len_1) 18 | 19 | def _num_examples_with_batch_size(self, batch_size): 20 | nce_data = NCEData( 21 | self.dataset, 22 | batch_size=batch_size, 23 | context_size=2, 24 | num_noise_words=3, 25 | max_size=1, 26 | num_workers=1) 27 | num_batches = len(nce_data) 28 | nce_data.start() 29 | nce_generator = nce_data.get_generator() 30 | 31 | total = 0 32 | for _ in range(num_batches): 33 | batch = next(nce_generator) 34 | total += len(batch) 35 | nce_data.stop() 36 | return total 37 | 38 | def test_multiple_iterations(self): 39 | nce_data = NCEData( 40 | self.dataset, 41 | batch_size=16, 42 | context_size=3, 43 | num_noise_words=3, 44 | max_size=1, 45 | num_workers=1) 46 | num_batches = len(nce_data) 47 | nce_data.start() 48 | nce_generator = nce_data.get_generator() 49 | 50 | iter0_targets = [] 51 | for _ in range(num_batches): 52 | batch = next(nce_generator) 53 | iter0_targets.append([x[0] for x in batch.target_noise_ids]) 54 | 55 | iter1_targets = [] 56 | for _ in range(num_batches): 57 | batch = next(nce_generator) 58 | iter1_targets.append([x[0] for x in batch.target_noise_ids]) 59 | 60 | for ts0, ts1 in zip(iter0_targets, iter1_targets): 61 | for t0, t1 in zip(ts0, ts0): 62 | self.assertEqual(t0, t1) 63 | nce_data.stop() 64 | 65 | def test_different_batch_sizes(self): 66 | nce_data = NCEData( 67 | self.dataset, 68 | batch_size=16, 69 | context_size=1, 70 | num_noise_words=3, 71 | max_size=1, 72 | num_workers=1) 73 | num_batches = len(nce_data) 74 | nce_data.start() 75 | nce_generator = nce_data.get_generator() 76 | 77 | targets0 = [] 78 | for _ in range(num_batches): 79 | batch = next(nce_generator) 80 | for ts in batch.target_noise_ids: 81 | targets0.append(ts[0]) 82 | nce_data.stop() 83 | 84 | nce_data = NCEData( 85 | self.dataset, 86 | batch_size=19, 87 | context_size=1, 88 | num_noise_words=3, 89 | max_size=1, 90 | num_workers=1) 91 | num_batches = len(nce_data) 92 | nce_data.start() 93 | nce_generator = nce_data.get_generator() 94 | 95 | targets1 = [] 96 | for _ in range(num_batches): 97 | batch = next(nce_generator) 98 | for ts in batch.target_noise_ids: 99 | targets1.append(ts[0]) 100 | nce_data.stop() 101 | 102 | for t0, t1 in zip(targets0, targets1): 103 | self.assertEqual(t0, t1) 104 | 105 | def test_tensor_sizes(self): 106 | nce_data = NCEData( 107 | self.dataset, 108 | batch_size=32, 109 | context_size=5, 110 | num_noise_words=3, 111 | max_size=1, 112 | num_workers=1) 113 | nce_data.start() 114 | nce_generator = nce_data.get_generator() 115 | batch = next(nce_generator) 116 | nce_data.stop() 117 | 118 | self.assertEqual(batch.context_ids.size()[0], 32) 119 | self.assertEqual(batch.context_ids.size()[1], 10) 120 | self.assertEqual(batch.doc_ids.size()[0], 32) 121 | self.assertEqual(batch.target_noise_ids.size()[0], 32) 122 | self.assertEqual(batch.target_noise_ids.size()[1], 4) 123 | 124 | def test_parallel(self): 125 | # serial version has max_size=3, because in the parallel version two 126 | # processes advance the state before they are blocked by the queue.put() 127 | nce_data = NCEData( 128 | self.dataset, 129 | batch_size=32, 130 | context_size=5, 131 | num_noise_words=1, 132 | max_size=3, 133 | num_workers=1) 134 | nce_data.start() 135 | time.sleep(1) 136 | nce_data.stop() 137 | state_serial = nce_data._generator._state 138 | 139 | nce_data = NCEData( 140 | self.dataset, 141 | batch_size=32, 142 | context_size=5, 143 | num_noise_words=1, 144 | max_size=2, 145 | num_workers=2) 146 | nce_data.start() 147 | time.sleep(1) 148 | nce_data.stop() 149 | state_parallel = nce_data._generator._state 150 | 151 | self.assertEqual( 152 | state_parallel._doc_id.value, 153 | state_serial._doc_id.value) 154 | self.assertEqual( 155 | state_parallel._in_doc_pos.value, 156 | state_serial._in_doc_pos.value) 157 | 158 | def test_no_context(self): 159 | nce_data = NCEData( 160 | self.dataset, 161 | batch_size=16, 162 | context_size=0, 163 | num_noise_words=3, 164 | max_size=1, 165 | num_workers=1) 166 | nce_data.start() 167 | nce_generator = nce_data.get_generator() 168 | batch = next(nce_generator) 169 | nce_data.stop() 170 | 171 | self.assertEqual(batch.context_ids, None) 172 | 173 | 174 | class DataUtilsTest(TestCase): 175 | 176 | def setUp(self): 177 | self.dataset = load_dataset('example.csv') 178 | 179 | def test_load_dataset(self): 180 | self.assertEqual(len(self.dataset), 4) 181 | 182 | def test_vocab(self): 183 | self.assertTrue(self.dataset.fields['text'].use_vocab) 184 | self.assertTrue(len(self.dataset.fields['text'].vocab) > 0) 185 | -------------------------------------------------------------------------------- /paragraphvec/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from sys import float_info, stdout 3 | 4 | import fire 5 | import torch 6 | from torch.optim import Adam 7 | 8 | from paragraphvec.data import load_dataset, NCEData 9 | from paragraphvec.loss import NegativeSampling 10 | from paragraphvec.models import DM, DBOW 11 | from paragraphvec.utils import save_training_state 12 | 13 | 14 | def start(data_file_name, 15 | num_noise_words, 16 | vec_dim, 17 | num_epochs, 18 | batch_size, 19 | lr, 20 | model_ver='dbow', 21 | context_size=0, 22 | vec_combine_method='sum', 23 | save_all=False, 24 | generate_plot=True, 25 | max_generated_batches=5, 26 | num_workers=1): 27 | """Trains a new model. The latest checkpoint and the best performing 28 | model are saved in the *models* directory. 29 | 30 | Parameters 31 | ---------- 32 | data_file_name: str 33 | Name of a file in the *data* directory. 34 | 35 | model_ver: str, one of ('dm', 'dbow'), default='dbow' 36 | Version of the model as proposed by Q. V. Le et al., Distributed 37 | Representations of Sentences and Documents. 'dbow' stands for 38 | Distributed Bag Of Words, 'dm' stands for Distributed Memory. 39 | 40 | vec_combine_method: str, one of ('sum', 'concat'), default='sum' 41 | Method for combining paragraph and word vectors when model_ver='dm'. 42 | Currently only the 'sum' operation is implemented. 43 | 44 | context_size: int, default=0 45 | Half the size of a neighbourhood of target words when model_ver='dm' 46 | (i.e. how many words left and right are regarded as context). When 47 | model_ver='dm' context_size has to greater than 0, when 48 | model_ver='dbow' context_size has to be 0. 49 | 50 | num_noise_words: int 51 | Number of noise words to sample from the noise distribution. 52 | 53 | vec_dim: int 54 | Dimensionality of vectors to be learned (for paragraphs and words). 55 | 56 | num_epochs: int 57 | Number of iterations to train the model (i.e. number 58 | of times every example is seen during training). 59 | 60 | batch_size: int 61 | Number of examples per single gradient update. 62 | 63 | lr: float 64 | Learning rate of the Adam optimizer. 65 | 66 | save_all: bool, default=False 67 | Indicates whether a checkpoint is saved after each epoch. 68 | If false, only the best performing model is saved. 69 | 70 | generate_plot: bool, default=True 71 | Indicates whether a diagnostic plot displaying loss value over 72 | epochs is generated after each epoch. 73 | 74 | max_generated_batches: int, default=5 75 | Maximum number of pre-generated batches. 76 | 77 | num_workers: int, default=1 78 | Number of batch generator jobs to run in parallel. If value is set 79 | to -1 number of machine cores are used. 80 | """ 81 | if model_ver not in ('dm', 'dbow'): 82 | raise ValueError("Invalid version of the model") 83 | 84 | model_ver_is_dbow = model_ver == 'dbow' 85 | 86 | if model_ver_is_dbow and context_size != 0: 87 | raise ValueError("Context size has to be zero when using dbow") 88 | if not model_ver_is_dbow: 89 | if vec_combine_method not in ('sum', 'concat'): 90 | raise ValueError("Invalid method for combining paragraph and word " 91 | "vectors when using dm") 92 | if context_size <= 0: 93 | raise ValueError("Context size must be positive when using dm") 94 | 95 | dataset = load_dataset(data_file_name) 96 | nce_data = NCEData( 97 | dataset, 98 | batch_size, 99 | context_size, 100 | num_noise_words, 101 | max_generated_batches, 102 | num_workers) 103 | nce_data.start() 104 | 105 | try: 106 | _run(data_file_name, dataset, nce_data.get_generator(), len(nce_data), 107 | nce_data.vocabulary_size(), context_size, num_noise_words, vec_dim, 108 | num_epochs, batch_size, lr, model_ver, vec_combine_method, 109 | save_all, generate_plot, model_ver_is_dbow) 110 | except KeyboardInterrupt: 111 | nce_data.stop() 112 | 113 | 114 | def _run(data_file_name, 115 | dataset, 116 | data_generator, 117 | num_batches, 118 | vocabulary_size, 119 | context_size, 120 | num_noise_words, 121 | vec_dim, 122 | num_epochs, 123 | batch_size, 124 | lr, 125 | model_ver, 126 | vec_combine_method, 127 | save_all, 128 | generate_plot, 129 | model_ver_is_dbow): 130 | 131 | if model_ver_is_dbow: 132 | model = DBOW(vec_dim, num_docs=len(dataset), num_words=vocabulary_size) 133 | else: 134 | model = DM(vec_dim, num_docs=len(dataset), num_words=vocabulary_size) 135 | 136 | cost_func = NegativeSampling() 137 | optimizer = Adam(params=model.parameters(), lr=lr) 138 | 139 | if torch.cuda.is_available(): 140 | model.cuda() 141 | 142 | print("Dataset comprised of {:d} documents.".format(len(dataset))) 143 | print("Vocabulary size is {:d}.\n".format(vocabulary_size)) 144 | print("Training started.") 145 | 146 | best_loss = float("inf") 147 | prev_model_file_path = None 148 | 149 | for epoch_i in range(num_epochs): 150 | epoch_start_time = time.time() 151 | loss = [] 152 | 153 | for batch_i in range(num_batches): 154 | batch = next(data_generator) 155 | if torch.cuda.is_available(): 156 | batch.cuda_() 157 | 158 | if model_ver_is_dbow: 159 | x = model.forward(batch.doc_ids, batch.target_noise_ids) 160 | else: 161 | x = model.forward( 162 | batch.context_ids, 163 | batch.doc_ids, 164 | batch.target_noise_ids) 165 | 166 | x = cost_func.forward(x) 167 | 168 | loss.append(x.item()) 169 | model.zero_grad() 170 | x.backward() 171 | optimizer.step() 172 | _print_progress(epoch_i, batch_i, num_batches) 173 | 174 | # end of epoch 175 | loss = torch.mean(torch.FloatTensor(loss)) 176 | is_best_loss = loss < best_loss 177 | best_loss = min(loss, best_loss) 178 | 179 | state = { 180 | 'epoch': epoch_i + 1, 181 | 'model_state_dict': model.state_dict(), 182 | 'best_loss': best_loss, 183 | 'optimizer_state_dict': optimizer.state_dict() 184 | } 185 | 186 | prev_model_file_path = save_training_state( 187 | data_file_name, 188 | model_ver, 189 | vec_combine_method, 190 | context_size, 191 | num_noise_words, 192 | vec_dim, 193 | batch_size, 194 | lr, 195 | epoch_i, 196 | loss, 197 | state, 198 | save_all, 199 | generate_plot, 200 | is_best_loss, 201 | prev_model_file_path, 202 | model_ver_is_dbow) 203 | 204 | epoch_total_time = round(time.time() - epoch_start_time) 205 | print(" ({:d}s) - loss: {:.4f}".format(epoch_total_time, loss)) 206 | 207 | 208 | def _print_progress(epoch_i, batch_i, num_batches): 209 | progress = round((batch_i + 1) / num_batches * 100) 210 | print("\rEpoch {:d}".format(epoch_i + 1), end='') 211 | stdout.write(" - {:d}%".format(progress)) 212 | stdout.flush() 213 | 214 | 215 | if __name__ == '__main__': 216 | fire.Fire() 217 | -------------------------------------------------------------------------------- /paragraphvec/data.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import re 4 | import signal 5 | from math import ceil 6 | from os.path import join 7 | 8 | import numpy as np 9 | import torch 10 | from numpy.random import choice 11 | from torchtext.data import Field, TabularDataset 12 | 13 | from paragraphvec.utils import DATA_DIR 14 | 15 | 16 | def load_dataset(file_name): 17 | """Loads contents from a file in the *data* directory into a 18 | torchtext.data.TabularDataset instance. 19 | """ 20 | file_path = join(DATA_DIR, file_name) 21 | text_field = Field(pad_token=None, tokenize=_tokenize_str) 22 | 23 | dataset = TabularDataset( 24 | path=file_path, 25 | format='csv', 26 | fields=[('text', text_field)], 27 | skip_header=True) 28 | 29 | text_field.build_vocab(dataset) 30 | return dataset 31 | 32 | 33 | def _tokenize_str(str_): 34 | # keep only alphanumeric and punctations 35 | str_ = re.sub(r'[^A-Za-z0-9(),.!?\'`]', ' ', str_) 36 | # remove multiple whitespace characters 37 | str_ = re.sub(r'\s{2,}', ' ', str_) 38 | # punctations to tokens 39 | str_ = re.sub(r'\(', ' ( ', str_) 40 | str_ = re.sub(r'\)', ' ) ', str_) 41 | str_ = re.sub(r',', ' , ', str_) 42 | str_ = re.sub(r'\.', ' . ', str_) 43 | str_ = re.sub(r'!', ' ! ', str_) 44 | str_ = re.sub(r'\?', ' ? ', str_) 45 | # split contractions into multiple tokens 46 | str_ = re.sub(r'\'s', ' \'s', str_) 47 | str_ = re.sub(r'\'ve', ' \'ve', str_) 48 | str_ = re.sub(r'n\'t', ' n\'t', str_) 49 | str_ = re.sub(r'\'re', ' \'re', str_) 50 | str_ = re.sub(r'\'d', ' \'d', str_) 51 | str_ = re.sub(r'\'ll', ' \'ll', str_) 52 | # lower case 53 | return str_.strip().lower().split() 54 | 55 | 56 | class NCEData(object): 57 | """An infinite, parallel (multiprocess) batch generator for 58 | noise-contrastive estimation of word vector models. 59 | 60 | Parameters 61 | ---------- 62 | dataset: torchtext.data.TabularDataset 63 | Dataset from which examples are generated. A column labeled *text* 64 | is expected and should be comprised of a list of tokens. Each row 65 | should represent a single document. 66 | 67 | batch_size: int 68 | Number of examples per single gradient update. 69 | 70 | context_size: int 71 | Half the size of a neighbourhood of target words (i.e. how many 72 | words left and right are regarded as context). 73 | 74 | num_noise_words: int 75 | Number of noise words to sample from the noise distribution. 76 | 77 | max_size: int 78 | Maximum number of pre-generated batches. 79 | 80 | num_workers: int 81 | Number of jobs to run in parallel. If value is set to -1, total number 82 | of machine CPUs is used. 83 | """ 84 | # code inspired by parallel generators in https://github.com/fchollet/keras 85 | def __init__(self, dataset, batch_size, context_size, 86 | num_noise_words, max_size, num_workers): 87 | self.max_size = max_size 88 | 89 | self.num_workers = num_workers if num_workers != -1 else os.cpu_count() 90 | if self.num_workers is None: 91 | self.num_workers = 1 92 | 93 | self._generator = _NCEGenerator( 94 | dataset, 95 | batch_size, 96 | context_size, 97 | num_noise_words, 98 | _NCEGeneratorState(context_size)) 99 | 100 | self._queue = None 101 | self._stop_event = None 102 | self._processes = [] 103 | 104 | def __len__(self): 105 | return len(self._generator) 106 | 107 | def vocabulary_size(self): 108 | return self._generator.vocabulary_size() 109 | 110 | def start(self): 111 | """Starts num_worker processes that generate batches of data.""" 112 | self._queue = multiprocessing.Queue(maxsize=self.max_size) 113 | self._stop_event = multiprocessing.Event() 114 | 115 | for _ in range(self.num_workers): 116 | process = multiprocessing.Process(target=self._parallel_task) 117 | process.daemon = True 118 | self._processes.append(process) 119 | process.start() 120 | 121 | def _parallel_task(self): 122 | while not self._stop_event.is_set(): 123 | try: 124 | batch = self._generator.next() 125 | # queue blocks a call to put() until a free slot is available 126 | self._queue.put(batch) 127 | except KeyboardInterrupt: 128 | self._stop_event.set() 129 | 130 | def get_generator(self): 131 | """Returns a generator that yields batches of data.""" 132 | while self._is_running(): 133 | yield self._queue.get() 134 | 135 | def stop(self): 136 | """Terminates all processes that were created with start().""" 137 | if self._is_running(): 138 | self._stop_event.set() 139 | 140 | for process in self._processes: 141 | if process.is_alive(): 142 | os.kill(process.pid, signal.SIGINT) 143 | process.join() 144 | 145 | if self._queue is not None: 146 | self._queue.close() 147 | 148 | self._queue = None 149 | self._stop_event = None 150 | self._processes = [] 151 | 152 | def _is_running(self): 153 | return self._stop_event is not None and not self._stop_event.is_set() 154 | 155 | 156 | class _NCEGenerator(object): 157 | """An infinite, process-safe batch generator for noise-contrastive 158 | estimation of word vector models. 159 | 160 | Parameters 161 | ---------- 162 | state: paragraphvec.data._NCEGeneratorState 163 | Initial (indexing) state of the generator. 164 | 165 | For other parameters see the NCEData class. 166 | """ 167 | def __init__(self, dataset, batch_size, context_size, 168 | num_noise_words, state): 169 | self.dataset = dataset 170 | self.batch_size = batch_size 171 | self.context_size = context_size 172 | self.num_noise_words = num_noise_words 173 | 174 | self._vocabulary = self.dataset.fields['text'].vocab 175 | self._sample_noise = None 176 | self._init_noise_distribution() 177 | self._state = state 178 | 179 | def _init_noise_distribution(self): 180 | # we use a unigram distribution raised to the 3/4rd power, 181 | # as proposed by T. Mikolov et al. in Distributed Representations 182 | # of Words and Phrases and their Compositionality 183 | probs = np.zeros(len(self._vocabulary) - 1) 184 | 185 | for word, freq in self._vocabulary.freqs.items(): 186 | probs[self._word_to_index(word)] = freq 187 | 188 | probs = np.power(probs, 0.75) 189 | probs /= np.sum(probs) 190 | 191 | self._sample_noise = lambda: choice( 192 | probs.shape[0], self.num_noise_words, p=probs).tolist() 193 | 194 | def __len__(self): 195 | num_examples = sum(self._num_examples_in_doc(d) for d in self.dataset) 196 | return ceil(num_examples / self.batch_size) 197 | 198 | def vocabulary_size(self): 199 | return len(self._vocabulary) - 1 200 | 201 | def next(self): 202 | """Updates state for the next process in a process-safe manner 203 | and generates the current batch.""" 204 | prev_doc_id, prev_in_doc_pos = self._state.update_state( 205 | self.dataset, 206 | self.batch_size, 207 | self.context_size, 208 | self._num_examples_in_doc) 209 | 210 | # generate the actual batch 211 | batch = _NCEBatch(self.context_size) 212 | 213 | while len(batch) < self.batch_size: 214 | if prev_doc_id == len(self.dataset): 215 | # last document exhausted 216 | batch.torch_() 217 | return batch 218 | if prev_in_doc_pos <= (len(self.dataset[prev_doc_id].text) - 1 219 | - self.context_size): 220 | # more examples in the current document 221 | self._add_example_to_batch(prev_doc_id, prev_in_doc_pos, batch) 222 | prev_in_doc_pos += 1 223 | else: 224 | # go to the next document 225 | prev_doc_id += 1 226 | prev_in_doc_pos = self.context_size 227 | 228 | batch.torch_() 229 | return batch 230 | 231 | def _num_examples_in_doc(self, doc, in_doc_pos=None): 232 | if in_doc_pos is not None: 233 | # number of remaining 234 | if len(doc.text) - in_doc_pos >= self.context_size + 1: 235 | return len(doc.text) - in_doc_pos - self.context_size 236 | return 0 237 | 238 | if len(doc.text) >= 2 * self.context_size + 1: 239 | # total number 240 | return len(doc.text) - 2 * self.context_size 241 | return 0 242 | 243 | def _add_example_to_batch(self, doc_id, in_doc_pos, batch): 244 | doc = self.dataset[doc_id].text 245 | batch.doc_ids.append(doc_id) 246 | 247 | # sample from the noise distribution 248 | current_noise = self._sample_noise() 249 | current_noise.insert(0, self._word_to_index(doc[in_doc_pos])) 250 | batch.target_noise_ids.append(current_noise) 251 | 252 | if self.context_size == 0: 253 | return 254 | 255 | current_context = [] 256 | context_indices = (in_doc_pos + diff for diff in 257 | range(-self.context_size, self.context_size + 1) 258 | if diff != 0) 259 | 260 | for i in context_indices: 261 | context_id = self._word_to_index(doc[i]) 262 | current_context.append(context_id) 263 | batch.context_ids.append(current_context) 264 | 265 | def _word_to_index(self, word): 266 | return self._vocabulary.stoi[word] - 1 267 | 268 | 269 | class _NCEGeneratorState(object): 270 | """Batch generator state that is represented with a document id and 271 | in-document position. It abstracts a process-safe indexing mechanism.""" 272 | def __init__(self, context_size): 273 | # use raw values because both indices have 274 | # to manually be locked together 275 | self._doc_id = multiprocessing.RawValue('i', 0) 276 | self._in_doc_pos = multiprocessing.RawValue('i', context_size) 277 | self._lock = multiprocessing.Lock() 278 | 279 | def update_state(self, dataset, batch_size, 280 | context_size, num_examples_in_doc): 281 | """Returns current indices and computes new indices for the 282 | next process.""" 283 | with self._lock: 284 | doc_id = self._doc_id.value 285 | in_doc_pos = self._in_doc_pos.value 286 | self._advance_indices( 287 | dataset, batch_size, context_size, num_examples_in_doc) 288 | return doc_id, in_doc_pos 289 | 290 | def _advance_indices(self, dataset, batch_size, 291 | context_size, num_examples_in_doc): 292 | num_examples = num_examples_in_doc( 293 | dataset[self._doc_id.value], self._in_doc_pos.value) 294 | 295 | if num_examples > batch_size: 296 | # more examples in the current document 297 | self._in_doc_pos.value += batch_size 298 | return 299 | 300 | if num_examples == batch_size: 301 | # just enough examples in the current document 302 | if self._doc_id.value < len(dataset) - 1: 303 | self._doc_id.value += 1 304 | else: 305 | self._doc_id.value = 0 306 | self._in_doc_pos.value = context_size 307 | return 308 | 309 | while num_examples < batch_size: 310 | if self._doc_id.value == len(dataset) - 1: 311 | # last document: reset indices 312 | self._doc_id.value = 0 313 | self._in_doc_pos.value = context_size 314 | return 315 | 316 | self._doc_id.value += 1 317 | num_examples += num_examples_in_doc( 318 | dataset[self._doc_id.value]) 319 | 320 | self._in_doc_pos.value = (len(dataset[self._doc_id.value].text) 321 | - context_size 322 | - (num_examples - batch_size)) 323 | 324 | 325 | class _NCEBatch(object): 326 | def __init__(self, context_size): 327 | self.context_ids = [] if context_size > 0 else None 328 | self.doc_ids = [] 329 | self.target_noise_ids = [] 330 | 331 | def __len__(self): 332 | return len(self.doc_ids) 333 | 334 | def torch_(self): 335 | if self.context_ids is not None: 336 | self.context_ids = torch.LongTensor(self.context_ids) 337 | self.doc_ids = torch.LongTensor(self.doc_ids) 338 | self.target_noise_ids = torch.LongTensor(self.target_noise_ids) 339 | 340 | def cuda_(self): 341 | if self.context_ids is not None: 342 | self.context_ids = self.context_ids.cuda() 343 | self.doc_ids = self.doc_ids.cuda() 344 | self.target_noise_ids = self.target_noise_ids.cuda() 345 | --------------------------------------------------------------------------------