├── convote_v1.1 └── readme.txt ├── README.md ├── utils.py ├── demo.py ├── rnn_theano.py ├── rnn.py ├── con_speech.ipynb └── con_util.py /convote_v1.1/readme.txt: -------------------------------------------------------------------------------- 1 | Download dataset from http://www.cs.cornell.edu/home/llee/data/convote.html 2 | Put sub-directory 'data_stage_three' here or adjust path to data in code. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Political Speech Generator 2 | 3 | Link to original paper: http://arxiv.org/abs/1601.03313 4 | 5 | ## Running the demo script 6 | 7 | ```sh 8 | $ python demo.py 9 | $ python demo.py [class] # Example: python demo.py RY 10 | $ python demo.py [class] [lambda] # Example: python demo.py RY 0.25 11 | ``` 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def softmax(x): 4 | xt = np.exp(x - np.max(x)) 5 | return xt / np.sum(xt) 6 | 7 | def save_model_parameters_theano(outfile, model): 8 | U, V, W = model.U.get_value(), model.V.get_value(), model.W.get_value() 9 | np.savez(outfile, U=U, V=V, W=W) 10 | print "Saved model parameters to %s." % outfile 11 | 12 | def load_model_parameters_theano(path, model): 13 | npzfile = np.load(path) 14 | U, V, W = npzfile["U"], npzfile["V"], npzfile["W"] 15 | model.hidden_dim = U.shape[0] 16 | model.word_dim = U.shape[1] 17 | model.U.set_value(U) 18 | model.V.set_value(V) 19 | model.W.set_value(W) 20 | print "Loaded model parameters from %s. hidden_dim=%d word_dim=%d" % (path, U.shape[0], U.shape[1]) 21 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import con_util 2 | reload(con_util) 3 | from con_util import * 4 | import os 5 | import pickle 6 | import sys 7 | 8 | # Demo script for conspeech 9 | # 10 | # Usage: 11 | # python demo.py 12 | # python demo.py [class] Example: python demo.py RY 13 | # python demo.py [class] [lambda] Example: python demo.py RY 0.25 14 | 15 | lambd = 0.5 16 | speech_class = 'DN' 17 | 18 | if len(sys.argv) >= 2: 19 | speech_class = sys.argv[1] 20 | if not speech_class in ['RY','RN','DN','DY']: 21 | print 'Invalid parameter:',speech_class 22 | sys.exit() 23 | 24 | if len(sys.argv) >= 3: 25 | lambd = float(sys.argv[2]) 26 | if (lambd < 0.0) or (lambd > 1.0): 27 | print 'Invalid parameter:',lambd 28 | sys.exit() 29 | 30 | 31 | # Dataset from http://www.cs.cornell.edu/home/llee/data/convote.html 32 | PATH_TO_DATA = 'convote_v1.1\data_stage_three' 33 | TRAIN_DIR = os.path.join(PATH_TO_DATA, "training_set") 34 | TEST_DIR = os.path.join(PATH_TO_DATA, "test_set") 35 | DEV_DIR = os.path.join(PATH_TO_DATA, "development_set") 36 | 37 | (dataset,vocab) = construct_dataset([TRAIN_DIR,TEST_DIR,DEV_DIR]) 38 | 39 | if sum([len(x) for x in dataset.values()]) == 0: 40 | print 'No data found!' 41 | sys.exit() 42 | 43 | print '# Class: ' + speech_class + ', Lambda: ' + str(lambd) + ' #' 44 | 45 | 46 | class_words = get_class_words(dataset) 47 | 48 | jk = pickle.load( open( "jk.p", "rb" ) ) 49 | jk_trend = get_jk_trend(jk,print_n=0) 50 | 51 | ngram_probs = get_n_gram_probs(dataset,n=6, verbose = False) 52 | gen_sp = generate_speech_wba(dataset,ngram_probs,None,None,jk_trend,jk,speech_class,lamb=lambd) -------------------------------------------------------------------------------- /rnn_theano.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano as theano 3 | import theano.tensor as T 4 | from utils import * 5 | import operator 6 | 7 | class RNNTheano: 8 | 9 | def __init__(self, word_dim, hidden_dim=100, bptt_truncate=4): 10 | # Assign instance variables 11 | self.word_dim = word_dim 12 | self.hidden_dim = hidden_dim 13 | self.bptt_truncate = bptt_truncate 14 | # Randomly initialize the network parameters 15 | U = np.random.uniform(-np.sqrt(1./word_dim), np.sqrt(1./word_dim), (hidden_dim, word_dim)) 16 | V = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (word_dim, hidden_dim)) 17 | W = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (hidden_dim, hidden_dim)) 18 | # Theano: Created shared variables 19 | self.U = theano.shared(name='U', value=U.astype(theano.config.floatX)) 20 | self.V = theano.shared(name='V', value=V.astype(theano.config.floatX)) 21 | self.W = theano.shared(name='W', value=W.astype(theano.config.floatX)) 22 | # We store the Theano graph here 23 | self.theano = {} 24 | self.__theano_build__() 25 | 26 | def __theano_build__(self): 27 | U, V, W = self.U, self.V, self.W 28 | x = T.ivector('x') 29 | y = T.ivector('y') 30 | def forward_prop_step(x_t, s_t_prev, U, V, W): 31 | s_t = T.tanh(U[:,x_t] + W.dot(s_t_prev)) 32 | o_t = T.nnet.softmax(V.dot(s_t)) 33 | return [o_t[0], s_t] 34 | [o,s], updates = theano.scan( 35 | forward_prop_step, 36 | sequences=x, 37 | outputs_info=[None, dict(initial=T.zeros(self.hidden_dim))], 38 | non_sequences=[U, V, W], 39 | truncate_gradient=self.bptt_truncate, 40 | strict=True) 41 | 42 | prediction = T.argmax(o, axis=1) 43 | o_error = T.sum(T.nnet.categorical_crossentropy(o, y)) 44 | 45 | # Gradients 46 | dU = T.grad(o_error, U) 47 | dV = T.grad(o_error, V) 48 | dW = T.grad(o_error, W) 49 | 50 | # Assign functions 51 | self.forward_propagation = theano.function([x], o) 52 | self.predict = theano.function([x], prediction) 53 | self.ce_error = theano.function([x, y], o_error) 54 | self.bptt = theano.function([x, y], [dU, dV, dW]) 55 | 56 | # SGD 57 | learning_rate = T.scalar('learning_rate') 58 | self.sgd_step = theano.function([x,y,learning_rate], [], 59 | updates=[(self.U, self.U - learning_rate * dU), 60 | (self.V, self.V - learning_rate * dV), 61 | (self.W, self.W - learning_rate * dW)]) 62 | 63 | def calculate_total_loss(self, X, Y): 64 | return np.sum([self.ce_error(x,y) for x,y in zip(X,Y)]) 65 | 66 | def calculate_loss(self, X, Y): 67 | # Divide calculate_loss by the number of words 68 | num_words = np.sum([len(y) for y in Y]) 69 | return self.calculate_total_loss(X,Y)/float(num_words) 70 | 71 | 72 | def gradient_check_theano(model, x, y, h=0.001, error_threshold=0.01): 73 | # Overwrite the bptt attribute. We need to backpropagate all the way to get the correct gradient 74 | model.bptt_truncate = 1000 75 | # Calculate the gradients using backprop 76 | bptt_gradients = model.bptt(x, y) 77 | # List of all parameters we want to chec. 78 | model_parameters = ['U', 'V', 'W'] 79 | # Gradient check for each parameter 80 | for pidx, pname in enumerate(model_parameters): 81 | # Get the actual parameter value from the mode, e.g. model.W 82 | parameter_T = operator.attrgetter(pname)(model) 83 | parameter = parameter_T.get_value() 84 | print "Performing gradient check for parameter %s with size %d." % (pname, np.prod(parameter.shape)) 85 | # Iterate over each element of the parameter matrix, e.g. (0,0), (0,1), ... 86 | it = np.nditer(parameter, flags=['multi_index'], op_flags=['readwrite']) 87 | while not it.finished: 88 | ix = it.multi_index 89 | # Save the original value so we can reset it later 90 | original_value = parameter[ix] 91 | # Estimate the gradient using (f(x+h) - f(x-h))/(2*h) 92 | parameter[ix] = original_value + h 93 | parameter_T.set_value(parameter) 94 | gradplus = model.calculate_total_loss([x],[y]) 95 | parameter[ix] = original_value - h 96 | parameter_T.set_value(parameter) 97 | gradminus = model.calculate_total_loss([x],[y]) 98 | estimated_gradient = (gradplus - gradminus)/(2*h) 99 | parameter[ix] = original_value 100 | parameter_T.set_value(parameter) 101 | # The gradient for this parameter calculated using backpropagation 102 | backprop_gradient = bptt_gradients[pidx][ix] 103 | # calculate The relative error: (|x - y|/(|x| + |y|)) 104 | relative_error = np.abs(backprop_gradient - estimated_gradient)/(np.abs(backprop_gradient) + np.abs(estimated_gradient)) 105 | # If the error is to large fail the gradient check 106 | if relative_error > error_threshold: 107 | print "Gradient Check ERROR: parameter=%s ix=%s" % (pname, ix) 108 | print "+h Loss: %f" % gradplus 109 | print "-h Loss: %f" % gradminus 110 | print "Estimated_gradient: %f" % estimated_gradient 111 | print "Backpropagation gradient: %f" % backprop_gradient 112 | print "Relative Error: %f" % relative_error 113 | return 114 | it.iternext() 115 | print "Gradient check for parameter %s passed." % (pname) -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import csv 3 | import itertools 4 | import operator 5 | import numpy as np 6 | import nltk 7 | import sys 8 | from datetime import datetime 9 | from utils import * 10 | import os 11 | import random 12 | from collections import defaultdict 13 | from scipy.stats import futil 14 | import re 15 | from rnn_theano import RNNTheano, gradient_check_theano 16 | from utils import load_model_parameters_theano, save_model_parameters_theano 17 | 18 | 19 | START_OF_SPEECH = "__START__" 20 | END_OF_SPEECH = "__END__" 21 | END_OF_SENTENCE = "__STOP__" 22 | REFERENCE = "" 23 | NUMBER = "" 24 | 25 | max_vocab_size = 6000 26 | unknown_token = "UNKNOWN_TOKEN" 27 | sentence_start_token = "SENTENCE_START" 28 | sentence_end_token = "SENTENCE_END" 29 | 30 | 31 | 32 | PATH_TO_DATA = 'convote_v1.1\data_stage_three' 33 | TRAIN_DIR = os.path.join(PATH_TO_DATA, "training_set") 34 | TEST_DIR = os.path.join(PATH_TO_DATA, "test_set") 35 | DEV_DIR = os.path.join(PATH_TO_DATA, "development_set") 36 | 37 | classes = ['DY','DN','RY','RN'] 38 | 39 | 40 | 41 | def construct_dataset(paths): 42 | print "[constructing dataset...]" 43 | 44 | class_sentences = dict() 45 | for c in classes: 46 | class_sentences[c] = [] 47 | 48 | #for l in labels: 49 | # dataset[l] = [] 50 | 51 | for p in paths: 52 | for f in sorted(os.listdir(p)): 53 | #006_400102_0002030_DON.txt 54 | vote = f[21:22] 55 | party = f[19:20] 56 | label = party + vote 57 | if label not in classes: 58 | continue; 59 | with open(os.path.join(p,f),'r') as doc: 60 | content = doc.read() 61 | 62 | content = content.replace('; center ', '; ') 63 | content = content.replace(' /center ', ' ') 64 | content = content.replace(' em ', ' ') 65 | content = content.replace(' /em ', ' ') 66 | content = content.replace(' pre ', ' ') 67 | content = content.replace(' /pre ', ' ') 68 | 69 | content = content.replace(' & lt ;', '') 70 | content = content.replace(' & gt ;', '') 71 | content = content.replace(' p ; ', ' ') 72 | content = content.replace(' & amp ; ', ' ') 73 | 74 | content = content.replace(' p nbsp ; ', ' ') 75 | content = content.replace(' nbsp ;', '') 76 | content = content.replace(' p ; ', ' ') 77 | content = content.replace(' p lt ;', '') 78 | content = content.replace(' p gt ;', '') 79 | 80 | content = content.replace(' b ', ' ') 81 | content = content.replace(' p ', ' ') 82 | 83 | content = content.replace(" n't", "n't") 84 | content = content.replace(" 's", "'s") 85 | content = content.replace(" h. con . res. ", " h.con.res. ") 86 | content = content.replace('.these ', '. these ') 87 | 88 | content = re.sub(r'[a-z]\.[a-z] \. ',lambda pat: pat.group(0).replace(' ','') + ' ',content) 89 | 90 | content = re.sub(r'xz[0-9]{7}',REFERENCE,content) 91 | #content = re.sub(r' [0-9]+ ', ' ' + NUMBER + ' ',content) 92 | #content = re.sub(r' [0-9]+\.[0-9]+ ', ' ' + NUMBER + ' ',content) 93 | 94 | #content = content.replace(' no . ' + NUMBER, ' no. ' + NUMBER) 95 | content = re.sub(r' no . [0-9]', lambda pat: pat.group(0).replace(' . ','. ') + ' ',content) 96 | 97 | content = content.replace(chr(0xc3), '') 98 | content = content.replace(chr(0x90), '') 99 | 100 | #lines = content.split(" . ") 101 | lines = re.split(r' \. | \! | \? ',content) 102 | lines = [x.strip() for x in lines] 103 | lines = filter(lambda a: (a.strip() != ''), lines) 104 | 105 | if len(lines) <= 1: 106 | continue 107 | 108 | 109 | for idx,line in enumerate(lines): 110 | lines[idx] = sentence_start_token + ' ' + lines[idx] + ' ' + sentence_end_token 111 | 112 | 113 | #lines.insert(0,START_OF_SPEECH) 114 | #lines.append(END_OF_SPEECH) 115 | 116 | class_sentences[label].extend(lines) 117 | 118 | print "[dataset constructed.]" 119 | return class_sentences 120 | 121 | 122 | 123 | def generate_sentence(model,word_to_index,index_to_word): 124 | # We start the sentence with the start token 125 | new_sentence = [word_to_index[sentence_start_token]] 126 | # Repeat until we get an end token 127 | while not new_sentence[-1] == word_to_index[sentence_end_token]: 128 | next_word_probs = model.forward_propagation(new_sentence) 129 | sampled_word = word_to_index[unknown_token] 130 | # We don't want to sample unknown words 131 | while sampled_word == word_to_index[unknown_token]: 132 | samples = np.random.multinomial(1, next_word_probs[-1]) 133 | sampled_word = np.argmax(samples) 134 | new_sentence.append(sampled_word) 135 | sentence_str = [index_to_word[x] for x in new_sentence[1:-1]] 136 | return sentence_str 137 | 138 | 139 | 140 | 141 | if __name__=='__main__': 142 | 143 | # Download NLTK model data (you need to do this once) 144 | nltk.download("book") 145 | 146 | dataset = construct_dataset([TRAIN_DIR,TEST_DIR,DEV_DIR]) 147 | print "Sentences",sum([len(x) for x in dataset.values()]) 148 | 149 | for label, sentences in dataset.iteritems(): 150 | print 'Processing',label,'...' 151 | 152 | # Tokenize the sentences into words 153 | tokenized_sentences = [nltk.word_tokenize(sent) for sent in sentences] 154 | 155 | # Count the word frequencies 156 | word_freq = nltk.FreqDist(itertools.chain(*tokenized_sentences)) 157 | print "Found %d unique words tokens." % len(word_freq.items()) 158 | 159 | vocabulary_size = min(max_vocab_size,len(word_freq.items())) 160 | 161 | # Get the most common words and build index_to_word and word_to_index vectors 162 | vocab = word_freq.most_common(vocabulary_size-1) 163 | index_to_word = [x[0] for x in vocab] 164 | index_to_word.append(unknown_token) 165 | word_to_index = dict([(w,i) for i,w in enumerate(index_to_word)]) 166 | 167 | print "Using vocabulary size %d." % vocabulary_size 168 | print "The least frequent word in our vocabulary is '%s' and appeared %d times." % (vocab[-1][0], vocab[-1][1]) 169 | 170 | # Replace all words not in our vocabulary with the unknown token 171 | for i, sent in enumerate(tokenized_sentences): 172 | tokenized_sentences[i] = [w if w in word_to_index else unknown_token for w in sent] 173 | 174 | #print "\nExample sentence: '%s'" % sentences[0] 175 | #print "\nExample sentence after Pre-processing: '%s'" % tokenized_sentences[0] 176 | 177 | # Create the training data 178 | X_train = np.asarray([[word_to_index[w] for w in sent[:-1]] for sent in tokenized_sentences]) 179 | y_train = np.asarray([[word_to_index[w] for w in sent[1:]] for sent in tokenized_sentences]) 180 | 181 | model = RNNTheano(vocabulary_size, hidden_dim=50) 182 | losses = train_with_sgd(model, X_train, y_train, nepoch=50) 183 | save_model_parameters_theano('./data/trained-model-'+label+'-dim50-t50.npz', model) 184 | #load_model_parameters_theano('./data/trained-model-theano.npz', model) 185 | 186 | 187 | -------------------------------------------------------------------------------- /con_speech.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#Political Speech Generator" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Author: Valentin Kassarnig
\n", 15 | "Email: valentin.kassarnig@gmail.com\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "collapsed": false 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "[constructing dataset...]\n", 30 | "[dataset constructed.]\n", 31 | "tokens 24938\n", 32 | "speeches 2771\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "import con_util\n", 38 | "reload(con_util)\n", 39 | "from con_util import *\n", 40 | "import os\n", 41 | "\n", 42 | "# Dataset from http://www.cs.cornell.edu/home/llee/data/convote.html\n", 43 | "PATH_TO_DATA = 'convote_v1.1\\data_stage_three'\n", 44 | "TRAIN_DIR = os.path.join(PATH_TO_DATA, \"training_set\")\n", 45 | "TEST_DIR = os.path.join(PATH_TO_DATA, \"test_set\")\n", 46 | "DEV_DIR = os.path.join(PATH_TO_DATA, \"development_set\")\n", 47 | "\n", 48 | "\n", 49 | "(dataset,vocab) = construct_dataset([TRAIN_DIR,TEST_DIR,DEV_DIR])\n", 50 | "\n", 51 | "print \"tokens\",len(vocab)\n", 52 | "print \"speeches\",sum([len(x) for x in dataset.values()])\n", 53 | "\n", 54 | "class_words = get_class_words(dataset)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": { 61 | "collapsed": false, 62 | "scrolled": true 63 | }, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "DY\n", 70 | "cell research enhancement\n", 71 | "research enhancement act\n", 72 | "spinal cord\n", 73 | "vitro fertilization\n", 74 | "clean coal\n", 75 | "stem cell\n", 76 | "air national guard\n", 77 | "human embryonic stem\n", 78 | "stem cell research\n", 79 | "heart disease\n", 80 | "medical research\n", 81 | "embryonic stem cell\n", 82 | "federal funding\n", 83 | "embryonic stem\n", 84 | "critical habitat\n", 85 | "cord blood\n", 86 | "free trade agreement\n", 87 | "nation's history\n", 88 | "adult stem\n", 89 | "middle east\n", 90 | "sickle cell\n", 91 | "title vii\n", 92 | "gentleman from delaware\n", 93 | "brac commission\n", 94 | "energy and water\n", 95 | "brac process\n", 96 | "clean air\n", 97 | "adult stem cell\n", 98 | "affordable housing\n", 99 | "health care system\n", 100 | "DN\n", 101 | "school of law\n", 102 | "cbc alternative\n", 103 | "cbc budget\n", 104 | "cbc alternative budget\n", 105 | "professor of law\n", 106 | "republican budget\n", 107 | "gun industry\n", 108 | "big oil\n", 109 | "judicial conference\n", 110 | "employment service\n", 111 | "democratic alternative\n", 112 | "middle class\n", 113 | "social security trust\n", 114 | "republican party\n", 115 | "republican bill\n", 116 | "bald eagle\n", 117 | "security trust fund\n", 118 | "minimum wage\n", 119 | "republican leadership\n", 120 | "house republican\n", 121 | "strong opposition\n", 122 | "war in iraq\n", 123 | "congressional black\n", 124 | "republican majority\n", 125 | "tax break\n", 126 | "national debt\n", 127 | "fiscal responsibility\n", 128 | "today in opposition\n", 129 | "maximum pell grant\n", 130 | "social security\n", 131 | "RY\n", 132 | "death tax repeal\n", 133 | "head start program\n", 134 | "public law\n", 135 | "death tax\n", 136 | "budget request\n", 137 | "community protection\n", 138 | "community protection act\n", 139 | "gang deterrence\n", 140 | "democrat substitute\n", 141 | "federal jurisdiction\n", 142 | "committee on homeland\n", 143 | "deterrence and community\n", 144 | "chairman sensenbrenner\n", 145 | "lawsuit abuse\n", 146 | "personal injury\n", 147 | "bankruptcy abuse\n", 148 | "chinese government\n", 149 | "pension protection act\n", 150 | "commission report\n", 151 | "chamber of commerce\n", 152 | "san diego\n", 153 | "pension protection\n", 154 | "sex offender\n", 155 | "good bill\n", 156 | "tax relief\n", 157 | "mandatory spending\n", 158 | "global war\n", 159 | "chairman for yielding\n", 160 | "consumer protection act\n", 161 | "driver's license\n", 162 | "RN\n", 163 | "inner cell\n", 164 | "inner cell mass\n", 165 | "human embryo\n", 166 | "human life\n", 167 | "adult stem cell\n", 168 | "world trade organization\n", 169 | "adult stem\n", 170 | "world trade\n", 171 | "sickle cell\n", 172 | "associate professor\n", 173 | "embryonic stem\n", 174 | "gentleman from maryland\n", 175 | "umbilical cord blood\n", 176 | "embryonic stem cell\n", 177 | "umbilical cord\n", 178 | "federal funding\n", 179 | "stem cell research\n", 180 | "rule of law\n", 181 | "central american\n", 182 | "brac process\n", 183 | "stem cell\n", 184 | "cord blood\n", 185 | "central america\n", 186 | "bone marrow\n", 187 | "base bill\n", 188 | "human embryonic stem\n", 189 | "time of war\n", 190 | "federal money\n", 191 | "security council\n", 192 | "air force\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "import pickle\n", 198 | "\n", 199 | "#jk = jk_pos_tag_filter(dataset)\n", 200 | "#pickle.dump( jk, open( \"jk.p\", \"wb\" ) )\n", 201 | "jk = pickle.load( open( \"jk.p\", \"rb\" ) )\n", 202 | "\n", 203 | "jk_trend = get_jk_trend(jk,print_n=30)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 4, 209 | "metadata": { 210 | "collapsed": false 211 | }, 212 | "outputs": [ 213 | { 214 | "name": "stdout", 215 | "output_type": "stream", 216 | "text": [ 217 | "DN 487396\n", 218 | "RN 80952\n", 219 | "RY 446941\n", 220 | "DY 99824\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "ngram_probs = get_n_gram_probs(dataset,n=6)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": { 232 | "collapsed": true 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "# Only first time\n", 237 | "#create_corpus_pos_tags(dataset)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 7, 243 | "metadata": { 244 | "collapsed": false, 245 | "scrolled": true 246 | }, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "__START__ mr. speaker , the gentleman from georgia and i are on the committee on rules that his state is impacted by virtue of education formulas .\n", 253 | "i do not disagree with what the gentleman says , but i find it curious that the leadership of the house will not even allow democrats the opportunity to offer a substitute and have a straight up-or-down vote on it .\n", 254 | "is that asking too much .\n", 255 | "the republicans say it is .\n", 256 | "so a yes vote on the motion to recommit we address two other abuses of power that should be addressed in this bill .\n", 257 | "one is what i will call the tauzin rule , and the democratic motion to commit would forbid a member of congress to negotiate with an outside entity that has business before his or her committee and before the congress , in the current congress or in a previous congress , called the tauzin rule because mr. tauzin , who managed the medicare bill , was at the time being courted by the pharmaceutical industry which was to benefit from provisions in the prescription drug bill under medicare .\n", 258 | "this is yet again another example of republicans being the handmaidens of the pharmaceutical industry .\n", 259 | "this bill also runs counter to the principles of federalism that my colleagues on the other side of our oceans .\n", 260 | "the wild horse is an icon of american history .\n", 261 | "the gentleman from iowa asked what is the difference between a bald eagle and a pigeon or a turkey .\n", 262 | "and if you do not know the difference , we can not explain it to you .\n", 263 | "shakespeare once said that `` horses are as full of spirit as the month of may and as gorgeous as the sun in midsummer '' .\n", 264 | "does everything have to be converted to the bottom line .\n", 265 | "there are so many things we could be doing rather than selling these beautiful creatures for horse meat .\n", 266 | "we are not just about dollars and cents .\n", 267 | "we are about the things that made our country great .\n", 268 | "the wild horse is one of those things we emphasize and plus-up .\n", 269 | "__END__\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "import con_util\n", 275 | "reload(con_util)\n", 276 | "from con_util import *\n", 277 | "\n", 278 | "lambd = 0.25\n", 279 | "speech_class = 'DN'\n", 280 | "\n", 281 | "gen_sp = generate_speech_wba(dataset,ngram_probs,None,None,jk_trend,jk,speech_class,lamb=lambd)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 8, 287 | "metadata": { 288 | "collapsed": false 289 | }, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "0.588235294118\n", 296 | "0.336282987086\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "print evaluate_grammar(gen_sp,verbose=False)\n", 302 | "print evaluate_content(gen_sp,dataset,speech_class,jk,jk_trend)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": { 309 | "collapsed": true 310 | }, 311 | "outputs": [], 312 | "source": [] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python 2", 318 | "language": "python", 319 | "name": "python2" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 2 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython2", 331 | "version": "2.7.10" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 0 336 | } 337 | -------------------------------------------------------------------------------- /con_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import operator 3 | import os 4 | import random 5 | from collections import defaultdict 6 | from scipy.stats import futil 7 | from sklearn import preprocessing 8 | import numpy as np 9 | from sklearn.feature_extraction import DictVectorizer 10 | import re 11 | import sys 12 | from nltk import pos_tag 13 | 14 | 15 | # Project: Political Speech Generator 16 | # Author: Valentin Kassarnig 17 | # Email: valentin.kassarnig@gmail.com 18 | 19 | 20 | START_OF_SPEECH = "__START__" 21 | END_OF_SPEECH = "__END__" 22 | END_OF_SENTENCE = "__STOP__" 23 | REFERENCE = "" 24 | NUMBER = "" 25 | 26 | 27 | 28 | classes = ['DY','DN','RY','RN'] 29 | 30 | 31 | vocab_count = defaultdict(float) 32 | 33 | def construct_dataset(paths): 34 | print "[constructing dataset...]" 35 | dataset = dict() 36 | 37 | for c in classes: 38 | dataset[c] = [] 39 | 40 | vocab = set() 41 | vocab.add(START_OF_SPEECH) 42 | vocab.add(END_OF_SPEECH) 43 | vocab.add(END_OF_SENTENCE) 44 | #for l in labels: 45 | # dataset[l] = [] 46 | 47 | for p in paths: 48 | for f in sorted(os.listdir(p)): 49 | #006_400102_0002030_DON.txt 50 | vote = f[21:22] 51 | party = f[19:20] 52 | label = party + vote 53 | if label not in classes: 54 | continue; 55 | with open(os.path.join(p,f),'r') as doc: 56 | content = doc.read() 57 | 58 | content = content.replace('; center ', '; ') 59 | content = content.replace(' /center ', ' ') 60 | content = content.replace(' em ', ' ') 61 | content = content.replace(' /em ', ' ') 62 | content = content.replace(' pre ', ' ') 63 | content = content.replace(' /pre ', ' ') 64 | 65 | content = content.replace(' & lt ;', '') 66 | content = content.replace(' & gt ;', '') 67 | content = content.replace(' p ; ', ' ') 68 | content = content.replace(' & amp ; ', ' ') 69 | 70 | content = content.replace(' p nbsp ; ', ' ') 71 | content = content.replace(' nbsp ;', '') 72 | content = content.replace(' p ; ', ' ') 73 | content = content.replace(' p lt ;', '') 74 | content = content.replace(' p gt ;', '') 75 | 76 | content = content.replace(' b ', ' ') 77 | content = content.replace(' p ', ' ') 78 | 79 | content = content.replace(" n't", "n't") 80 | content = content.replace(" 's", "'s") 81 | content = content.replace(" h. con . res. ", " h.con.res. ") 82 | content = content.replace('.these ', '. these ') 83 | 84 | content = re.sub(r'[a-z]\.[a-z] \. ',lambda pat: pat.group(0).replace(' ','') + ' ',content) 85 | 86 | content = re.sub(r'xz[0-9]{7}',REFERENCE,content) 87 | #content = re.sub(r' [0-9]+ ', ' ' + NUMBER + ' ',content) 88 | #content = re.sub(r' [0-9]+\.[0-9]+ ', ' ' + NUMBER + ' ',content) 89 | 90 | #content = content.replace(' no . ' + NUMBER, ' no. ' + NUMBER) 91 | content = re.sub(r' no . [0-9]', lambda pat: pat.group(0).replace(' . ','. ') + ' ',content) 92 | 93 | content = content.replace(chr(0xc3), '') 94 | content = content.replace(chr(0x90), '') 95 | 96 | #lines = content.split(" . ") 97 | lines = re.split(r' \. | \! | \? ',content) 98 | lines = [x.strip() for x in lines] 99 | lines = filter(lambda a: (a.strip() != ''), lines) 100 | 101 | if len(lines) <= 1: 102 | continue 103 | 104 | 105 | for idx,line in enumerate(lines): 106 | lines[idx] = lines[idx] + ' ' + END_OF_SENTENCE 107 | 108 | words = line.split(); 109 | for word in words: 110 | vocab.add(word) 111 | vocab_count[word] += 1 112 | 113 | lines.insert(0,START_OF_SPEECH) 114 | lines.append(END_OF_SPEECH) 115 | 116 | dataset[label].append(lines) 117 | 118 | print "[dataset constructed.]" 119 | return (dataset,vocab) 120 | 121 | 122 | def get_class_words(dataset): 123 | class_words = dict() 124 | 125 | for c in classes: 126 | class_words[c] = defaultdict(float) 127 | 128 | for key,speeches in dataset.iteritems(): 129 | for speech in speeches: 130 | for sentence in speech: 131 | for word in sentence.split(): 132 | class_words[key][word] += 1 133 | return class_words 134 | 135 | 136 | def jk_pos_tag_filter(dataset): 137 | #Justeson and Katz Filter 138 | import nltk 139 | from nltk import pos_tag 140 | import sys 141 | import pickle 142 | 143 | jk_trigram_filter_ = [['NN','NN','NN'],['JJ','JJ','NN'],['JJ','NN','NN'],['NN','JJ','NN'],['NN','IN','NN'],['NN','CC','NN']] 144 | jk_bigram_filter = [['NN','NN'],['JJ','NN']] 145 | #nltk.download('maxent_treebank_pos_tagger'); 146 | 147 | jk = dict() 148 | for c in classes: 149 | jk[c] =defaultdict(float) 150 | 151 | speech_cnt = 0 152 | for key,speeches in dataset.iteritems(): 153 | print key 154 | sys.stdout.flush() 155 | for idx,speech in enumerate(speeches): 156 | for sentence in speech: 157 | words = sentence.split() 158 | if len(words) < 3: 159 | continue 160 | 161 | tags = pos_tag(words) 162 | if ([tags[0][1], tags[1][1]] in jk_bigram_filter) and (tags[2][1] is not 'NN'): 163 | tw = tags[0][0]+' '+tags[1][0] 164 | jk[key][tw]+=1 165 | 166 | for i in range(len(tags)-2): 167 | t = [tags[i][1], tags[i+1][1] ,tags[i+2][1]] 168 | if t in jk_trigram_filter_: 169 | tw = tags[i][0]+' '+tags[i+1][0]+' '+tags[i+2][0] 170 | jk[key][tw]+=1 171 | else: 172 | t = [tags[i+1][1], tags[i+2][1]] 173 | if t in jk_bigram_filter: 174 | tw = tags[i+1][0]+' '+tags[i+2][0] 175 | jk[key][tw]+=1 176 | 177 | if idx % 100 == 0: 178 | print idx,'/',len(speeches),'...' 179 | sys.stdout.flush() 180 | 181 | 182 | return jk 183 | 184 | 185 | def get_jk_trend(jk,print_n=10,thresh=1.0,min_occ=20): 186 | jk_trend = dict() 187 | totsum = 0 188 | 189 | for c in classes: 190 | jk_trend[c] =defaultdict(float) 191 | totsum += sum(jk[c].values()) 192 | 193 | for c in classes: 194 | sorted_jk = sorted(jk[c].items(), key=operator.itemgetter(1),reverse=True) 195 | class_sum = sum(jk[c].values()) 196 | for f in sorted_jk: 197 | #if f[1] < 2: 198 | # continue 199 | 200 | p = f[1]/class_sum 201 | 202 | other_p = 0 203 | for c2 in classes: 204 | other_p += jk[c2][f[0]] 205 | other_p = other_p / totsum 206 | 207 | jk_trend[c][f[0]] = p/other_p 208 | 209 | for c in classes: 210 | if print_n > 0: 211 | print c 212 | 213 | remlist = [] 214 | for word, ratio in jk_trend[c].iteritems(): 215 | if (ratio > thresh) and (sum([jk[x][word] for x in classes]) >= min_occ): 216 | pass 217 | else: 218 | remlist.append(word) 219 | for r in remlist: 220 | del jk_trend[c][r] 221 | sorted_jk = sorted(jk_trend[c].items(), key=operator.itemgetter(1),reverse=True) 222 | for sj in sorted_jk[:print_n]: 223 | print sj[0] 224 | #print len(jk_trend[c]) 225 | return jk_trend 226 | 227 | def longest_common_substring(s1, s2): 228 | m = [[0] * (1 + len(s2)) for i in range(1 + len(s1))] 229 | longest, x_longest = 0, 0 230 | for x in range(1, 1 + len(s1)): 231 | for y in range(1, 1 + len(s2)): 232 | if s1[x - 1] == s2[y - 1]: 233 | m[x][y] = m[x - 1][y - 1] + 1 234 | if m[x][y] > longest: 235 | longest = m[x][y] 236 | x_longest = x 237 | else: 238 | m[x][y] = 0 239 | return len(s1[x_longest - longest: x_longest]) 240 | 241 | def generate_speech_sba(label,dataset,jk_trend,rand_set_size=20,sim_thresh = 0.1,max_sentences=30): 242 | from nltk import trigrams 243 | 244 | print label 245 | random.seed() 246 | last_speech = dataset[label][random.randint(0,len(dataset[label])-1)] 247 | last_idx = 1 248 | last_sentence = last_speech[last_idx] 249 | 250 | speech_cnt = 0 251 | 252 | max_struc_sim = 0 253 | max_text_sim = 0 254 | 255 | 256 | 257 | print last_sentence 258 | sys.stdout.flush() 259 | for i in range(max_sentences): 260 | D=[] 261 | random.seed() 262 | while len(D) < rand_set_size: 263 | idx = random.randint(0,len(dataset[label])-1) 264 | sp = dataset[label][idx] 265 | if sp != last_speech: 266 | D.append(sp) 267 | 268 | max_similarity = 0.0 269 | max_struc_sim = 0 270 | max_text_sim = 0 271 | 272 | last_topics = [] 273 | full_speech = " ".join(last_speech) 274 | for topic in jk_trend[label].keys(): 275 | if topic in full_speech: 276 | last_topics.append(topic) 277 | """ 278 | if (last_idx-1)/(len(last_speech)-2) <= 1/3: 279 | last_part = 1 280 | elif (last_idx-1)/(len(last_speech)-2) >= 2/3: 281 | last_part = 3 282 | else: 283 | last_part = 2 284 | #print last_part 285 | """ 286 | 287 | last_tags = [x[1] for x in pos_tag(last_sentence.split())] 288 | last_tg = list(trigrams(last_sentence.split())) 289 | 290 | for speech in D: 291 | topic_cnt = 0 292 | full_speech = " ".join(speech) 293 | for topic in last_topics: 294 | if topic in full_speech: 295 | topic_cnt += 1 296 | 297 | for idx,sentence in enumerate(speech): 298 | #print sentence 299 | similarity = 0.0 300 | struc_sim = 0 301 | text_sim = 0 302 | 303 | if (last_idx != 1) and (idx <= 1): 304 | continue 305 | 306 | if (len(sentence.split()) <= 1): 307 | continue 308 | 309 | 310 | tg = list(trigrams(sentence.split())) 311 | 312 | #for tg1 in last_tg: 313 | # for tg2 in tg: 314 | # if tg1 == tg2: 315 | # text_sim += 1 316 | # break 317 | text_sim = len(set(last_tg) & set(tg)) 318 | text_sim = text_sim/(min(len(set(last_tg)),len(set(tg)))+0.01) 319 | 320 | 321 | tags = [x[1] for x in pos_tag(sentence.split())] 322 | struc_sim = (longest_common_substring(last_tags,tags)) / (max(len(last_tags),len(tags))) 323 | similarity = ((struc_sim)+(text_sim*3)) 324 | #similarity = (text_sim) 325 | 326 | """ 327 | if (idx-1)/(len(speech)-2) <= 1/3: 328 | part = 1 329 | elif (idx-1)/(len(speech)-2) >= 2/3: 330 | part = 3 331 | else: 332 | part = 2 333 | 334 | if (last_part == 1) and (part == 3): 335 | continue 336 | if (part < last_part): 337 | continue 338 | 339 | #if part == last_part: 340 | # similarity += similarity 341 | 342 | #Same topics 343 | for i in range(topic_cnt): 344 | similarity += similarity 345 | """ 346 | if similarity > max_similarity: 347 | max_similarity = similarity 348 | max_struc_sim = struc_sim 349 | max_text_sim = text_sim 350 | 351 | if similarity > sim_thresh: 352 | last_speech = speech 353 | last_idx = idx+1 354 | 355 | 356 | if max_similarity <= sim_thresh: 357 | last_idx += 1 358 | else: 359 | speech_cnt += 1 360 | 361 | last_sentence = last_speech[last_idx] 362 | 363 | #print last_speech[last_idx-1] 364 | #print 'Similarity:',max_similarity,'/ Struc:',max_struc_sim,'/ Text:',max_text_sim 365 | print last_sentence 366 | sys.stdout.flush() 367 | if last_sentence == END_OF_SPEECH: 368 | break 369 | 370 | print speech_cnt 371 | 372 | 373 | def get_n_gram_class_probs(dataset,n=6): 374 | from nltk.util import ngrams 375 | 376 | class_tokens = dict() 377 | for c in classes: 378 | class_tokens[c] = [] 379 | 380 | 381 | for key,speeches in dataset.iteritems(): 382 | for speech in speeches: 383 | for sentence in speech: 384 | class_tokens[key].extend(sentence.split()) 385 | 386 | #print len(tokens) 387 | n_gram_count = dict() 388 | n_gram_class_probs = dict() 389 | for c,tokens in class_tokens.iteritems(): 390 | n_grams = ngrams(tokens,n) 391 | 392 | for ng in n_grams: 393 | if (END_OF_SPEECH in ng[:-1]): 394 | continue 395 | 396 | if ng not in n_gram_count: 397 | n_gram_count[ng] = defaultdict(float) 398 | n_gram_class_probs[ng] = defaultdict(float) 399 | 400 | n_gram_count[ng][c] += 1 401 | 402 | for n_gram,class_counts in n_gram_count.iteritems(): 403 | for c in classes: 404 | n_gram_class_probs[n_gram][c] = class_counts[c]/sum(class_counts.values()) 405 | 406 | 407 | return n_gram_class_probs 408 | 409 | def get_n_gram_probs(dataset,n=6,verbose=True): 410 | from nltk.util import ngrams 411 | from nltk import trigrams 412 | from nltk import bigrams 413 | 414 | class_tokens = dict() 415 | for c in classes: 416 | class_tokens[c] = [] 417 | 418 | 419 | for key,speeches in dataset.iteritems(): 420 | for speech in speeches: 421 | for sentence in speech: 422 | class_tokens[key].extend(sentence.split()) 423 | 424 | #print len(tokens) 425 | class_n_gram_probs = dict() 426 | for c,tokens in class_tokens.iteritems(): 427 | n_grams = ngrams(tokens,n) 428 | 429 | 430 | n_gram_count = defaultdict(float) 431 | for ng in n_grams: 432 | if (END_OF_SPEECH in ng[:-1]): 433 | continue 434 | n_gram_count[ng] += 1 435 | 436 | prob = dict() 437 | for key, value in n_gram_count.iteritems(): 438 | n_1_gram = tuple(key[:-1]) 439 | word = key[-1] 440 | if n_1_gram not in prob: 441 | prob[n_1_gram] = defaultdict(float) 442 | prob[n_1_gram][word] += value 443 | 444 | for n_1_gram, words in prob.iteritems(): 445 | n_1_gram_sum = sum(words.values()) 446 | for word,cnt in words.iteritems(): 447 | prob[n_1_gram][word] = prob[n_1_gram][word]/n_1_gram_sum 448 | 449 | for key, value in prob.iteritems(): 450 | prob[key] = sorted(value.items(), key=operator.itemgetter(1), reverse=True) 451 | 452 | #n_gram_probs = sorted(prob.items(), key=lambda x: len(x[1]), reverse= True) 453 | class_n_gram_probs[c] = prob 454 | if verbose == True: 455 | print c,len(prob) 456 | return class_n_gram_probs 457 | 458 | 459 | def get_corpus_n_gram_probs(dataset,n=6): 460 | from nltk.util import ngrams 461 | 462 | 463 | 464 | all_tokens = [] 465 | 466 | for key,speeches in dataset.iteritems(): 467 | for speech in speeches: 468 | for sentence in speech: 469 | all_tokens.extend(sentence.split()) 470 | 471 | 472 | n_grams = ngrams(all_tokens,n) 473 | 474 | n_gram_count = defaultdict(float) 475 | 476 | for ng in n_grams: 477 | if (END_OF_SPEECH in ng[:-1]): 478 | continue 479 | n_gram_count[ng] += 1 480 | 481 | n_gram_probs = dict() 482 | for key, value in n_gram_count.iteritems(): 483 | n_1_gram = tuple(key[:-1]) 484 | word = key[-1] 485 | if n_1_gram not in n_gram_probs: 486 | n_gram_probs[n_1_gram] = defaultdict(float) 487 | n_gram_probs[n_1_gram][word] += value 488 | 489 | for n_1_gram, words in n_gram_probs.iteritems(): 490 | n_1_gram_sum = sum(words.values()) 491 | for word,cnt in words.iteritems(): 492 | n_gram_probs[n_1_gram][word] = n_gram_probs[n_1_gram][word]/n_1_gram_sum 493 | 494 | for key, value in n_gram_probs.iteritems(): 495 | n_gram_probs[key] = sorted(value.items(), key=operator.itemgetter(1), reverse=True) 496 | 497 | print len(n_gram_probs) 498 | return n_gram_probs 499 | 500 | def get_start_key(dataset,label,n=5): 501 | cnt = 0 502 | probs = [] 503 | sentences = [] 504 | 505 | for speech in dataset[label]: 506 | sent = speech[1] 507 | words = sent.split()[:n-1] 508 | start = " ".join(words) 509 | 510 | cnt+=1 511 | 512 | if start in sentences: 513 | idx = sentences.index(start) 514 | probs[idx] +=1 515 | else: 516 | sentences.append(start) 517 | probs.append(1) 518 | 519 | for i in range(len(probs)): 520 | probs[i] = probs[i] / cnt 521 | 522 | idx = np.random.multinomial(1, probs)[0] 523 | result = START_OF_SPEECH + " " + sentences[idx] 524 | result = tuple(result.split()) 525 | 526 | return result 527 | 528 | 529 | def get_word_prob_for_topics(dataset, c, word, topics): 530 | count = 0.0 531 | totlen = 0.001 532 | for speech in dataset[c]: 533 | full_speech = " ".join(speech) 534 | speech_prob = 0 535 | for t,prob in topics.iteritems(): 536 | if t in full_speech: 537 | speech_prob += prob 538 | 539 | if speech_prob > 0.0: 540 | count+=full_speech.count(word)*speech_prob 541 | totlen += len(full_speech.split())*speech_prob 542 | 543 | p_w = count/totlen 544 | return p_w 545 | 546 | def get_n_topics_from_ngram(dataset, jk_trend,jk, c, ngram, n=3): 547 | topics = defaultdict(float) 548 | ngram_key = " ".join(ngram) 549 | for speech in dataset[c]: 550 | full_speech = " ".join(speech) 551 | 552 | if ngram_key in full_speech: 553 | for key in jk_trend[c].keys(): 554 | topics[key] += full_speech.count(key) 555 | 556 | for key,cnt in topics.iteritems(): 557 | topics[key] = cnt/jk[c][key] 558 | result = [] 559 | for t in sorted(topics.items(), key=operator.itemgetter(1),reverse=True)[:n]: 560 | result.append(t[0]) 561 | return result 562 | 563 | def get_topics_from_speech(speech, jk_trend,jk, c, n=3): 564 | 565 | topics = defaultdict(float) 566 | for key in jk_trend[c].keys(): 567 | if key in speech: 568 | topics[key] += speech.count(key) 569 | 570 | for key,cnt in topics.iteritems(): 571 | topics[key] = cnt/jk[c][key] 572 | 573 | if n is None: 574 | n=len(topics) 575 | result = dict() 576 | sorted_topics = sorted(topics.items(), key=operator.itemgetter(1),reverse=True)[:n] 577 | for t in sorted_topics: 578 | result[t[0]] = t[1]/sum([pair[1] for pair in sorted_topics]) 579 | return result 580 | 581 | 582 | import pickle 583 | def create_corpus_pos_tags(dataset): 584 | all_pos_tags = set() 585 | for label,speeches in dataset.iteritems(): 586 | print label,'...', 587 | sys.stdout.flush() 588 | for sp in speeches: 589 | for sent in sp[1:-1]: 590 | tags = pos_tag(sent.split()[:-1]) 591 | tag_sequence = [x[1] for x in tags] 592 | tag_sequence = " ".join(tag_sequence) 593 | all_pos_tags.add(tag_sequence) 594 | print 'Done!' 595 | sys.stdout.flush() 596 | pickle.dump( all_pos_tags, open( "all_pos_tags.p", "wb" ) ) 597 | return all_pos_tags 598 | 599 | def evaluate_grammar(speech,verbose=True): 600 | sp = speech.replace(START_OF_SPEECH,'') 601 | sp = sp.replace(END_OF_SPEECH,'') 602 | sentences = sp.split(END_OF_SENTENCE) 603 | if len(sentences[-1].strip())== 0: 604 | sentences = sentences[:-1] 605 | 606 | 607 | all_pos_tags = pickle.load( open( "all_pos_tags.p", "rb" ) ) 608 | 609 | 610 | acc_cnt = 0 611 | for sent in sentences: 612 | tags = pos_tag(sent.split()) 613 | tag_sequence = [x[1] for x in tags] 614 | tag_sequence = " ".join(tag_sequence) 615 | 616 | if tag_sequence in all_pos_tags: 617 | acc_cnt += 1 618 | elif verbose == True: 619 | print sent 620 | 621 | 622 | return acc_cnt/len(sentences) 623 | 624 | def evaluate_content(gen_speech, dataset, label,jk,jk_trend): 625 | gen_topics = get_topics_from_speech(gen_speech, jk_trend,jk, label, n=None) 626 | sorted_gen_topics = sorted(gen_topics.items(), key=operator.itemgetter(1),reverse=True) 627 | num_topics = len(sorted_gen_topics) 628 | if num_topics==0: 629 | return 1.0 630 | 631 | max_cnt = 0 632 | 633 | for speech in dataset[label]: 634 | sp = " ".join(speech) 635 | topics = get_topics_from_speech(sp, jk_trend,jk, label, n=num_topics) 636 | sorted_topics = sorted(topics.items(), key=operator.itemgetter(1),reverse=True) 637 | sorted_topics = [t[0] for t in sorted_topics] 638 | 639 | cnt =0 640 | for i in range(num_topics): 641 | if i < len(sorted_topics): 642 | if sorted_topics[i] == sorted_gen_topics[i][0]: 643 | cnt += sorted_gen_topics[i][1] 644 | 645 | if cnt > max_cnt: 646 | max_cnt = cnt 647 | return max_cnt 648 | 649 | def generate_speech_wba(dataset,n_gram_probs,ngram_class_probs,corpus_ngram_props,jk_trend,jk,label,lamb=0.3,max_words=900): 650 | wordcnt = 0 651 | next_word = '' 652 | nn = len(n_gram_probs[label].keys()[0]) 653 | tuple_key = get_start_key(dataset,label,n=nn) 654 | print " ".join(tuple_key), 655 | my_speech = " ".join(tuple_key) 656 | current_sentence = my_speech 657 | topic_cnt = defaultdict(float) 658 | 659 | #all_pos_tags = pickle.load( open( "all_pos_tags.p", "rb" ) ) 660 | 661 | sen_count = 0 662 | topics = [] 663 | speech_sentences = [] 664 | while (next_word != END_OF_SPEECH) and ((wordcnt < max_words) or (next_word != END_OF_SENTENCE)): 665 | 666 | topics = get_topics_from_speech(my_speech,jk_trend,jk,label) 667 | 668 | words = [] 669 | probs = [] 670 | topic_probs = dict() 671 | 672 | for (word,ngram_prob) in n_gram_probs[label][tuple_key]: 673 | topic_prob = get_word_prob_for_topics(dataset,label,word,topics) 674 | topic_probs[word] = topic_prob 675 | 676 | sum_probs = sum(topic_probs.values()) 677 | if sum_probs > 0: 678 | for word,prob in topic_probs.iteritems(): 679 | topic_probs[word] = topic_probs[word]/sum_probs 680 | 681 | 682 | for (word,ngram_prob) in n_gram_probs[label][tuple_key]: 683 | #for (word,ngram_prob) in corpus_ngram_props[tuple_key]: 684 | topic_prob = topic_probs[word] 685 | lang_prob = ngram_prob 686 | 687 | prob = lamb*lang_prob + (1-lamb)*topic_prob 688 | phrase = " ".join(tuple_key) + ' ' + word 689 | prob = prob/(1+my_speech.count(phrase)**2) 690 | 691 | if prob <= 0: 692 | continue 693 | 694 | words.append(word) 695 | probs.append(prob) 696 | 697 | if len(probs) > 1: 698 | probs = [p/sum(probs) for p in probs] 699 | ni = np.random.multinomial(1, probs)[0] 700 | else: 701 | ni = 0 702 | 703 | if len(word) > 0: 704 | next_word = words[ni] 705 | else: 706 | next_word = END_OF_SENTENCE 707 | 708 | if next_word == END_OF_SENTENCE: 709 | print '.' 710 | speech_sentences.append(current_sentence) 711 | current_sentence = '' 712 | sen_count += 1 713 | else: 714 | print next_word, 715 | current_sentence = current_sentence + ' ' + next_word 716 | my_speech = my_speech + ' ' + next_word 717 | tuple_key = tuple_key[1:] + (next_word,) 718 | wordcnt += 1 719 | return my_speech --------------------------------------------------------------------------------