├── README.md ├── classifier.py ├── data ├── countries-classify-gdp-normalised.test.txt └── countries-classify-gdp-normalised.train.txt ├── minimal_training_example.py ├── minimal_working_example.py ├── rnnclassifier.py └── stanford_sentiment_extractor.py /README.md: -------------------------------------------------------------------------------- 1 | Theano tutorial 2 | ========================== 3 | 4 | This repository contains contains code examples for the Theano tutorial at [http://www.marekrei.com/blog/theano-tutorial/](http://www.marekrei.com/blog/theano-tutorial/) 5 | 6 | Minimal Working Example 7 | -------------------------- 8 | Basically the smallest Theano example I could come up with. 9 | It calculates the dot product between vectors [0.2, 0.9] and [1.0, 1.0]. 10 | 11 | Run with: 12 | 13 | python minimal_working_example.py 14 | 15 | 16 | The script should print value 0.9 17 | 18 | 19 | 20 | Minimal Training Example 21 | ------------------------- 22 | 23 | The script iteratively modifies the first vector in the previous example, using gradient descent, such that the dot product would have value 20. 24 | 25 | Run with: 26 | 27 | python minimal_training_example.py 28 | 29 | It should print the value at each of the 10 iterations: 30 | 31 | 0.9 32 | 8.54 33 | 13.124 34 | 15.8744 35 | 17.52464 36 | 18.514784 37 | 19.1088704 38 | 19.46532224 39 | 19.679193344 40 | 19.8075160064 41 | 42 | 43 | Simple Classifier Example 44 | --------------------------- 45 | 46 | The next example tries to train a small network on tiny (but real) dataset. 47 | The task is to predict whether the GDP per capita for a country is more than the average GDP, based on the following features: 48 | 49 | * Population density (per suqare km) 50 | * Population growth rate (%) 51 | * Urban population (%) 52 | * Life expectancy at birth (years) 53 | * Fertility rate (births per woman) 54 | * Infant mortality (deaths per 1000 births) 55 | * Enrolment in tertiary education (%) 56 | * Unemployment (%) 57 | * Estimated control of corruption (score) 58 | * Estimated government effectiveness (score) 59 | * Internet users (per 100 people) 60 | 61 | The *data/* directory contains the files for training (121 countries) and testing (40 countries). 62 | Each row represents one country, the first column is the label, followed by the features. 63 | The feature values have been normalised, by subtracting the mean and dividing by the standard deviation. 64 | The label is 1 if the GDP is more than average, and 0 otherwise. 65 | 66 | 67 | Run with: 68 | 69 | python classifier.py data/countries-classify-gdp-normalised.train.txt data/countries-classify-gdp-normalised.test.txt 70 | 71 | The script will print information about 10 training epochs and the result on the test set: 72 | 73 | Epoch: 0, Training_cost: 28.4304042768, Training_accuracy: 0.578512396694 74 | Epoch: 1, Training_cost: 24.5186290354, Training_accuracy: 0.619834710744 75 | Epoch: 2, Training_cost: 22.1283727037, Training_accuracy: 0.619834710744 76 | Epoch: 3, Training_cost: 20.7941253329, Training_accuracy: 0.619834710744 77 | Epoch: 4, Training_cost: 19.9641569475, Training_accuracy: 0.619834710744 78 | Epoch: 5, Training_cost: 19.3749411377, Training_accuracy: 0.619834710744 79 | Epoch: 6, Training_cost: 18.8899216914, Training_accuracy: 0.619834710744 80 | Epoch: 7, Training_cost: 18.4006371608, Training_accuracy: 0.677685950413 81 | Epoch: 8, Training_cost: 17.7210185975, Training_accuracy: 0.793388429752 82 | Epoch: 9, Training_cost: 16.315597037, Training_accuracy: 0.876033057851 83 | Test_cost: 5.01800578051, Test_accuracy: 0.925 84 | 85 | 86 | 87 | 88 | RNN Classifier Example 89 | ------------------------- 90 | 91 | Now let's try a real task of realistic size - sentiment classification on the Stanford sentiment corpus. 92 | We will use a recurrent neural network for this, to show how they work in Theano. 93 | 94 | The task is to classify sentences into 5 classes, based on their fine-grained sentiment (very negative, slightly negative, neutral, slightly positive, very positive). 95 | We use the dataset published in "Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank" (Socher et al., 2013). 96 | 97 | Start by downloading the dataset from [http://nlp.stanford.edu/sentiment/](http://nlp.stanford.edu/sentiment/) (the main zip file) and unpack it somewhere. 98 | Then, create training and test splits in the format that is more suitable for us, using the provided script in this repository: 99 | 100 | python stanford_sentiment_extractor.py 1 full /path/to/sentiment/dataset/ > data/sentiment.train.txt 101 | python stanford_sentiment_extractor.py 2 full /path/to/sentiment/dataset/ > data/sentiment.test.txt 102 | 103 | Now we can run the classifier with: 104 | 105 | python rnnclassifier.py data/sentiment.train.txt data/sentiment.test.txt 106 | 107 | The code will train for 3 passes over the training data, and will then print performance on the test data. 108 | 109 | Epoch: 0 Cost: 25929.9481023 Accuracy: 0.286633895131 110 | Epoch: 1 Cost: 21541.7328736 Accuracy: 0.35779494382 111 | Epoch: 2 Cost: 17857.7320117 Accuracy: 0.443586142322 112 | Test_cost: 4934.24376649 Test_accuracy: 0.349773755656 113 | 114 | The accuracy on the test set is about 38%, which isn't a great result. But it is quite a difficult task - the current state-of-the-art system ([Tai ei al., 2015](https://aclweb.org/anthology/P/P15/P15-1150.pdf)) achieves 50.9% accuracy, using a large amount of additional phrase-level annotations, and a much bigger network based on LSTMs and parse trees. As there are 5 classes to choose from, a random system would get 20% accuracy. 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import sys 3 | import numpy 4 | import collections 5 | 6 | floatX=theano.config.floatX 7 | 8 | class Classifier(object): 9 | def __init__(self, n_features): 10 | # network parameters 11 | random_seed = 42 12 | hidden_layer_size = 5 13 | l2_regularisation = 0.001 14 | 15 | # random number generator 16 | rng = numpy.random.RandomState(random_seed) 17 | 18 | # setting up variables for the network 19 | input_vector = theano.tensor.fvector('input_vector') 20 | target_value = theano.tensor.fscalar('target_value') 21 | learningrate = theano.tensor.fscalar('learningrate') 22 | 23 | # input->hidden weights 24 | W_hidden_vals = numpy.asarray(rng.normal(loc=0.0, scale=0.1, size=(n_features, hidden_layer_size)), dtype=floatX) 25 | W_hidden = theano.shared(W_hidden_vals, 'W_hidden') 26 | 27 | # calculating the hidden layer 28 | hidden = theano.tensor.dot(input_vector, W_hidden) 29 | hidden = theano.tensor.nnet.sigmoid(hidden) 30 | 31 | # hidden->output weights 32 | W_output_vals = numpy.asarray(rng.normal(loc=0.0, scale=0.1, size=(hidden_layer_size, 1)), dtype=floatX) 33 | W_output = theano.shared(W_output_vals, 'W_output') 34 | 35 | # calculating the predicted value (output) 36 | predicted_value = theano.tensor.dot(hidden, W_output) 37 | predicted_value = theano.tensor.nnet.sigmoid(predicted_value) 38 | 39 | # calculating the cost function 40 | cost = theano.tensor.sqr(predicted_value - target_value).sum() 41 | cost += l2_regularisation * (theano.tensor.sqr(W_hidden).sum() + theano.tensor.sqr(W_output).sum()) 42 | 43 | # calculating gradient descent updates based on the cost function 44 | params = [W_hidden, W_output] 45 | gradients = theano.tensor.grad(cost, params) 46 | updates = [(p, p - (learningrate * g)) for p, g in zip(params, gradients)] 47 | 48 | # defining Theano functions for training and testing the network 49 | self.train = theano.function([input_vector, target_value, learningrate], [cost, predicted_value], updates=updates, allow_input_downcast=True) 50 | self.test = theano.function([input_vector, target_value], [cost, predicted_value], allow_input_downcast=True) 51 | 52 | def read_dataset(path): 53 | """Read a dataset, with each line containing a real-valued label and a feature vector""" 54 | dataset = [] 55 | with open(path, "r") as f: 56 | for line in f: 57 | line_parts = line.strip().split() 58 | label = float(line_parts[0]) 59 | vector = numpy.array([float(line_parts[i]) for i in xrange(1, len(line_parts))]) 60 | dataset.append((label, vector)) 61 | return dataset 62 | 63 | 64 | if __name__ == "__main__": 65 | path_train = sys.argv[1] 66 | path_test = sys.argv[2] 67 | 68 | # training parameters 69 | learningrate = 0.1 70 | epochs = 10 71 | 72 | # reading the datasets 73 | data_train = read_dataset(path_train) 74 | data_test = read_dataset(path_test) 75 | 76 | # creating the network 77 | n_features = len(data_train[0][1]) 78 | classifier = Classifier(n_features) 79 | 80 | # training 81 | for epoch in xrange(epochs): 82 | cost_sum = 0.0 83 | correct = 0 84 | for label, vector in data_train: 85 | cost, predicted_value = classifier.train(vector, label, learningrate) 86 | cost_sum += cost 87 | if (label == 1.0 and predicted_value >= 0.5) or (label == 0.0 and predicted_value < 0.5): 88 | correct += 1 89 | print "Epoch: " + str(epoch) + ", Training_cost: " + str(cost_sum) + ", Training_accuracy: " + str(float(correct) / len(data_train)) 90 | 91 | # testing 92 | cost_sum = 0.0 93 | correct = 0 94 | for label, vector in data_test: 95 | cost, predicted_value = classifier.test(vector, label) 96 | cost_sum += cost 97 | if (label == 1.0 and predicted_value >= 0.5) or (label == 0.0 and predicted_value < 0.5): 98 | correct += 1 99 | print "Test_cost: " + str(cost_sum) + ", Test_accuracy: " + str(float(correct) / len(data_test)) 100 | -------------------------------------------------------------------------------- /data/countries-classify-gdp-normalised.test.txt: -------------------------------------------------------------------------------- 1 | 0 -0.586763655625 0.331139493452 0.263031821779 0.157226080779 0.102457727252 -0.234540097286 0.00685240794282 -0.489590089209 -0.869030679557 -1.02952397087 -0.516838303499 2 | 0 0.179409684 1.14036939968 0.300616612831 -0.53924290613 1.02284896311 0.598805131141 -0.995379388772 0.965743324671 -0.407403927347 -0.834008868142 -0.707732053829 3 | 1 0.918280926655 0.409452065022 1.5374586448 1.25721392426 0.152401437725 -0.907292755652 1.04374673984 -0.319145635331 0.844117489755 1.30579531167 1.08687706905 4 | 0 -0.47324672426 1.1664735902 -0.336596798682 -1.23913154437 1.54369051518 1.38006628279 -0.558184187007 -0.738701214107 -0.592054628231 -1.3336585751 -1.32371950952 5 | 0 -0.391428362806 0.827119113398 -0.187985670844 -2.35735752826 1.51515125205 2.25681490853 -0.969662023963 -0.68625676676 -1.13574835861 -1.25762492404 -1.37256890843 6 | 0 -0.307191941235 -0.425882031729 -0.999298746654 -0.636133027058 -0.575349772021 0.733355662814 -0.766532113509 -0.437145641862 -1.15626510315 -1.7138268304 -1.41795381095 7 | 1 -0.211348831097 -0.939264445357 -0.27870758028 1.05317472842 -0.875012034857 -0.942015473503 1.93341847608 -0.0700345104324 0.823600745212 1.05597045819 0.970123541175 8 | 0 -0.343682666655 0.0787989850584 0.0457312482257 -0.0365541610781 -0.097317114639 0.173451837464 -0.754605219684 -0.0438122867588 -0.42792067189 -0.0953962578509 0.450449084742 9 | 1 -0.252499817835 -1.08718819166 0.916661578808 1.32218729947 -1.01770835049 -0.885591056995 1.79812277425 2.05396560712 1.06980167972 1.15372800955 1.0394134687 10 | 1 -0.498570408406 0.157111556629 1.21388383448 0.458155397545 -0.632428298275 -0.447216744125 -0.720315399938 0.611743305079 1.36729447559 1.00166070743 1.03075222776 11 | 0 -0.0210694821612 -0.843549080104 -1.41661953006 0.183443407618 0.765995594961 -0.568746256603 -1.0445778258 -1.07959012186 -0.0791360146645 -0.258325510121 -0.247300485429 12 | 0 -0.456628056538 -1.10459098534 0.307960767404 0.434217838256 -0.782259429693 -0.811805281561 0.788436668899 1.34596556794 -0.109911131478 0.0892568947219 0.514195818064 13 | 0 -0.647654805874 0.461660446069 -0.75132552753 -0.849291293101 0.273693306016 0.177792177196 -0.933508627056 0.965743324671 0.218356781204 0.0783949445706 -1.0067180911 14 | 0 1.22409958451 -0.330166666476 1.33959848042 1.00643949362 -0.917820929548 -0.703296788277 0.442929463412 -0.410923418188 -0.940839285456 -0.421254762391 0.666980108256 15 | 0 -0.401232497467 1.68855740067 -1.49135710307 -1.02825304587 1.49374680471 0.959053328846 -1.19366399861 0.677298864263 -0.674121606401 -1.6921029301 -1.42730795117 16 | 0 -0.591379952215 0.722702351304 -1.89139752277 -0.976958275971 0.794534858088 1.05020046321 -1.21080890848 -0.699367878597 -1.07419812498 -0.888318618899 -1.37534050553 17 | 0 -0.311808237824 1.34050152703 -1.19327082921 -0.483388601124 1.90043130427 1.02415842482 -0.620054948723 -0.712478990433 -1.08445649725 -1.34452052526 -1.42349700516 18 | 0 -0.35392644861 0.600882795528 -0.156449007088 0.271214458342 0.209479963979 -0.208498058898 -0.514949196892 -0.64692343125 -0.97161440227 -0.834008868142 -0.827257178809 19 | 0 -0.459793517057 1.06205682811 -0.881792273481 -1.72586191656 1.6935216466 2.37400408128 -0.953635260385 -0.817367885127 -1.14600673088 -1.30107272465 -1.40340292617 20 | 1 -0.165933265604 -0.974070032721 -0.0700471885779 0.596081334396 -0.953495008457 -0.77708256371 0.771291759026 0.598632193242 0.290165387103 0.88217925577 1.3165731788 21 | 0 -0.61758293095 0.270229715564 0.469532168018 -0.467430228266 0.373580726961 0.655229547649 0.123512338167 -0.778034549617 -0.725413467758 -0.453840612845 -0.27051261115 22 | 0 -0.512419298174 0.16581295347 1.61349224509 0.410280278968 -0.247148246057 -0.481939461976 1.63077354528 -0.16181229329 -1.27936557041 -1.2902107745 0.0710867315456 23 | 1 -0.389537879441 -0.0865275549235 0.952086324397 0.70437029308 -0.382709745912 -0.468918442781 -0.249203093858 -0.528923424719 -0.42792067189 0.295633947597 -0.123964414436 24 | 0 -0.251532593788 0.0700975882172 0.377946240397 0.979082283007 -0.682372008748 -0.677254749888 0.320678801708 -0.2011456288 0.587658182972 0.48028710017 0.190611856525 25 | 1 5.06521006459 -0.81744488958 1.66878940894 1.28457113488 -1.00343871893 -0.829166640487 0.0348060653447 -0.384701194514 0.977476329282 1.29493336152 0.970123541175 26 | 1 -0.581531852824 -0.747833714851 1.18666726165 1.1101689172 -0.682372008748 -0.946355813235 2.27967111243 -0.214256740637 2.27003123547 2.3485425262 1.69766778018 27 | 1 -0.514002028433 -2.55772425781 0.490700613553 0.324788995796 -1.03197798206 -0.720658147202 1.22637730153 0.72974331161 0.146548175305 0.849593405316 1.10870339622 28 | 0 -0.598502238381 0.948938669174 -0.508104408424 0.314530041815 -0.0331037726026 -0.369090628959 -0.484759246898 -0.148701181453 0.00293096350609 -0.24746355997 -0.588899828125 29 | 0 -0.643697980226 1.00114705022 -0.629066954339 -1.08752700221 1.42239864689 1.76201617915 -1.08967639307 2.866854541 -0.622829745045 -1.04038592102 -1.26898046678 30 | 0 -0.0678039895361 -0.0778261580824 -0.211746170935 -0.0388339286293 -0.275687509184 0.069283683911 -0.267466150028 -0.358478970841 -0.684379978673 -0.366945011634 -0.922877278792 31 | 0 -0.295717146856 0.365945080816 -1.56220659425 -0.88348780637 0.102457727252 0.420851202154 -0.740814748699 -1.19759012839 -1.07419812498 -0.953490319807 -1.28387780119 32 | 1 1.51633314097 -0.773937905375 1.17370698888 1.19338043282 -0.732315719221 -0.902952415921 1.5670392064 -0.528923424719 2.17770588503 1.89234061984 1.76695770771 33 | 1 -0.520508808387 -1.20030635059 0.571054304767 0.615459358581 -0.903551297984 -0.924654114578 1.38925394532 0.100409943445 0.997993073825 1.25148556092 1.28192821503 34 | 1 -0.602722852405 0.47036184291 1.12964206144 0.519709121428 -0.018834141039 -0.729338826665 0.6173602856 -0.489590089209 -0.0688776423932 -0.0193626067915 0.41580412098 35 | 0 -0.498922126242 0.157111556629 0.898949206013 -1.13198246946 0.530546674161 1.80107923674 -1.09787613258 6.5772991908 -0.315078576905 -1.24676297389 -1.16851007187 36 | 1 -0.547239363876 1.2708903523 -0.719788863774 -2.1259611218 1.60790385722 2.08320131928 -1.15900146343 1.94907671243 -1.60763348309 -1.84417023222 -0.972073127335 37 | 0 -0.025949567127 -0.73913231801 -0.19748987088 0.491212027038 -0.803663877039 -0.52534285929 -0.282747482741 -0.68625676676 -0.499729277789 -0.0410865070942 0.0104580449618 38 | 0 -0.308818636224 0.278931112405 -0.546553217661 -0.0057772991361 0.0453792009974 -0.269262815137 -0.209695258064 0.441298851201 -0.592054628231 -0.888318618899 0.0717796308209 39 | 1 -0.0822683855157 -0.852250476945 1.32707021673 1.03379670424 -0.739450535003 -0.915973435115 1.4611880237 -0.227367852473 2.4751986809 2.08785572257 1.76695770771 40 | 0 -0.476148396401 -0.13873593597 0.262599812686 -1.75891854605 -0.247148246057 0.394809163765 -0.722178977098 2.05396560712 0.0234477080487 0.306495897749 -0.0345804079291 41 | -------------------------------------------------------------------------------- /data/countries-classify-gdp-normalised.train.txt: -------------------------------------------------------------------------------- 1 | 0 -0.114714355829 -1.20030635059 -0.344372962348 -0.245152892018 -0.939225376893 -0.395132667348 0.1425208252 -0.489590089209 -0.622829745045 -0.649355715569 0.0475281561873 2 | 0 -0.292727545256 -1.0958895885 0.133429094014 0.462714932647 -0.960629824239 -0.768401884247 0.237935975798 2.84063231733 0.0131893357774 -0.127982108305 0.732805539404 3 | 1 -0.361884064635 -1.6875845737 0.746882005436 0.390902254783 -0.910686113766 -0.594788294992 0.941249996029 0.24463217365 -0.253528343277 0.100118844873 0.455645829306 4 | 0 -0.602151310923 2.17583562378 -1.65163247641 -1.51042388297 3.01346256623 1.67520938453 -1.21602692453 -1.02714567452 -0.715155095487 -0.779699117385 -1.40617452328 5 | 0 -0.472938971154 -0.0169163801943 0.830259760298 0.32364911202 -0.311361588093 -0.395132667348 0.317697078252 0.297076620997 -0.448437416432 -0.0410865070942 0.241886402893 6 | 0 4.20543581603 0.513868827116 -0.610058554266 0.736287038798 -0.76798979813 -0.6512127115 -0.790013185727 0.664187752426 -0.653604861859 -0.388668911937 -0.106295482917 7 | 1 -0.152084375837 -1.41784127162 0.225879039819 1.13866601159 -1.02484316628 -0.924654114578 1.17680614907 0.821521094467 1.11083516881 1.06683240834 0.762253758602 8 | 1 -0.43829476437 0.261528318723 0.839331951241 0.732867387471 -0.190069719803 -0.360409949497 0.275952949865 -0.699367878597 -0.407403927347 0.284771997446 0.110928439872 9 | 1 0.8821858838 -1.33952870005 1.52838645386 1.35182427764 -0.996303903148 -0.946355813235 0.952058743558 -0.660034543086 1.64427052692 1.41441481319 1.28366046322 10 | 0 -0.153403317719 -0.939264445357 -0.0821434431693 0.732867387471 -0.739450535003 -0.399473007079 0.763092019521 0.637965528752 -0.745930212301 -0.356083061483 0.438669797063 11 | 1 -0.274086499981 -1.32212590637 0.231495158023 1.14094577915 -0.967764640021 -0.872570037801 2.12499420814 1.09685444304 -0.263786715548 0.284771997446 0.485094048504 12 | 1 0.221352035868 -0.269256888588 -1.14402179266 0.543646680717 -0.475462351075 -0.6512127115 -0.745660049316 -0.122478957779 1.31600261424 0.469425150019 1.44787759146 13 | 0 -0.462695189198 1.04465403443 -0.159473070736 -1.89114506403 1.5365556994 1.60142360909 -0.837348045594 -0.72559010227 -1.27936557041 -1.00780007056 -1.25754762874 14 | 0 -0.201896414271 -1.01757701693 0.337337385697 0.410280278968 -0.746585350784 -0.412494026273 0.542817199195 1.1886322259 -0.643346489587 -0.0953962578509 -0.0983271412516 15 | 0 -0.355069531575 1.37530711439 -0.582409972343 -0.876648503716 1.39385938376 0.911309591801 -0.9968702505 0.0872988316087 -0.335595321448 -0.529874263905 -0.789840617945 16 | 0 -0.572299259646 0.783612129192 -1.3449060207 0.0420978194403 0.480602963688 -0.555725237409 -1.10458501036 -0.620701207576 0.454299343444 -0.312635260878 -1.0877873063 17 | 1 0.210184994595 -0.234451301223 0.752930132731 1.36436299917 -0.903551297984 -0.889931396726 0.833907951606 -0.673145654923 2.19822262957 1.99009817121 1.49672699036 18 | 1 2.12577222021 -0.800042095898 -0.627770927061 0.289452598752 -0.953495008457 -0.486279801707 0.204391586916 -0.0831456222691 0.331198876189 0.773559754256 -0.0210688720619 19 | 0 -0.589753257226 0.479063239751 0.749906069083 0.00220188729329 0.0311095694338 0.0606030044482 -0.108689375985 0.0872988316087 -0.561279511417 -0.649355715569 -0.927381124081 20 | 0 -0.535236992744 0.479063239751 -0.908576837219 -0.38535859642 0.295097753361 1.29325948816 -0.644281451805 -1.04025678635 -1.07419812498 -1.00780007056 -1.08259056174 21 | 1 -0.621100109303 -0.81744488958 0.762434332767 -0.197277773442 -0.889281666421 -0.664233730694 1.53163124036 -0.502701201045 -1.04342300817 -0.519012313753 0.390513297433 22 | 0 0.585995501688 -0.243152698064 -1.06582814691 0.539087145614 -0.710911271875 -0.251901456212 -0.364372162354 -0.987812339005 -0.58179625596 -0.366945011634 -0.0868943032101 23 | 0 -0.594061800709 -0.399777841205 1.56769928128 0.582402729088 -0.418383824821 -0.499300820901 1.50777745271 -0.27981229982 -0.50998765006 -0.323497211029 0.478165055752 24 | 0 0.43924123488 -0.982771429562 -0.181073525363 0.267794807015 -0.339900851221 -0.425515045468 -0.312564717303 0.572409969569 -0.376628810533 -0.0736723575482 0.155966892763 25 | 1 -0.646687581827 0.226722731358 1.42513628074 1.26747287824 -0.653832745621 -0.872570037801 1.82123113104 -0.542034536556 2.0443470455 1.69682551712 1.39798884364 26 | 0 0.973852344643 1.6102448291 -1.95015075936 -2.00513344159 2.44267730369 1.85316331351 -1.16310133318 -1.15825679288 -1.15626510315 -1.49658782737 -1.41275706639 27 | 1 0.0641341634573 4.96898400979 1.83770496413 0.862814137893 -0.396979377475 -0.772742223979 -0.828402875225 -1.15825679288 1.21341889152 0.979936807132 1.59719738527 28 | 0 -0.46331069541 0.957640066015 -1.40365925728 -1.21519398508 1.85762240958 2.0311172425 -1.15713788627 -0.109367845943 -1.45375789902 -1.57262147843 -1.26620886968 29 | 1 -0.367379655813 -0.939264445357 0.266055885427 1.11358856853 -0.525406061548 -0.902952415921 1.45708815395 0.703521087936 1.48013657058 1.60992991591 1.28192821503 30 | 0 0.803620912324 -0.260555491747 -1.77734712234 0.361265276616 -0.311361588093 -0.594788294992 -0.747523626476 -0.581367872066 -0.253528343277 -0.312635260878 -0.821367534969 31 | 0 0.667989722059 -0.59120857171 0.384426376786 0.129868870164 -0.389844561693 -0.481939461976 -0.36735388581 -0.424034530025 -0.397145555076 -0.204015759364 -0.571577346243 32 | 0 -0.365577101907 -0.225749904382 -1.21184722019 -2.5659562592 0.25228885867 2.1700081139 -0.877601312252 2.09329894263 0.105514686219 -0.356083061483 -1.29600353851 33 | 0 -0.328514835005 -1.287320319 -0.325796571368 0.596081334396 -1.09619132409 -0.759721204785 0.125375915327 2.46041007406 -0.315078576905 -0.562460114359 0.809370909318 34 | 1 -0.632267150576 0.914133081809 1.30071766209 -0.916544435863 0.994309699979 0.789780079322 -1.01140615235 1.10996555488 -0.571537883688 -0.996938120412 -1.15638433455 35 | 0 -0.402463509891 1.32309873334 -1.25288808398 -1.74296017319 2.13588022507 1.80541957647 -1.11129388813 -0.791145661454 -0.540762766874 -0.73625131678 -1.32579820735 36 | 0 -0.641807496861 -0.669521143281 -1.20363904743 -0.552921511438 -0.104451930421 0.225535914241 -0.800076502391 1.52952113365 -0.622829745045 -0.204015759364 -0.266355215498 37 | 0 -0.607515007912 1.43621689228 -0.897776609905 -2.20689286987 2.89930551372 2.4043864594 -1.0538957116 -0.0700345104324 -0.633088117316 -1.12728152223 -1.37984435082 38 | 0 1.32658136879 1.24478616177 -1.59503928528 -0.890327109024 1.38672456798 0.633527848992 -1.02407847704 -0.909145667985 0.669725161142 -0.117120158154 -1.17717131281 39 | 0 -0.124298666843 -0.486791809617 0.301048621923 0.363545044167 -0.318496403875 -0.230199757555 -0.513085619732 1.79174337039 0.823600745212 0.0566710442679 -0.251457881081 40 | 0 -0.368302915131 -0.486791809617 -1.58380704888 0.224479223541 1.0656578578 -0.386451987885 -1.00320641284 -0.476478977372 0.146548175305 0.0458090941165 -1.00741099037 41 | 0 -0.334845756041 1.18387638388 -1.38033076628 -1.55601923399 1.25116306812 1.0632214824 -1.13030237517 0.0610766079352 -1.13574835861 -0.638493765418 -0.342920585413 42 | 0 -0.33796725183 1.04465403443 0.437131486077 -0.19499800589 0.972905252634 0.182132516927 -0.68565286476 0.782187758957 -1.26910719814 -1.25762492404 -1.20904467947 43 | 1 6.81896312121 0.505167430275 1.40007975337 0.646236220523 -0.475462351075 -0.694616108814 -0.0341462895799 -1.07959012186 0.392749109816 0.654078302592 1.59373288889 44 | 1 -0.210073853944 -0.765236508534 0.498044768126 1.17400240864 -0.974899455802 -0.907292755652 1.36502744224 -0.660034543086 1.37755284786 1.75113526788 1.35121814256 45 | 0 -0.522047573917 1.00984844706 -1.07490033785 -2.42119101969 1.82194833067 1.68823040372 -1.1004851406 -0.935367891658 -0.42792067189 -0.747113266931 -1.28699584793 46 | 0 -0.632530938953 0.0787989850584 -0.121456270591 -0.209816494974 -0.14012600933 -0.325687231646 0.32850582578 -0.528923424719 -0.910064168642 -0.33435916118 0.392245545621 47 | 1 0.480875833642 -0.51289600014 1.01127157008 1.14208566292 -0.575349772021 -0.859549018607 0.998648172561 -0.188034516963 1.67504564373 1.60992991591 1.55978082441 48 | 0 -0.651699560981 0.157111556629 0.561550104731 -0.41157592326 -0.240013430275 -0.0522458285679 0.996039164537 -0.594478983903 -0.694638350944 -0.73625131678 -0.88684651648 49 | 0 0.173386516069 -0.156138729653 -1.68532918563 -0.362560920908 -0.204339351366 0.40783018296 -0.741187464131 -0.869812332474 -0.858772307285 -1.12728152223 -1.06873257623 50 | 0 -0.617319142573 1.44491828912 -1.48746902124 -2.33569973652 2.64245214558 2.82973975308 -1.19739115293 -1.13203456921 -1.28962394268 -1.6703790298 -1.38226949828 51 | 0 0.133114823919 1.26218895546 -0.264451280226 -2.1681368215 2.30711580383 2.32626034423 -0.893255360397 -0.712478990433 -1.16652347543 -1.13814347238 -0.315897513678 52 | 1 -0.10249216105 -1.14809796954 0.193910366971 0.68613215267 -1.06051724518 -0.863889358338 1.45857901568 0.100409943445 0.597916555243 0.664940252743 0.796898722364 53 | 0 4.50342875205 -0.130034539129 -1.18635868373 -0.0958281174107 -0.389844561693 0.386128484303 -0.791131332023 -0.568256760229 -0.899805796371 -0.953490319807 -1.23676065048 54 | 0 -0.455880656138 -0.0169163801943 0.556365995621 0.309970506713 -0.625293482493 -0.395132667348 0.774646197914 0.152854390792 -0.848513935014 -0.551598164207 -0.554254864362 55 | 0 -0.557131427995 -0.408479238046 1.23202821637 0.298571668957 -0.696641640312 -0.490620141438 -0.325982472855 -0.345367859004 -0.0791360146645 -0.182291859062 0.272027521366 56 | 0 -0.443746390818 0.104903175582 0.0651716573905 0.387482603456 -0.147260825112 -0.156413982121 -0.615582363539 -0.174923405126 -0.807480445929 -1.01866202071 -0.98731691139 57 | 0 -0.429413889026 1.47972387648 -1.25893621128 -1.21519398508 1.83621796223 0.585784111947 -1.13514767578 -0.76492343778 -0.879289051828 -0.801423017688 -1.00186779617 58 | 0 -0.266480601791 1.07945962179 -1.68792124018 -0.966699321991 1.41526383111 0.967734008309 -0.996497535068 1.00507666018 -0.622829745045 -0.529874263905 -1.40374937581 59 | 1 -0.274658041463 0.278931112405 0.734785750844 0.44903632734 -0.568214956239 -0.733679166396 0.294588721466 -0.778034549617 0.300423759375 1.04510850804 0.824614693374 60 | 1 0.248126556085 -0.88705606431 0.528285404605 1.29483008886 -0.982034271584 -0.911633095383 1.09890862378 -0.122478957779 -0.0381025255792 0.393391498959 0.554383976028 61 | 1 0.372238987242 -1.06978539797 0.765458396415 1.14094577915 -1.01770835049 -0.902952415921 0.825708212102 -0.515812312882 1.81866285553 1.65337771652 1.45515303385 62 | 0 -0.17156075097 0.00918781032917 -0.106335952352 -0.0205957882193 -0.618158666712 0.620506829798 -0.548866301206 -0.542034536556 -1.16652347543 -0.910042519202 0.422733113732 63 | 1 -0.574497496117 -0.860951873786 1.56769928128 0.687272036446 -0.511136429984 -0.781422903442 1.0724458281 -0.371590082678 1.34677773105 0.425977349413 0.454260030756 64 | 0 -0.349661869856 0.174514350311 -1.51641363044 -2.51580137307 0.494872595252 1.3670452636 -1.05911372765 1.7261878112 -0.345853693719 -0.801423017688 -0.735101575201 65 | 1 0.220384811821 1.00114705022 1.2652929165 1.16944287354 -0.903551297984 -0.976738191354 -0.602537323418 -0.555145648392 2.16744751276 1.75113526788 1.73231274394 66 | 1 0.111440212315 2.27155098903 1.81092040039 0.402301092539 -0.097317114639 -0.638191692306 -0.466496190729 -0.751812325944 -0.171461365106 -0.138844058456 1.28816430851 67 | 1 0.176244223482 -0.826146286422 -0.101151843242 0.311110390489 -0.489731982639 -0.564405916872 -1.22944468008 -0.502701201045 0.331198876189 0.360805648505 0.176060971745 68 | 1 0.482986140654 -0.878354667469 -1.83048424072 -0.116346025372 -0.703776456093 -0.251901456212 -0.851883947443 -0.620701207576 -0.304820204634 0.382529548808 0.607044320947 69 | 0 -0.312599602954 -1.38303568426 0.549885859232 0.00904118994706 -0.946360192675 -0.6512127115 1.76383295451 -0.188034516963 -1.06393975271 -0.681941566023 -0.287488643393 70 | 1 0.944571834849 -0.425882031729 1.77808770935 1.11244868476 -0.675237192966 -0.902952415921 1.30017495707 -0.24047896431 1.58272029329 1.67510161682 1.38586310632 71 | 0 -0.17991404956 0.722702351304 -0.165521198032 -0.742142218192 0.851613384343 1.05888114267 -0.823184859177 -0.673145654923 -0.0996527592072 -0.127982108305 -0.862248592208 72 | 1 -0.58790673859 -0.0169163801943 1.00651947006 1.20477927058 -0.646697929839 -0.955036492697 1.44292496753 -0.804256773291 2.29054798001 1.85975476939 1.83624763523 73 | 0 0.573289694886 -1.15679936639 -0.287347762131 0.178883872516 -0.525406061548 -0.134712283464 -1.05128670358 1.24107667324 0.997993073825 0.784421704408 0.191304755801 74 | 0 -0.465421002423 0.861924700762 -1.01269102852 -0.913124784536 1.11560156827 0.959053328846 -0.897727945582 1.10996555488 -1.26910719814 -1.44227807662 -0.850469304529 75 | 0 -0.389581844171 0.226722731358 0.502364859052 0.591521799293 -0.118721561984 -0.191136699972 0.203646156052 -0.673145654923 -0.684379978673 -0.605907914964 -0.237946345213 76 | 1 -0.190113866786 1.53193225753 1.22122798906 0.689551803997 -0.675237192966 -0.738019506128 -1.0538957116 -0.673145654923 1.20316051925 1.18631386001 1.48979799761 77 | 0 0.263821964489 -0.0691247612412 0.598702886691 0.259815620585 -0.175800088239 -0.0609265080307 -0.0140196562506 0.703521087936 -0.858772307285 -0.649355715569 0.10399944712 78 | 0 -0.133663154209 0.548674414481 0.00469038443341 0.460435165096 0.180940700852 -0.516662179827 -0.325609757423 -0.122478957779 -1.20755696451 -0.529874263905 -0.613151302758 79 | 0 -0.357443626964 -0.321465269635 0.439723540632 0.458155397545 -0.468327535293 -0.451557083856 0.0307061955924 1.17552111406 -0.222753226463 -0.0736723575482 -0.0193366238737 80 | 1 -0.134674342986 -0.730430921169 1.29207748023 1.24695497028 -0.539675693112 -0.902952415921 0.845462129999 0.074187719772 1.44936145376 1.39269091288 1.42050807008 81 | 0 -0.086576928999 -0.895757461151 -0.944433591901 0.373803998148 -0.967764640021 -0.59044795526 0.448892910324 -1.13203456921 -0.35611206599 0.176152495933 -0.536932382481 82 | 0 -0.644533310085 -0.382375047523 0.594814804858 -0.0171761368924 -0.332766035439 -0.247561116481 -0.831011883249 0.0217432724249 -0.386887182804 -0.0410865070942 -0.253536578906 83 | 0 -0.353398871857 0.740105144986 1.1490824706 0.307690739162 0.651838542452 -0.33870825084 0.207373310372 0.467521074875 0.0337060803201 -0.0953962578509 -0.0345804079291 84 | 0 -0.05610937151 1.03595263759 -0.264019271133 0.0386781681134 0.80166967387 -0.000161751791232 -0.614836932675 -0.843590108801 -0.540762766874 -0.812284967839 -0.900704501984 85 | 0 -0.312423744036 -0.617312762234 -0.145648779774 0.351006322636 -0.68950682453 -0.2779434946 -0.24025792349 0.742854423446 0.249131898018 0.567182701381 0.121321929001 86 | 0 0.699204679948 -0.826146286422 -0.728429045625 0.199401780477 -0.411249009039 -0.555725237409 0.68556720966 0.113521055282 0.403007482088 0.284771997446 0.0031826025717 87 | 0 0.994735591118 0.957640066015 -1.21746333839 -1.17415816916 1.47234235736 1.45385205822 -0.855611101763 1.39841001528 -0.756188584572 -1.7355507307 -1.24784703888 88 | 1 -0.251005017035 -1.40043847794 -0.151264897977 0.430798186929 -1.09619132409 -0.586107615529 0.911805476899 -0.306034523494 -0.284303460091 -0.388668911937 0.277224265931 89 | 1 -0.242124141692 -0.0517219675589 0.690288814311 0.434217838256 -0.504001614203 -0.551384897678 0.980385116392 -0.0175900630853 0.167064919847 0.382529548808 0.108503292409 90 | 1 -0.509737449679 -0.521597396981 1.13482617055 0.901570186264 -0.639563114057 -0.790103582904 2.25246288589 -0.16181229329 1.40832796468 1.47958651409 1.35225749147 91 | 0 0.345244643377 0.305035302928 -0.855439718835 -0.507326160413 0.402119990088 1.95733146707 -0.926054318416 -0.568256760229 -1.09471486953 -0.910042519202 -1.10996008311 92 | 0 0.742070291134 0.331139493452 -0.312404289499 -0.266810683755 0.230884411325 -0.030544129911 -0.230194606825 -0.306034523494 -0.602313000502 0.0349471439652 -0.199490435437 93 | 1 2.22179118927 -0.730430921169 -0.494280117463 0.483232840608 -0.675237192966 -0.317006552183 0.986348563304 0.297076620997 1.69556238827 1.5230343147 1.0854912705 94 | 1 -0.12825549249 -0.199645713859 0.620303341318 0.996180539641 -0.939225376893 -0.942015473503 0.453365495509 0.32329884467 1.26471075288 1.44700066364 0.658318867315 95 | 1 -0.557131427995 -0.382375047523 1.42556828983 0.977942399231 -0.675237192966 -0.711977467739 1.3471371015 -0.384701194514 1.59297866556 1.21889971046 0.672869752095 96 | 0 1.1462380487 -0.0691247612412 -1.06669216509 -0.543802441233 -0.182934904021 0.850544835561 -0.413943314813 -0.751812325944 -0.592054628231 -0.0845343076996 -1.01919027805 97 | 0 -0.294574063891 -1.5831678116 0.0163546299323 0.439917257135 -1.01770835049 -0.803124602099 0.671031307811 1.90974337692 -0.325336949176 -0.17142990891 0.211398834783 98 | 0 0.113902237162 1.74946717856 -1.74321840403 -1.44887015908 2.33565506696 0.919990271264 -0.943571943721 -0.673145654923 -0.981872774541 -0.671079615872 -0.946089404513 99 | 0 -0.240937093997 -0.81744488958 0.472988240758 0.669033896036 -0.632428298275 -0.551384897678 -1.01215158321 0.218409949976 0.751792139313 0.664940252743 0.456685178219 100 | 1 -0.558230546231 -0.547701587505 1.25319666191 1.26177345936 -0.632428298275 -0.950696152966 1.4436703984 -0.174923405126 2.36235658591 2.05526987211 1.80160267147 101 | 0 -0.299410184128 0.496466033434 -0.721948909237 -2.61611114532 1.47947717314 4.04503487787 -1.20559089243 -0.778034549617 -0.715155095487 -1.36624442556 -1.40998546929 102 | 1 -0.323195102745 -1.44394546214 0.0759718847043 0.700950641753 -0.946360192675 -0.876910377532 0.910687330603 0.847743318141 -0.0483608978505 0.545458801078 0.72760879484 103 | 0 0.632598114875 -0.391076444364 -1.70131352205 0.435357722032 -0.589619403584 -0.403813346811 -0.901827815334 1.4770766863 0.977476329282 0.578044651532 0.229760665577 104 | 1 -0.446911851336 -2.45330749572 0.469964177111 0.322509228245 -0.732315719221 -0.846527999413 1.57374808418 0.506854410385 0.310682131646 0.849593405316 0.900833613651 105 | 1 0.236212114412 -0.164840126494 -1.04725175593 0.0694550300554 -0.482597166857 -0.746700185591 -0.601791892554 -0.555145648392 1.080060052 0.860455355467 1.29405395235 106 | 0 0.0613204207743 1.32309873334 -1.74969854042 -1.89114506403 1.97177946209 0.946032309652 -1.25106217514 -0.2011456288 -0.376628810533 -0.595045964813 -1.30431832981 107 | 0 -0.557746934207 -0.0778261580824 0.917093587901 0.396601673661 -0.21860898293 -0.438536064662 0.32105151714 -0.751812325944 -0.407403927347 -0.225739659667 -0.131586306463 108 | 0 -0.533082721002 -0.103930348606 -0.901664691738 -0.128884746904 0.195210332416 -0.0262037901796 0.259926186288 -0.148701181453 -1.12548998634 -0.725389366629 -0.702535309265 109 | 0 -0.495580806806 1.2708903523 -0.999730755747 -0.790017336768 1.27970233125 0.724674983351 -1.12881151344 -0.882923444311 -0.294561832362 -0.996938120412 -1.38400174647 110 | 1 -0.0620446099819 -1.00887562009 0.737377805399 0.813799135541 -0.967764640021 -0.911633095383 1.12574413488 -0.306034523494 0.228615153475 0.947350956678 1.14334835999 111 | 1 -0.320337395332 0.0526947945349 0.862660442239 0.836596811053 -0.539675693112 -0.759721204785 -0.374062763587 -0.607590095739 0.6492084166 0.849593405316 0.633028043769 112 | 0 -0.575948332188 0.296333906087 -0.864511909779 -0.39333778285 -0.332766035439 0.498977317319 -0.955498837546 -0.948479003495 0.833859117484 0.469425150019 -0.574002493707 113 | 1 1.59441450043 -0.773937905375 1.17154694341 1.15576426823 -1.10332613988 -0.907292755652 2.47571942968 -0.804256773291 0.474816087987 1.25148556092 1.45861753022 114 | 1 -0.585928325766 -0.617312762234 1.29337350751 1.15918391956 -0.489731982639 -0.846527999413 1.72954313476 -0.319145635331 2.37261495818 1.89234061984 1.64604678418 115 | 0 -0.278175219817 1.20998057441 -0.466199526447 -1.34400085172 1.72919572551 1.48857477608 -0.820203135721 -1.13203456921 -0.951097657727 -0.627631815267 -1.32337305988 116 | 1 -0.175209823512 -1.40913987478 0.585742613914 0.470694119076 -1.11046095566 -0.820485961024 0.936777410845 0.20529883814 0.279907014832 0.621492452138 1.0394134687 117 | 0 -0.450604888608 -0.486791809617 -0.160769098014 -0.133444282006 -0.104451930421 -0.438536064662 1.0269745454 -0.0831456222691 -0.50998765006 -1.02952397087 -0.286102844843 118 | 1 -0.645500534132 -0.860951873786 1.6191083633 1.3256069508 -0.546810508893 -0.976738191354 1.73550658167 -0.437145641862 1.9007298337 1.5664821153 1.87089259899 119 | 0 0.941010691765 0.0439933976938 -0.0739352704108 -0.959860019337 0.352176279616 1.40176798145 -1.24286243564 -0.27981229982 -1.27936557041 -1.82244633191 -1.07843316608 120 | 0 -0.57885000433 1.6102448291 -0.723244936515 -1.69850470594 2.12874540929 1.39742764172 -1.19217313688 0.860854429977 -0.376628810533 -0.595045964813 -0.988356260303 121 | 1 -0.642818685638 -0.173541523335 1.05490448842 1.17856194374 -0.825068324384 -0.846527999413 0.954667751582 -0.253590076147 1.96228006733 1.84889281924 1.55111958347 122 | -------------------------------------------------------------------------------- /minimal_training_example.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import numpy 3 | 4 | x = theano.tensor.fvector('x') 5 | target = theano.tensor.fscalar('target') 6 | 7 | W = theano.shared(numpy.asarray([0.2, 0.7]), 'W') 8 | y = (x * W).sum() 9 | 10 | cost = theano.tensor.sqr(target - y) 11 | gradients = theano.tensor.grad(cost, [W]) 12 | W_updated = W - (0.1 * gradients[0]) 13 | updates = [(W, W_updated)] 14 | 15 | f = theano.function([x, target], y, updates=updates) 16 | 17 | for i in xrange(10): 18 | output = f([1.0, 1.0], 20.0) 19 | print output 20 | -------------------------------------------------------------------------------- /minimal_working_example.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import numpy 3 | 4 | x = theano.tensor.fvector('x') 5 | W = theano.shared(numpy.asarray([0.2, 0.7]), 'W') 6 | y = (x * W).sum() 7 | 8 | f = theano.function([x], y) 9 | 10 | output = f([1.0, 1.0]) 11 | print output 12 | -------------------------------------------------------------------------------- /rnnclassifier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import theano 3 | import collections 4 | import numpy 5 | import random 6 | 7 | floatX=theano.config.floatX 8 | 9 | class RnnClassifier(object): 10 | def __init__(self, n_words, n_classes): 11 | # network parameters 12 | random_seed = 42 13 | word_embedding_size = 200 14 | recurrent_size = 100 15 | l2_regularisation = 0.0001 16 | 17 | # random number generator 18 | self.rng = numpy.random.RandomState(random_seed) 19 | 20 | # this is where we keep shared weights that are optimised during training 21 | self.params = collections.OrderedDict() 22 | 23 | # setting up variables for the network 24 | input_indices = theano.tensor.ivector('input_indices') 25 | target_class = theano.tensor.iscalar('target_class') 26 | learningrate = theano.tensor.fscalar('learningrate') 27 | 28 | # creating the matrix of word embeddings 29 | word_embeddings = self.create_parameter_matrix('word_embeddings', (n_words, word_embedding_size)) 30 | 31 | # extract the relevant word embeddings, given the input word indices 32 | input_vectors = word_embeddings[input_indices] 33 | 34 | # gated recurrent unit 35 | # from: Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation (Cho et al, 2014) 36 | def gru_step(x, h_prev, W_xm, W_hm, W_xh, W_hh): 37 | m = theano.tensor.nnet.sigmoid(theano.tensor.dot(x, W_xm) + theano.tensor.dot(h_prev, W_hm)) 38 | r = _slice(m, 0, 2) 39 | z = _slice(m, 1, 2) 40 | _h = theano.tensor.tanh(theano.tensor.dot(x, W_xh) + theano.tensor.dot(r * h_prev, W_hh)) 41 | h = z * h_prev + (1.0 - z) * _h 42 | return h 43 | 44 | W_xm = self.create_parameter_matrix('W_xm', (word_embedding_size, recurrent_size*2)) 45 | W_hm = self.create_parameter_matrix('W_hm', (recurrent_size, recurrent_size*2)) 46 | W_xh = self.create_parameter_matrix('W_xh', (word_embedding_size, recurrent_size)) 47 | W_hh = self.create_parameter_matrix('W_hh', (recurrent_size, recurrent_size)) 48 | initial_hidden_vector = theano.tensor.alloc(numpy.array(0, dtype=floatX), recurrent_size) 49 | 50 | hidden_vector, _ = theano.scan( 51 | gru_step, 52 | sequences = input_vectors, 53 | outputs_info = initial_hidden_vector, 54 | non_sequences = [W_xm, W_hm, W_xh, W_hh] 55 | ) 56 | hidden_vector = hidden_vector[-1] 57 | 58 | # hidden->output weights 59 | W_output = self.create_parameter_matrix('W_output', (n_classes,recurrent_size)) 60 | output = theano.tensor.nnet.softmax([theano.tensor.dot(W_output, hidden_vector)])[0] 61 | predicted_class = theano.tensor.argmax(output) 62 | 63 | # calculating the cost function 64 | cost = -1.0 * theano.tensor.log(output[target_class]) 65 | for m in self.params.values(): 66 | cost += l2_regularisation * (theano.tensor.sqr(m).sum()) 67 | 68 | # calculating gradient descent updates based on the cost function 69 | gradients = theano.tensor.grad(cost, self.params.values()) 70 | updates = [(p, p - (learningrate * g)) for p, g in zip(self.params.values(), gradients)] 71 | 72 | # defining Theano functions for training and testing the network 73 | self.train = theano.function([input_indices, target_class, learningrate], [cost, predicted_class], updates=updates, allow_input_downcast = True) 74 | self.test = theano.function([input_indices, target_class], [cost, predicted_class], allow_input_downcast = True) 75 | 76 | def create_parameter_matrix(self, name, size): 77 | """Create a shared variable tensor and save it to self.params""" 78 | vals = numpy.asarray(self.rng.normal(loc=0.0, scale=0.1, size=size), dtype=floatX) 79 | self.params[name] = theano.shared(vals, name) 80 | return self.params[name] 81 | 82 | 83 | def _slice(M, slice_num, total_slices): 84 | """ Helper function for extracting a slice from a tensor""" 85 | if M.ndim == 3: 86 | l = M.shape[2] / total_slices 87 | return M[:, :, slice_num*l:(slice_num+1)*l] 88 | elif M.ndim == 2: 89 | l = M.shape[1] / total_slices 90 | return M[:, slice_num*l:(slice_num+1)*l] 91 | elif M.ndim == 1: 92 | l = M.shape[0] / total_slices 93 | return M[slice_num*l:(slice_num+1)*l] 94 | 95 | def read_dataset(path): 96 | """Read a dataset, where the first column contains a real-valued score, 97 | followed by a tab and a string of words. 98 | """ 99 | dataset = [] 100 | with open(path, "r") as f: 101 | for line in f: 102 | line_parts = line.strip().split("\t") 103 | dataset.append((float(line_parts[0]), line_parts[1].lower())) 104 | return dataset 105 | 106 | def score_to_class_index(score, n_classes): 107 | """Maps a real-valued score between [0.0, 1.0] to a class id, given n_classes.""" 108 | for i in xrange(n_classes): 109 | if score <= (i + 1.0) * (1.0 / float(n_classes)): 110 | return i 111 | 112 | def create_dictionary(sentences, min_freq): 113 | """Creates a dictionary that maps words to ids. 114 | If min_freq is positive, removes all words that have a smaller frequency. 115 | """ 116 | counter = collections.Counter() 117 | for sentence in sentences: 118 | for word in sentence: 119 | counter.update([word]) 120 | 121 | word2id = collections.OrderedDict() 122 | word2id[""] = 0 123 | word2id[""] = 1 124 | word2id[""] = 2 125 | 126 | word_count_list = counter.most_common() 127 | for (word, count) in word_count_list: 128 | if min_freq < 0 or count >= min_freq: 129 | word2id[word] = len(word2id) 130 | 131 | return word2id 132 | 133 | def sentence2ids(words, word2id): 134 | """Takes a list of words and converts them to ids using the word2id dictionary.""" 135 | ids = [word2id[""],] 136 | for word in words: 137 | if word in word2id: 138 | ids.append(word2id[word]) 139 | else: 140 | ids.append(word2id[""]) 141 | ids.append(word2id[""]) 142 | return ids 143 | 144 | if __name__ == "__main__": 145 | path_train = sys.argv[1] 146 | path_test = sys.argv[2] 147 | 148 | # training parameters 149 | min_freq = 2 150 | epochs = 3 151 | learningrate = 0.1 152 | n_classes = 5 153 | 154 | # reading the datasets 155 | sentences_train = read_dataset(path_train) 156 | sentences_test = read_dataset(path_test) 157 | 158 | # creating the dictionary from the training data 159 | word2id = create_dictionary([sentence.split() for label, sentence in sentences_train], min_freq) 160 | 161 | # mapping training and test data to the dictionary indices 162 | data_train = [(score_to_class_index(score, n_classes), sentence2ids(sentence.split(), word2id)) for score, sentence in sentences_train] 163 | data_test = [(score_to_class_index(score, n_classes), sentence2ids(sentence.split(), word2id)) for score, sentence in sentences_test] 164 | 165 | # shuffling the training data 166 | random.seed(1) 167 | random.shuffle(data_train) 168 | 169 | # creating the classifier 170 | rnn_classifier = RnnClassifier(len(word2id), n_classes) 171 | 172 | # training 173 | for epoch in xrange(epochs): 174 | cost_sum = 0.0 175 | correct = 0 176 | for target_class, sentence in data_train: 177 | cost, predicted_class = rnn_classifier.train(sentence, target_class, learningrate) 178 | cost_sum += cost 179 | if predicted_class == target_class: 180 | correct += 1 181 | print "Epoch: " + str(epoch) + "\tCost: " + str(cost_sum) + "\tAccuracy: " + str(float(correct)/len(data_train)) 182 | 183 | 184 | # testing 185 | cost_sum = 0.0 186 | correct = 0 187 | for target_class, sentence in data_test: 188 | cost, predicted_class = rnn_classifier.test(sentence, target_class) 189 | cost_sum += cost 190 | if predicted_class == target_class: 191 | correct += 1 192 | print "Test_cost: " + str(cost_sum) + "\tTest_accuracy: " + str(float(correct)/len(data_test)) 193 | -------------------------------------------------------------------------------- /stanford_sentiment_extractor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | split_type = sys.argv[1] 6 | sentence_type = sys.argv[2] 7 | data_path = sys.argv[3] 8 | 9 | dataset_fixes = {"\x83\xc2": "", "-LRB-":"(", "-RRB-":")", "\xc3\x82\xc2\xa0":" "} 10 | 11 | #read dataset split 12 | dataset_split = {} 13 | with open(data_path + "datasetSplit.txt", "r") as f: 14 | next(f) 15 | for line in f: 16 | line_parts = line.strip().split(",") 17 | dataset_split[line_parts[0].strip()] = line_parts[1].strip() 18 | 19 | # read relevant sentences 20 | sentences = [] 21 | with open(data_path + "datasetSentences.txt", "r") as f: 22 | next(f) 23 | for line in f: 24 | line_parts = line.strip().split("\t") 25 | if len(line_parts) != 2: 26 | raise ValueError("Unexpected file format") 27 | if dataset_split[line_parts[0]] == split_type: 28 | sentence = line_parts[1] 29 | for fix in dataset_fixes: 30 | sentence = sentence.replace(fix, dataset_fixes[fix]) 31 | sentences.append(sentence) 32 | 33 | 34 | # read sentiment labels 35 | sentiment_labels = {} 36 | with open(data_path + "sentiment_labels.txt", "r") as f: 37 | next(f) 38 | for line in f: 39 | line_parts = line.strip().split("|") 40 | if len(line_parts) != 2: 41 | raise ValueError("Unexpected file format") 42 | sentiment_labels[line_parts[0]] = float(line_parts[1]) 43 | 44 | # read the phrases 45 | phrases = {} 46 | with open(data_path + "dictionary.txt", "r") as f: 47 | for line in f: 48 | line_parts = line.strip().split("|") 49 | if len(line_parts) != 2: 50 | raise ValueError("Unexpected file format") 51 | phrases[line_parts[0]] = sentiment_labels[line_parts[1]] 52 | 53 | # print the labels and sentences/phrases 54 | if sentence_type == "full": 55 | for sentence in sentences: 56 | print str(phrases[sentence]) + "\t" + sentence 57 | elif sentence_type == "all": 58 | for phrase in phrases: 59 | print_phrase = False 60 | for sentence in sentences: 61 | if sentence.find(phrase) >= 0: 62 | print_phrase = True 63 | break 64 | if print_phrase: 65 | print str(phrases[phrase]) + "\t" + phrase 66 | --------------------------------------------------------------------------------