├── cnn_text_trainer ├── __init__.py ├── rw │ ├── __init__.py │ ├── wordvecs.py │ └── datasets.py ├── config │ ├── __init__.py │ └── config.py └── core │ ├── __init__.py │ ├── multichannel │ ├── __init__.py │ └── model.py │ ├── unichannel │ ├── __init__.py │ └── model.py │ └── nn_classes.py ├── test ├── __init__.py ├── testConfig.json └── test_cnn_text_trainer.py ├── sample ├── configs │ ├── sampleMCConfig.json │ ├── sampleStaticConfig.json │ ├── sampleNonStaticConfig.json │ ├── mc │ │ ├── config-mc5.json │ │ ├── config-mc1.json │ │ ├── config-mc2.json │ │ ├── config-mc3.json │ │ └── config-mc4.json │ ├── static │ │ ├── config-static1.json │ │ ├── config-static2.json │ │ ├── config-static3.json │ │ ├── config-static4.json │ │ └── config-static5.json │ └── nonstatic │ │ ├── config-nonstatic5.json │ │ ├── config-nonstatic1.json │ │ ├── config-nonstatic2.json │ │ ├── config-nonstatic3.json │ │ └── config-nonstatic4.json └── datasets │ └── sst_small_sample.csv ├── requirements.txt ├── downloadWordVecs.sh ├── .gitignore ├── train.py ├── make └── gdown.pl ├── gpu_to_cpu.py ├── README.md ├── server.py ├── test.py └── LICENSE /cnn_text_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'devashish.shankar' 2 | -------------------------------------------------------------------------------- /cnn_text_trainer/rw/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'devashish.shankar' 2 | -------------------------------------------------------------------------------- /cnn_text_trainer/config/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'devashish.shankar' 2 | -------------------------------------------------------------------------------- /cnn_text_trainer/core/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'devashish.shankar' 2 | -------------------------------------------------------------------------------- /cnn_text_trainer/core/multichannel/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'devashish.shankar' 2 | -------------------------------------------------------------------------------- /cnn_text_trainer/core/unichannel/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'devashish.shankar' 2 | -------------------------------------------------------------------------------- /test/testConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "max_l":56, 5 | "filter_h":5, 6 | "filter_hs":[3,4,5], 7 | "mlp_hidden_units":[], 8 | "dropout_rate":0.5, 9 | "shuffle_batch":true, 10 | "n_epochs":5, 11 | "batch_size":50, 12 | "lr_decay":0.95, 13 | "conv_non_linear":"relu", 14 | "mode":"static" 15 | } -------------------------------------------------------------------------------- /sample/configs/sampleMCConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "max_l":56, 5 | "filter_h":5, 6 | "filter_hs":[3,4,5], 7 | "mlp_hidden_units":[], 8 | "dropout_rate":0.5, 9 | "shuffle_batch":true, 10 | "n_epochs":5, 11 | "batch_size":50, 12 | "lr_decay":0.95, 13 | "conv_non_linear":"relu", 14 | "mode":"multichannel" 15 | } -------------------------------------------------------------------------------- /sample/configs/sampleStaticConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "max_l":56, 5 | "filter_h":5, 6 | "filter_hs":[3,4,5], 7 | "mlp_hidden_units":[], 8 | "dropout_rate":0.5, 9 | "shuffle_batch":true, 10 | "n_epochs":5, 11 | "batch_size":50, 12 | "lr_decay":0.95, 13 | "conv_non_linear":"relu", 14 | "mode":"static" 15 | } -------------------------------------------------------------------------------- /sample/configs/sampleNonStaticConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "max_l":56, 5 | "filter_h":5, 6 | "filter_hs":[3,4,5], 7 | "mlp_hidden_units":[], 8 | "dropout_rate":0.5, 9 | "shuffle_batch":true, 10 | "n_epochs":5, 11 | "batch_size":50, 12 | "lr_decay":0.95, 13 | "conv_non_linear":"relu", 14 | "mode":"nonstatic" 15 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==0.10.1 2 | Jinja2==2.7.3 3 | MarkupSafe==0.23 4 | Theano==0.7.0 5 | Werkzeug==0.10.4 6 | argparse==1.2.1 7 | distribute==0.6.24 8 | gunicorn==19.3.0 9 | itsdangerous==0.24 10 | nltk==3.0.3 11 | numpy==1.9.2 12 | pandas==0.16.2 13 | python-dateutil==2.4.2 14 | pytz==2015.4 15 | scikit-learn==0.16.1 16 | scipy==0.15.1 17 | six==1.9.0 18 | wsgiref==0.1.2 19 | -------------------------------------------------------------------------------- /sample/configs/mc/config-mc5.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":100, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"multichannel" 16 | } -------------------------------------------------------------------------------- /sample/configs/mc/config-mc1.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":200, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[100], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"multichannel" 16 | } -------------------------------------------------------------------------------- /sample/configs/mc/config-mc2.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":200, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[50], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"multichannel" 16 | } -------------------------------------------------------------------------------- /sample/configs/mc/config-mc3.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":300, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[100], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"multichannel" 16 | } -------------------------------------------------------------------------------- /sample/configs/mc/config-mc4.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":300, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[50], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"multichannel" 16 | } -------------------------------------------------------------------------------- /sample/configs/static/config-static1.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":200, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[100], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"static" 16 | } -------------------------------------------------------------------------------- /sample/configs/static/config-static2.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":200, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[50], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"static" 16 | } -------------------------------------------------------------------------------- /sample/configs/static/config-static3.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":300, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[100], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"static" 16 | } -------------------------------------------------------------------------------- /sample/configs/static/config-static4.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":300, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[50], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"static" 16 | } -------------------------------------------------------------------------------- /sample/configs/static/config-static5.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":100, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"static" 16 | } -------------------------------------------------------------------------------- /sample/configs/nonstatic/config-nonstatic5.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":100, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"nonstatic" 16 | } -------------------------------------------------------------------------------- /sample/configs/nonstatic/config-nonstatic1.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":200, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[100], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"nonstatic" 16 | } -------------------------------------------------------------------------------- /sample/configs/nonstatic/config-nonstatic2.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":200, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[50], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"nonstatic" 16 | } -------------------------------------------------------------------------------- /sample/configs/nonstatic/config-nonstatic3.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":300, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[100], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"nonstatic" 16 | } -------------------------------------------------------------------------------- /sample/configs/nonstatic/config-nonstatic4.json: -------------------------------------------------------------------------------- 1 | { 2 | "word2vec":"GoogleNews-vectors-negative300.bin", 3 | "dim":300, 4 | "conv_features":300, 5 | "max_l":56, 6 | "filter_h":5, 7 | "filter_hs":[3,4,5], 8 | "mlp_hidden_units":[50], 9 | "dropout_rate":0.5, 10 | "shuffle_batch":true, 11 | "n_epochs":50, 12 | "batch_size":50, 13 | "lr_decay":0.95, 14 | "conv_non_linear":"relu", 15 | "mode":"nonstatic" 16 | } -------------------------------------------------------------------------------- /downloadWordVecs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ ! -f GoogleNews-vectors-negative300.bin ]; then 3 | echo "Downloading google word2vecs" 4 | perl ./make/gdown.pl "https://docs.google.com/uc?export=download&confirm=Kqnw&id=0B7XkCwpI5KDYNlNUTTlSS21pQmM" GoogleNews-vectors-negative300.bin.gz 5 | echo "done downloading word2vec. Uncompressing them" 6 | gunzip GoogleNews-vectors-negative300.bin.gz 7 | echo "done uncompressing word2vec" 8 | fi 9 | 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Intellij 2 | .idea/ 3 | .idea/* 4 | *.iml 5 | *.iws 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from cnn_text_trainer.rw import datasets 3 | from cnn_text_trainer.config import config 4 | from cnn_text_trainer.core.multichannel.model import MultiChannelTrainer 5 | from cnn_text_trainer.core.unichannel.model import TextCNNModelTrainer 6 | from cnn_text_trainer.rw import wordvecs 7 | 8 | __author__ = 'devashish.shankar' 9 | 10 | if __name__=="__main__": 11 | if len(sys.argv)<5: 12 | print "Usage: training.py" 13 | print "\t" 14 | print "\t" 15 | print "\t" 16 | print "\t" 17 | exit(0) 18 | 19 | #processing.. 20 | config_file=sys.argv[1] 21 | train_data_file=sys.argv[2] 22 | model_output_file=sys.argv[3] 23 | preprocess=sys.argv[4].lower() 24 | 25 | training_config = config.get_training_config_from_json(config_file) 26 | sentences, vocab, labels = datasets.build_data(train_data_file,preprocess) 27 | print "Dataset loaded" 28 | word_vecs = wordvecs.load_wordvecs(training_config.word2vec,vocab) 29 | print "Loaded word vecs from file" 30 | 31 | if training_config.mode=="multichannel": 32 | nntrainer = MultiChannelTrainer(training_config,word_vecs,sentences,labels) 33 | else: 34 | nntrainer = TextCNNModelTrainer(training_config,word_vecs,sentences,labels) 35 | 36 | nntrainer.train(model_output_file) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /make/gdown.pl: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/perl 2 | # 3 | # Google Drive direct download of big files 4 | # ./gdown.pl 'gdrive file url' ['desired file name'] 5 | # 6 | # v1.0 by circulosmeos 04-2014. 7 | # http://circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files 8 | # Distributed under GPL 3 (http://www.gnu.org/licenses/gpl-3.0.html) 9 | # 10 | use strict; 11 | 12 | my $TEMP='/tmp'; 13 | my $COMMAND; 14 | my $confirm; 15 | my $check; 16 | sub execute_command(); 17 | 18 | my $URL=shift; 19 | die "\n./gdown.pl 'gdrive file url' [desired file name]\n\n" if $URL eq ''; 20 | my $FILENAME=shift; 21 | $FILENAME='gdown' if $FILENAME eq ''; 22 | 23 | execute_command(); 24 | 25 | while (-s $FILENAME < 100000) { # only if the file isn't the download yet 26 | open fFILENAME, '<', $FILENAME; 27 | $check=0; 28 | foreach () { 29 | if (/href="(\/uc\?export=download[^"]+)/) { 30 | $URL='https://docs.google.com'.$1; 31 | $URL=~s/&/&/g; 32 | $confirm=''; 33 | $check=1; 34 | last; 35 | } 36 | if (/confirm=([^;&]+)/) { 37 | $confirm=$1; 38 | $check=1; 39 | last; 40 | } 41 | if (/"downloadUrl":"([^"]+)/) { 42 | $URL=$1; 43 | $URL=~s/\\u003d/=/g; 44 | $URL=~s/\\u0026/&/g; 45 | $confirm=''; 46 | $check=1; 47 | last; 48 | } 49 | } 50 | close fFILENAME; 51 | die "Couldn't download the file :-(\n" if ($check==0); 52 | $URL=~s/confirm=([^;&]+)/confirm=$confirm/ if $confirm ne ''; 53 | 54 | execute_command(); 55 | } 56 | 57 | sub execute_command() { 58 | $COMMAND="wget --load-cookie $TEMP/cookie.txt --save-cookie $TEMP/cookie.txt \"$URL\""; 59 | $COMMAND.=" -O \"$FILENAME\"" if $FILENAME ne ''; 60 | `$COMMAND`; 61 | return 1; 62 | } 63 | -------------------------------------------------------------------------------- /test/test_cnn_text_trainer.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | import os 3 | from cnn_text_trainer.config.config import get_training_config_from_json 4 | from cnn_text_trainer.core.unichannel.model import TextCNNModelTrainer 5 | from cnn_text_trainer.rw import wordvecs 6 | from cnn_text_trainer.rw.datasets import build_data 7 | 8 | __author__ = 'devashish.shankar' 9 | 10 | def test_config_reader(): 11 | #TODO improve this test case, probably check if values are actually getting correctly parsed from config 12 | config = get_training_config_from_json("testConfig.json") 13 | assert config.mode == "static" 14 | print config 15 | 16 | def test_dataset_reader(): 17 | sentences,vocabs,labels = build_data("../sample/datasets/sst_small_sample.csv") 18 | assert len(sentences) == 300 19 | assert len(labels) == 2 20 | assert "neg" in labels and "pos" in labels 21 | 22 | def trainer_helper(configFile,dataSetFile,tempModel): 23 | print "Training model on ",configFile,dataSetFile 24 | config = get_training_config_from_json(configFile) 25 | sentences, vocab, labels = build_data(dataSetFile,True) 26 | word_vecs = wordvecs.load_wordvecs(config.word2vec,vocab) 27 | trainer = TextCNNModelTrainer(config,word_vecs,sentences,labels) 28 | trainer.train(tempModel) 29 | print "Succesfully trained model on ",configFile,dataSetFile," and model is at ",tempModel 30 | print "Will proceed at testing the model on same data. If everything is correct, you should see the same accuracy" 31 | model = cPickle.load(open(tempModel,"rb")) 32 | op = model.classify(sentences) 33 | os.remove(tempModel) 34 | 35 | def test_all_trainers(): 36 | trainer_helper("../sample/configs/sampleMCConfig.json","../sample/datasets/sst_small_sample.csv","tempModel.p") 37 | trainer_helper("../sample/configs/sampleNonStaticConfig.json","../sample/datasets/sst_small_sample.csv","tempModel.p") 38 | trainer_helper("../sample/configs/sampleMCConfig.json","../sample/datasets/sst_small_sample.csv","tempModel.p") 39 | #TODO validate embeddings change in MC in test case 40 | #TODO validate if preprocess flag is working 41 | 42 | 43 | -------------------------------------------------------------------------------- /cnn_text_trainer/rw/wordvecs.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | import os 3 | import numpy as np 4 | 5 | def load_wordvecs_from_binfile(word_vec_file,vocab=None): 6 | """ 7 | Load word vectors from bin file 8 | :param word_vec_file: file path 9 | :param vocab: vocabulary. If not none, only words from this vocab will be loaded 10 | :return: dictionary of word to word_vector 11 | """ 12 | with open(word_vec_file, "rb") as f: 13 | word_vecs = {} 14 | header = f.readline() 15 | vocab_size, layer1_size = map(int, header.split()) 16 | binary_len = np.dtype('float32').itemsize * layer1_size 17 | i = 0 18 | for line in xrange(vocab_size): 19 | word = [] 20 | while True: 21 | ch = f.read(1) 22 | if ch == ' ': 23 | word = ''.join(word) 24 | break 25 | if ch != '\n': 26 | word.append(ch) 27 | if vocab == None or word in vocab: 28 | word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32') 29 | else: 30 | f.read(binary_len) 31 | return word_vecs 32 | 33 | 34 | def load_wordvecs(word_vec_file,vocab=None): 35 | i = 0 36 | cwd = os.getcwd() 37 | os.chdir(os.path.dirname(os.path.realpath(__file__))) 38 | while not os.path.isfile(word_vec_file): #TODO this is a hack. Find better way 39 | word_vec_file='../'+word_vec_file 40 | i+=1 41 | if i==4: 42 | raise Exception("File "+word_vec_file+" not found. Searched "+str(i)+" level above the cwd: till "+os.path.abspath(word_vec_file)) 43 | 44 | word_vec_file = os.path.abspath(word_vec_file) 45 | 46 | os.chdir(cwd) 47 | if word_vec_file.endswith('.bin'): 48 | return load_wordvecs_from_binfile(word_vec_file,vocab) 49 | else: 50 | model=cPickle.load(open(word_vec_file,"rb")) 51 | word_idx_map, W = model[2], model[3] 52 | word_vecs = {} 53 | for word in word_idx_map: 54 | word_vecs[word]=W[word_idx_map[word]] 55 | return word_vecs 56 | 57 | 58 | -------------------------------------------------------------------------------- /cnn_text_trainer/rw/datasets.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import csv 3 | import re 4 | 5 | 6 | def build_data(fname,preprocess=True): 7 | """ 8 | Reads a CSV file with headers 'labels' and 'text' (containing label string and text respectively) 9 | and outputs sentences, vocab and labels 10 | :param fname: file name to read 11 | :param preprocess: should data be preprocessed 12 | :return: sentences is a list of dictionary (a format which NNTrainer accepts) [{'text': , 'y':