├── README.md
├── data
└── README.md
├── figure1.png
├── model
├── __init__.py
├── autoencoder.py
├── cnn_layer.py
├── lstm_layer.py
├── optimizers.py
└── utils.py
├── sentence_retrieval.py
├── train_autoencoder.py
└── vector_compositionality.py
/README.md:
--------------------------------------------------------------------------------
1 | # ConvSent
2 |
3 | The training code for the EMNLP 2017 paper “[Learning Generic Sentence Representations Using Convolutional Neural Networks](https://arxiv.org/pdf/1611.07897.pdf)”
4 |
5 | Illustration of the CNN-LSTM encoder-decoder models.
6 |
7 |
8 |
9 | ## Dependencies
10 |
11 | This code is written in python. To use it you will need:
12 |
13 | * Python 2.7 (do not use Python 3.0)
14 | * Theano 0.7 (you can also use the most recent version)
15 | * A recent version of NumPy and SciPy
16 |
17 | ## Getting started
18 |
19 | Inside the ./data file, I provided a randomly sampled 1M sentences from the BookCorpus dataset for demo purposes.
20 |
21 | We provide the CNN-LSTM auto-encoder training code here. The training code for the future predictor and the composite model is similar to this. The training code for the hierarchical model can be also revised based on this code.
22 |
23 | 1. Run `train_autoencoder.py` to start training.
24 | ```
25 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python train_autoencoder.py
26 | ```
27 |
28 | 2. After training, you can run `sentence_retrieval.py` and `vector_compositionality.py` to do the qualitative analysis experiments in the paper.
29 |
30 | ```
31 | you needed me ? - you got me ? + i got you . = i needed you .
32 | ```
33 |
34 | ## Citing ConvSent
35 |
36 | Please cite our EMNLP paper in your publications if it helps your research:
37 |
38 | @inproceedings{ConvSent_EMNLP2017,
39 | Author = {Gan, Zhe and Pu, Yunchen and Ricardo, Henao and Li, Chunyuan and He, Xiaodong and Carin, Lawrence},
40 | Title = {Learning Generic Sentence Representations Using Convolutional Neural Networks},
41 | booktitle={EMNLP},
42 | Year = {2017}
43 | }
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 |
2 | The randomly sampled 1M sentences from the BookCorpus dataset can be downloaded [here](https://www.dropbox.com/sh/joh8a379du99qwr/AADQMhGsxPxfxlUHlrcNCAVoa?dl=0).
3 |
4 |
--------------------------------------------------------------------------------
/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhegan27/ConvSent/6ce4494de195a330574c062a34871354fc1567c8/figure1.png
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhegan27/ConvSent/6ce4494de195a330574c062a34871354fc1567c8/model/__init__.py
--------------------------------------------------------------------------------
/model/autoencoder.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import theano
4 | import theano.tensor as tensor
5 | from theano import config
6 |
7 | from collections import OrderedDict
8 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
9 |
10 | from utils import dropout, numpy_floatX
11 | from utils import _p
12 | from utils import uniform_weight, zero_bias
13 |
14 | from cnn_layer import param_init_encoder, encoder
15 | from lstm_layer import param_init_decoder, decoder
16 |
17 | # Set the random number generators' seeds for consistency
18 | SEED = 123
19 | np.random.seed(SEED)
20 |
21 | """ init. parameters. """
22 | def init_params(options):
23 |
24 | n_words = options['n_words']
25 | n_x = options['n_x']
26 | n_h = options['n_h']
27 |
28 | params = OrderedDict()
29 | # word embedding
30 | params['Wemb'] = uniform_weight(n_words,n_x)
31 | #params['Wemb'] = W.astype(config.floatX)
32 | params['Wemb'][-1] = np.zeros((n_x,)).astype(theano.config.floatX)
33 | # encoding words into sentences
34 | length = len(options['filter_shapes'])
35 | for idx in range(length):
36 | params = param_init_encoder(options['filter_shapes'][idx],params,prefix=_p('cnn_encoder',idx))
37 |
38 | options['n_z'] = options['feature_maps'] * length
39 | params = param_init_decoder(options,params,prefix='decoder')
40 |
41 | params['Vhid'] = uniform_weight(n_h,n_x)
42 | params['bhid'] = zero_bias(n_words)
43 |
44 | return params
45 |
46 | def init_tparams(params):
47 | tparams = OrderedDict()
48 | for kk, pp in params.iteritems():
49 | tparams[kk] = theano.shared(params[kk], name=kk)
50 | #tparams[kk].tag.test_value = params[kk]
51 | return tparams
52 |
53 | """ Building model... """
54 |
55 | def build_model(tparams,options):
56 |
57 | trng = RandomStreams(SEED)
58 |
59 | # Used for dropout.
60 | use_noise = theano.shared(numpy_floatX(0.))
61 |
62 | x = tensor.matrix('x', dtype='int32')
63 |
64 | layer0_input = tparams['Wemb'][tensor.cast(x.flatten(),dtype='int32')].reshape((x.shape[0],1,x.shape[1],tparams['Wemb'].shape[1]))
65 | layer0_input = dropout(layer0_input, trng, use_noise)
66 |
67 | layer1_inputs = []
68 | for i in xrange(len(options['filter_hs'])):
69 | filter_shape = options['filter_shapes'][i]
70 | pool_size = options['pool_sizes'][i]
71 | conv_layer = encoder(tparams, layer0_input,filter_shape, pool_size,prefix=_p('cnn_encoder',i))
72 | layer1_input = conv_layer
73 | layer1_inputs.append(layer1_input)
74 | layer1_input = tensor.concatenate(layer1_inputs,1)
75 | layer1_input = dropout(layer1_input, trng, use_noise)
76 |
77 | # description string: n_steps * n_samples
78 | y = tensor.matrix('y', dtype='int32')
79 | y_mask = tensor.matrix('y_mask', dtype=config.floatX)
80 |
81 | n_steps = y.shape[0]
82 | n_samples = y.shape[1]
83 |
84 | n_x = tparams['Wemb'].shape[1]
85 |
86 | # n_steps * n_samples * n_x
87 | y_emb = tparams['Wemb'][y.flatten()].reshape([n_steps,n_samples,n_x])
88 | y_emb = dropout(y_emb, trng, use_noise)
89 |
90 | # n_steps * n_samples * n_x
91 | h_decoder = decoder(tparams, y_emb, layer1_input, mask=y_mask,prefix='decoder')
92 | h_decoder = dropout(h_decoder, trng, use_noise)
93 |
94 | # reconstruct the original sentence
95 | shape_w = h_decoder.shape
96 | h_decoder = h_decoder.reshape((shape_w[0]*shape_w[1], shape_w[2]))
97 |
98 | # (n_steps * n_samples) * n_words
99 | Vhid = tensor.dot(tparams['Vhid'],tparams['Wemb'].T)
100 | pred_w = tensor.dot(h_decoder, Vhid) + tparams['bhid']
101 | pred_w = tensor.nnet.softmax(pred_w)
102 |
103 | x_vec = y.reshape((shape_w[0]*shape_w[1],))
104 | x_index = tensor.arange(shape_w[0]*shape_w[1])
105 | x_pred_word = pred_w[x_index, x_vec]
106 |
107 | x_mask_reshape = y_mask.reshape((shape_w[0]*shape_w[1],))
108 | x_index_list = theano.tensor.eq(x_mask_reshape, 1.).nonzero()[0]
109 |
110 | x_pred_word_prob = x_pred_word[x_index_list]
111 |
112 | cost_x = -tensor.log(x_pred_word_prob + 1e-6).sum()
113 |
114 | # the cross-entropy loss
115 | num_words = y_mask.sum()
116 | cost = cost_x / num_words
117 |
118 | return use_noise, x, y, y_mask, cost
119 |
--------------------------------------------------------------------------------
/model/cnn_layer.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import theano
4 | import theano.tensor as tensor
5 | import theano.tensor.shared_randomstreams
6 | from utils import _p
7 | from theano.tensor.nnet import conv
8 | from theano.tensor.signal import pool
9 |
10 | rng = np.random.RandomState(3435)
11 |
12 | #def ReLU(x):
13 | # y = T.maximum(0.0, x)
14 | # return(y)
15 |
16 | def param_init_encoder(filter_shape, params, prefix='cnn_encoder'):
17 |
18 | """ filter_shape: (number of filters, num input feature maps, filter height,
19 | filter width)
20 | image_shape: (batch_size, num input feature maps, image height, image width)
21 | """
22 |
23 | W = np.asarray(rng.uniform(low=-0.01,high=0.01,size=filter_shape),dtype=theano.config.floatX)
24 | b = np.zeros((filter_shape[0],), dtype=theano.config.floatX)
25 |
26 | params[_p(prefix,'W')] = W
27 | params[_p(prefix,'b')] = b
28 |
29 | return params
30 |
31 |
32 | def encoder(tparams, layer0_input, filter_shape, pool_size,
33 | prefix='cnn_encoder'):
34 |
35 | """ filter_shape: (number of filters, num input feature maps, filter height,
36 | filter width)
37 | image_shape: (batch_size, num input feature maps, image height, image width)
38 | """
39 |
40 | conv_out = conv.conv2d(input=layer0_input, filters=tparams[_p(prefix,'W')],
41 | filter_shape=filter_shape)
42 |
43 | conv_out_tanh = tensor.tanh(conv_out + tparams[_p(prefix,'b')].dimshuffle('x', 0, 'x', 'x'))
44 | output = pool.pool_2d(input=conv_out_tanh, ds=pool_size, ignore_border=True)
45 |
46 | return output.flatten(2)
47 |
--------------------------------------------------------------------------------
/model/lstm_layer.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import theano
4 | import theano.tensor as tensor
5 | from utils import _p, numpy_floatX
6 | from utils import ortho_weight, uniform_weight, zero_bias
7 |
8 | def param_init_encoder(options, params, prefix='encoder'):
9 |
10 | n_x = options['n_x']
11 | n_h = options['n_h']
12 |
13 | W = np.concatenate([uniform_weight(n_x,n_h),
14 | uniform_weight(n_x,n_h),
15 | uniform_weight(n_x,n_h),
16 | uniform_weight(n_x,n_h)], axis=1)
17 | params[_p(prefix, 'W')] = W
18 |
19 | U = np.concatenate([ortho_weight(n_h),
20 | ortho_weight(n_h),
21 | ortho_weight(n_h),
22 | ortho_weight(n_h)], axis=1)
23 | params[_p(prefix, 'U')] = U
24 |
25 | params[_p(prefix,'b')] = zero_bias(4*n_h)
26 | params[_p(prefix, 'b')][n_h:2*n_h] = 3*np.ones((n_h,)).astype(theano.config.floatX)
27 |
28 | return params
29 |
30 | def encoder(tparams, state_below, mask, seq_output=False, prefix='encoder'):
31 |
32 | """ state_below: size of n_steps * n_samples * n_x
33 | """
34 |
35 | n_steps = state_below.shape[0]
36 | n_samples = state_below.shape[1]
37 |
38 | n_h = tparams[_p(prefix,'U')].shape[0]
39 |
40 | def _slice(_x, n, dim):
41 | if _x.ndim == 3:
42 | return _x[:, :, n*dim:(n+1)*dim]
43 | return _x[:, n*dim:(n+1)*dim]
44 |
45 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + \
46 | tparams[_p(prefix, 'b')]
47 |
48 | def _step(m_, x_, h_, c_, U):
49 | preact = tensor.dot(h_, U)
50 | preact += x_
51 |
52 | i = tensor.nnet.sigmoid(_slice(preact, 0, n_h))
53 | f = tensor.nnet.sigmoid(_slice(preact, 1, n_h))
54 | o = tensor.nnet.sigmoid(_slice(preact, 2, n_h))
55 | c = tensor.tanh(_slice(preact, 3, n_h))
56 |
57 | c = f * c_ + i * c
58 | c = m_[:, None] * c + (1. - m_)[:, None] * c_
59 |
60 | h = o * tensor.tanh(c)
61 | h = m_[:, None] * h + (1. - m_)[:, None] * h_
62 |
63 | return h, c
64 |
65 | seqs = [mask, state_below_]
66 |
67 | rval, updates = theano.scan(_step,
68 | sequences=seqs,
69 | outputs_info=[tensor.alloc(numpy_floatX(0.),
70 | n_samples,n_h),
71 | tensor.alloc(numpy_floatX(0.),
72 | n_samples,n_h)],
73 | non_sequences = [tparams[_p(prefix, 'U')]],
74 | name=_p(prefix, '_layers'),
75 | n_steps=n_steps,
76 | strict=True)
77 |
78 | h_rval = rval[0]
79 | if seq_output:
80 | return h_rval
81 | else:
82 | # size of n_samples * n_h
83 | return h_rval[-1]
84 |
85 | def param_init_decoder(options, params, prefix='decoder'):
86 |
87 | n_x = options['n_x']
88 | n_h = options['n_h']
89 | n_z = options['n_z']
90 |
91 | W = np.concatenate([uniform_weight(n_x,n_h),
92 | uniform_weight(n_x,n_h),
93 | uniform_weight(n_x,n_h),
94 | uniform_weight(n_x,n_h)], axis=1)
95 | params[_p(prefix, 'W')] = W
96 |
97 | U = np.concatenate([ortho_weight(n_h),
98 | ortho_weight(n_h),
99 | ortho_weight(n_h),
100 | ortho_weight(n_h)], axis=1)
101 | params[_p(prefix, 'U')] = U
102 |
103 | C = np.concatenate([uniform_weight(n_z,n_h),
104 | uniform_weight(n_z,n_h),
105 | uniform_weight(n_z,n_h),
106 | uniform_weight(n_z,n_h)], axis=1)
107 | params[_p(prefix,'C')] = C
108 |
109 | params[_p(prefix,'b')] = zero_bias(4*n_h)
110 | params[_p(prefix, 'b')][n_h:2*n_h] = 3*np.ones((n_h,)).astype(theano.config.floatX)
111 |
112 |
113 | C0 = uniform_weight(n_z, n_h)
114 | params[_p(prefix,'C0')] = C0
115 |
116 | params[_p(prefix,'b0')] = zero_bias(n_h)
117 |
118 | return params
119 |
120 |
121 | def decoder(tparams, state_below, z, mask=None, prefix='decoder'):
122 |
123 | """ state_below: size of n_steps * n_samples * n_x
124 | z: size of n_samples * n_z
125 | """
126 |
127 | n_steps = state_below.shape[0]
128 | n_samples = state_below.shape[1]
129 |
130 | n_h = tparams[_p(prefix,'U')].shape[0]
131 |
132 | # n_samples * n_h
133 | state_belowx0 = tensor.dot(z, tparams[_p(prefix, 'C0')]) + \
134 | tparams[_p(prefix, 'b0')]
135 | h0 = tensor.tanh(state_belowx0)
136 |
137 | def _slice(_x, n, dim):
138 | if _x.ndim == 3:
139 | return _x[:, :, n*dim:(n+1)*dim]
140 | return _x[:, n*dim:(n+1)*dim]
141 |
142 | # n_steps * n_samples * n_h
143 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + \
144 | tensor.dot(z, tparams[_p(prefix, 'C')]) + tparams[_p(prefix, 'b')]
145 |
146 | def _step(m_, x_, h_, c_, U):
147 | preact = tensor.dot(h_, U)
148 | preact += x_
149 |
150 | i = tensor.nnet.sigmoid(_slice(preact, 0, n_h))
151 | f = tensor.nnet.sigmoid(_slice(preact, 1, n_h))
152 | o = tensor.nnet.sigmoid(_slice(preact, 2, n_h))
153 | c = tensor.tanh(_slice(preact, 3, n_h))
154 |
155 | c = f * c_ + i * c
156 | c = m_[:, None] * c + (1. - m_)[:, None] * c_
157 |
158 | h = o * tensor.tanh(c)
159 | h = m_[:, None] * h + (1. - m_)[:, None] * h_
160 |
161 | return h, c
162 |
163 | seqs = [mask[:n_steps-1], state_below_[:n_steps-1]]
164 |
165 | rval, updates = theano.scan(_step,
166 | sequences=seqs,
167 | outputs_info = [h0,tensor.alloc(numpy_floatX(0.),
168 | n_samples,n_h)],
169 | non_sequences = [tparams[_p(prefix, 'U')]],
170 | name=_p(prefix, '_layers'),
171 | n_steps=n_steps-1,
172 | strict=True)
173 |
174 | h0x = tensor.shape_padleft(h0)
175 | h_rval = rval[0]
176 |
177 | return tensor.concatenate((h0x,h_rval))
178 |
--------------------------------------------------------------------------------
/model/optimizers.py:
--------------------------------------------------------------------------------
1 | import theano
2 | import theano.tensor as tensor
3 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
4 | from utils import numpy_floatX
5 |
6 | def SGD(tparams, cost, inps, lr):
7 | """ default: lr=0.01 """
8 |
9 | grads = tensor.grad(cost, tparams.values())
10 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
11 | for k, p in tparams.iteritems()]
12 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
13 | f_grad_shared = theano.function(inps, cost, updates=gsup)
14 |
15 | updates = []
16 |
17 | for p, g in zip(tparams.values(), grads):
18 | updated_p = p - lr * g
19 | updates.append((p, updated_p))
20 |
21 | f_update = theano.function([lr], [], updates=updates)
22 |
23 | return f_grad_shared, f_update
24 |
25 | def Momentum(tparams, cost, inps, lr, momentum=0.9):
26 | """ default: lr=0.01 """
27 |
28 | grads = tensor.grad(cost, tparams.values())
29 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
30 | for k, p in tparams.iteritems()]
31 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
32 | f_grad_shared = theano.function(inps, cost, updates=gsup)
33 |
34 | updates = []
35 |
36 | for p, g in zip(tparams.values(), gshared):
37 | m = theano.shared(p.get_value() * 0.)
38 | m_new = momentum * m - lr * g
39 | updates.append((m, m_new))
40 |
41 | updated_p = p + m_new
42 | updates.append((p, updated_p))
43 |
44 | f_update = theano.function([lr], [], updates=updates)
45 |
46 | return f_grad_shared, f_update
47 |
48 | def NAG(tparams, cost, inps, lr, momentum=0.9):
49 | """ default: lr=0.01 """
50 |
51 | grads = tensor.grad(cost, tparams.values())
52 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
53 | for k, p in tparams.iteritems()]
54 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
55 | f_grad_shared = theano.function(inps, cost, updates=gsup)
56 |
57 | updates = []
58 |
59 | for p, g in zip(tparams.values(), gshared):
60 | m = theano.shared(p.get_value() * 0.)
61 | m_new = momentum * m - lr * g
62 | updates.append((m, m_new))
63 |
64 | updated_p = p + momentum * m_new - lr * g
65 | updates.append((p, updated_p))
66 |
67 | f_update = theano.function([lr], [], updates=updates)
68 |
69 | return f_grad_shared, f_update
70 |
71 | def Adagrad(tparams, cost, inps, lr, epsilon=1e-6):
72 | """ default: lr=0.01 """
73 |
74 | grads = tensor.grad(cost, tparams.values())
75 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
76 | for k, p in tparams.iteritems()]
77 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
78 | f_grad_shared = theano.function(inps, cost, updates=gsup)
79 |
80 | updates = []
81 |
82 | for p, g in zip(tparams.values(), gshared):
83 | acc = theano.shared(p.get_value() * 0.)
84 | acc_t = acc + g ** 2
85 | updates.append((acc, acc_t))
86 | p_t = p - (lr / tensor.sqrt(acc_t + epsilon)) * g
87 | updates.append((p, p_t))
88 |
89 | f_update = theano.function([lr], [], updates=updates)
90 |
91 | return f_grad_shared, f_update
92 |
93 | def Adadelta(tparams, cost, inps, lr, rho=0.95, epsilon=1e-6):
94 | """ default: lr=0.5 """
95 |
96 | grads = tensor.grad(cost, tparams.values())
97 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
98 | for k, p in tparams.iteritems()]
99 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
100 | f_grad_shared = theano.function(inps, cost, updates=gsup)
101 |
102 | updates = []
103 |
104 | for p, g in zip(tparams.values(), gshared):
105 | acc = theano.shared(p.get_value() * 0.)
106 | acc_delta = theano.shared(p.get_value() * 0.)
107 | acc_new = rho * acc + (1 - rho) * g ** 2
108 | updates.append((acc,acc_new))
109 |
110 | update = g * tensor.sqrt(acc_delta + epsilon) / tensor.sqrt(acc_new + epsilon)
111 | updated_p = p - lr * update
112 | updates.append((p, updated_p))
113 |
114 | acc_delta_new = rho * acc_delta + (1 - rho) * update ** 2
115 | updates.append((acc_delta,acc_delta_new))
116 |
117 | f_update = theano.function([lr], [], updates=updates)
118 |
119 | return f_grad_shared, f_update
120 |
121 |
122 | def RMSprop_v1(tparams, cost, inps, lr, rho=0.9, epsilon=1e-6):
123 | """ default: lr=0.001
124 | This is the implementation of the RMSprop algorithm used in
125 | http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf.
126 | """
127 |
128 | grads = tensor.grad(cost, tparams.values())
129 | norm = tensor.sqrt(sum([tensor.sum(g**2) for g in grads]))
130 | if tensor.ge(norm, 5):
131 | grads = [g*5/norm for g in grads]
132 |
133 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
134 | for k, p in tparams.iteritems()]
135 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
136 | f_grad_shared = theano.function(inps, cost, updates=gsup)
137 |
138 | updates = []
139 |
140 | for p, g in zip(tparams.values(), gshared):
141 | acc = theano.shared(p.get_value() * 0.)
142 | acc_new = rho * acc + (1 - rho) * g ** 2
143 | updates.append((acc, acc_new))
144 |
145 | updated_p = p - lr * (g / tensor.sqrt(acc_new + epsilon))
146 | updates.append((p, updated_p))
147 |
148 | f_update = theano.function([lr], [], updates=updates)
149 |
150 | return f_grad_shared, f_update
151 |
152 | def RMSprop_v2(tparams, cost, inps, lr, rho=0.95, momentum=0.9, epsilon=1e-4):
153 | """ default: lr=0.0001
154 | This is the implementation of the RMSprop algorithm used in
155 | http://arxiv.org/pdf/1308.0850v5.pdf
156 | """
157 |
158 | grads = tensor.grad(cost, tparams.values())
159 | norm = tensor.sqrt(sum([tensor.sum(g**2) for g in grads]))
160 | if tensor.ge(norm, 5):
161 | grads = [g*5/norm for g in grads]
162 |
163 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
164 | for k, p in tparams.iteritems()]
165 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
166 | f_grad_shared = theano.function(inps, cost, updates=gsup)
167 |
168 | updates = []
169 |
170 | for p, g in zip(tparams.values(), gshared):
171 | acc = theano.shared(p.get_value() * 0.)
172 | acc2 = theano.shared(p.get_value() * 0.)
173 | acc_new = rho * acc + (1.-rho) * g
174 | acc2_new = rho * acc + (1.-rho) * (g ** 2)
175 | updates.append((acc, acc_new))
176 | updates.append((acc2, acc2_new))
177 |
178 | updir = theano.shared(p.get_value() * 0.)
179 | updir_new = momentum * updir - lr * g / tensor.sqrt(acc2_new -acc_new ** 2 + epsilon)
180 | updates.append((updir, updir_new))
181 |
182 | updated_p = p + updir_new
183 | updates.append((p, updated_p))
184 |
185 | f_update = theano.function([lr], [], updates=updates)
186 |
187 | return f_grad_shared, f_update
188 |
189 | def Adam(tparams, cost, inps, lr, b1=0.1, b2=0.001, e=1e-8):
190 | """ default: lr=0.0002
191 | This is the implementation of the Adam algorithm
192 | Reference: http://arxiv.org/pdf/1412.6980v8.pdf
193 | """
194 |
195 | grads = tensor.grad(cost, tparams.values())
196 | norm = tensor.sqrt(sum([tensor.sum(g**2) for g in grads]))
197 | if tensor.ge(norm, 5):
198 | grads = [g*5/norm for g in grads]
199 |
200 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad'%k)
201 | for k, p in tparams.iteritems()]
202 | gsup = [(gs, g) for gs, g in zip(gshared, grads)]
203 | f_grad_shared = theano.function(inps, cost, updates=gsup)
204 |
205 | updates = []
206 |
207 | i = theano.shared(numpy_floatX(0.))
208 | i_t = i + 1.
209 | fix1 = 1. - b1**(i_t)
210 | fix2 = 1. - b2**(i_t)
211 | lr_t = lr * (tensor.sqrt(fix2) / fix1)
212 |
213 | for p, g in zip(tparams.values(), gshared):
214 | m = theano.shared(p.get_value() * 0.)
215 | v = theano.shared(p.get_value() * 0.)
216 | m_t = (b1 * g) + ((1. - b1) * m)
217 | v_t = (b2 * tensor.sqr(g)) + ((1. - b2) * v)
218 | g_t = m_t / (tensor.sqrt(v_t) + e)
219 | p_t = p - (lr_t * g_t)
220 | updates.append((m, m_t))
221 | updates.append((v, v_t))
222 | updates.append((p, p_t))
223 | updates.append((i, i_t))
224 |
225 | f_update = theano.function([lr], [], updates=updates)
226 |
227 | return f_grad_shared, f_update
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import theano
3 | from theano import config
4 | from collections import OrderedDict
5 |
6 | def numpy_floatX(data):
7 | return np.asarray(data, dtype=config.floatX)
8 |
9 | def zipp(params, tparams):
10 | """
11 | When we reload the model. Needed for the GPU stuff.
12 | """
13 | for kk, vv in params.iteritems():
14 | tparams[kk].set_value(vv)
15 |
16 | def unzip(zipped):
17 | """
18 | When we pickle the model. Needed for the GPU stuff.
19 | """
20 | new_params = OrderedDict()
21 | for kk, vv in zipped.iteritems():
22 | new_params[kk] = vv.get_value()
23 | return new_params
24 |
25 |
26 | def get_minibatches_idx(n, minibatch_size, shuffle=False):
27 | idx_list = np.arange(n, dtype="int32")
28 |
29 | if shuffle:
30 | np.random.shuffle(idx_list)
31 |
32 | minibatches = []
33 | minibatch_start = 0
34 | for i in range(n // minibatch_size):
35 | minibatches.append(idx_list[minibatch_start:
36 | minibatch_start + minibatch_size])
37 | minibatch_start += minibatch_size
38 |
39 | if (minibatch_start != n):
40 | # Make a minibatch out of what is left
41 | minibatches.append(idx_list[minibatch_start:])
42 |
43 | return zip(range(len(minibatches)), minibatches)
44 |
45 | def _p(pp, name):
46 | return '%s_%s' % (pp, name)
47 |
48 | def dropout(X, trng, p=0.):
49 | if p != 0:
50 | retain_prob = 1 - p
51 | X = X / retain_prob * trng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX)
52 | return X
53 |
54 | """ used for initialization of the parameters. """
55 |
56 | def ortho_weight(ndim):
57 | W = np.random.randn(ndim, ndim)
58 | u, s, v = np.linalg.svd(W)
59 | return u.astype(config.floatX)
60 |
61 | def uniform_weight(nin,nout=None, scale=0.05):
62 | if nout == None:
63 | nout = nin
64 | W = np.random.uniform(low=-scale, high=scale, size=(nin, nout))
65 | return W.astype(config.floatX)
66 |
67 | def normal_weight(nin,nout=None, scale=0.05):
68 | if nout == None:
69 | nout = nin
70 | W = np.random.randn(nin, nout) * scale
71 | return W.astype(config.floatX)
72 |
73 | def zero_bias(ndim):
74 | b = np.zeros((ndim,))
75 | return b.astype(config.floatX)
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/sentence_retrieval.py:
--------------------------------------------------------------------------------
1 | '''
2 | Learning Generic Sentence Representations Using Convolutional Neural Networks
3 | https://arxiv.org/pdf/1611.07897.pdf
4 | Developed by Zhe Gan, zhe.gan@duke.edu, April, 19, 2016
5 | '''
6 |
7 | import cPickle
8 | import numpy as np
9 | import theano
10 | import theano.tensor as tensor
11 |
12 | from model.autoencoder import init_params, init_tparams
13 | from model.cnn_layer import encoder
14 | from model.utils import get_minibatches_idx
15 | from model.utils import _p
16 |
17 | from scipy import spatial
18 |
19 |
20 | def prepare_data_for_cnn(seqs_x, maxlen=40, n_words=21103, filter_h=5):
21 |
22 | lengths_x = [len(s) for s in seqs_x]
23 |
24 | if maxlen != None:
25 | new_seqs_x = []
26 | new_lengths_x = []
27 | for l_x, s_x in zip(lengths_x, seqs_x):
28 | if l_x < maxlen:
29 | new_seqs_x.append(s_x)
30 | new_lengths_x.append(l_x)
31 | lengths_x = new_lengths_x
32 | seqs_x = new_seqs_x
33 |
34 | if len(lengths_x) < 1 :
35 | return None, None
36 |
37 | pad = filter_h -1
38 | x = []
39 | for rev in seqs_x:
40 | xx = []
41 | for i in xrange(pad):
42 | xx.append(n_words-1)
43 | for idx in rev:
44 | xx.append(idx)
45 | while len(xx) < maxlen + 2*pad:
46 | xx.append(n_words-1)
47 | x.append(xx)
48 | x = np.array(x,dtype='int32')
49 | return x
50 |
51 | def prepare_data_for_rnn(seqs_x, maxlen=40):
52 |
53 | lengths_x = [len(s) for s in seqs_x]
54 |
55 | if maxlen != None:
56 | new_seqs_x = []
57 | new_lengths_x = []
58 | for l_x, s_x in zip(lengths_x, seqs_x):
59 | if l_x < maxlen:
60 | new_seqs_x.append(s_x)
61 | new_lengths_x.append(l_x)
62 | lengths_x = new_lengths_x
63 | seqs_x = new_seqs_x
64 |
65 | if len(lengths_x) < 1 :
66 | return None, None
67 |
68 | n_samples = len(seqs_x)
69 | maxlen_x = np.max(lengths_x)
70 |
71 | x = np.zeros((maxlen_x, n_samples)).astype('int32')
72 | x_mask = np.zeros((maxlen_x, n_samples)).astype(theano.config.floatX)
73 | for idx, s_x in enumerate(seqs_x):
74 | x[:lengths_x[idx], idx] = s_x
75 | x_mask[:lengths_x[idx], idx] = 1.
76 |
77 | return x, x_mask
78 |
79 | def find_sent_embedding(whole, n_words=21102, img_w=300, img_h=48, feature_maps=200,
80 | filter_hs=[3,4,5],n_x=300, n_h=600):
81 |
82 | options = {}
83 | options['n_words'] = n_words
84 | options['img_w'] = img_w
85 | options['img_h'] = img_h
86 | options['feature_maps'] = feature_maps
87 | options['filter_hs'] = filter_hs
88 | options['n_x'] = n_x
89 | options['n_h'] = n_h
90 |
91 | filter_w = img_w
92 | filter_shapes = []
93 | pool_sizes = []
94 | for filter_h in filter_hs:
95 | filter_shapes.append((feature_maps, 1, filter_h, filter_w))
96 | pool_sizes.append((img_h-filter_h+1, img_w-filter_w+1))
97 |
98 | options['filter_shapes'] = filter_shapes
99 | options['pool_sizes'] = pool_sizes
100 |
101 | params = init_params(options)
102 | tparams = init_tparams(params)
103 |
104 | data = np.load('./bookcorpus_result.npz')
105 |
106 | for kk, pp in params.iteritems():
107 | params[kk] = data[kk]
108 |
109 | for kk, pp in params.iteritems():
110 | tparams[kk].set_value(params[kk])
111 |
112 | x = tensor.matrix('x', dtype='int32')
113 |
114 | layer0_input = tparams['Wemb'][tensor.cast(x.flatten(),dtype='int32')].reshape((x.shape[0],1,x.shape[1],tparams['Wemb'].shape[1]))
115 |
116 | layer1_inputs = []
117 | for i in xrange(len(options['filter_hs'])):
118 | filter_shape = options['filter_shapes'][i]
119 | pool_size = options['pool_sizes'][i]
120 | conv_layer = encoder(tparams, layer0_input,filter_shape, pool_size,prefix=_p('cnn_encoder',i))
121 | layer1_input = conv_layer
122 | layer1_inputs.append(layer1_input)
123 | layer1_input = tensor.concatenate(layer1_inputs,1)
124 |
125 | f_embed = theano.function([x], layer1_input, name='f_embed')
126 |
127 | kf = get_minibatches_idx(len(whole), 100)
128 | sent_emb = np.zeros((len(whole),600))
129 |
130 | for i, train_index in kf:
131 | sents = [whole[t] for t in train_index]
132 | x = prepare_data_for_cnn(sents)
133 | sent_emb[train_index[0]:train_index[-1]+1] = f_embed(x)
134 | if i % 500 == 0:
135 | print i,
136 |
137 | np.savez('./bookcorpus_embedding.npz', sent_emb=sent_emb)
138 |
139 | return sent_emb
140 |
141 |
142 | if __name__ == '__main__':
143 |
144 | x = cPickle.load(open("./data/bookcorpus_1M.p","rb"))
145 | train, val, test = x[0], x[1], x[2]
146 | train_text, val_text, test_text = x[3], x[4], x[5]
147 | wordtoix, ixtoword = x[6], x[7]
148 | del x
149 |
150 | n_words = len(ixtoword)
151 |
152 | ixtoword[n_words] = ''
153 | wordtoix[''] = n_words
154 | n_words = n_words + 1
155 |
156 | whole = train + val + test
157 | whole_text = train_text + val_text + test_text
158 | del train, val, test
159 | del train_text, val_text, test_text
160 |
161 | sent_emb = find_sent_embedding(whole)
162 |
163 | """ sentence retrieval """
164 | x = np.load('./bookcorpus_embedding.npz')
165 | sent_emb = x['sent_emb']
166 |
167 | idx = 0
168 | print whole_text[idx]
169 | target_emb = sent_emb[idx]
170 |
171 | cos_similarity = []
172 | for i in range(len(whole)):
173 | vector = sent_emb[i]
174 | result = 1 - spatial.distance.cosine(target_emb, vector)
175 | cos_similarity.append(result)
176 | top_indices = np.argsort(cos_similarity)[::-1]
177 |
178 | print whole_text[top_indices[0]], cos_similarity[top_indices[0]]
179 | print whole_text[top_indices[1]], cos_similarity[top_indices[1]]
180 | print whole_text[top_indices[2]], cos_similarity[top_indices[2]]
181 | print whole_text[top_indices[3]], cos_similarity[top_indices[3]]
182 | print whole_text[top_indices[4]], cos_similarity[top_indices[4]]
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
--------------------------------------------------------------------------------
/train_autoencoder.py:
--------------------------------------------------------------------------------
1 | '''
2 | Learning Generic Sentence Representations Using Convolutional Neural Networks
3 | https://arxiv.org/pdf/1611.07897.pdf
4 | Developed by Zhe Gan, zhe.gan@duke.edu, April, 19, 2016
5 | '''
6 |
7 | #import os
8 | import time
9 | import logging
10 | import cPickle
11 |
12 | import numpy as np
13 | import theano
14 | import theano.tensor as tensor
15 |
16 | from model.autoencoder import init_params, init_tparams, build_model
17 | from model.optimizers import Adam
18 | from model.utils import get_minibatches_idx, unzip
19 |
20 | #theano.config.optimizer='fast_compile'
21 | #theano.config.exception_verbosity='high'
22 | #theano.config.compute_test_value = 'warn'
23 |
24 | def prepare_data_for_cnn(seqs_x, maxlen=40, n_words=21103, filter_h=5):
25 |
26 | lengths_x = [len(s) for s in seqs_x]
27 |
28 | if maxlen != None:
29 | new_seqs_x = []
30 | new_lengths_x = []
31 | for l_x, s_x in zip(lengths_x, seqs_x):
32 | if l_x < maxlen:
33 | new_seqs_x.append(s_x)
34 | new_lengths_x.append(l_x)
35 | lengths_x = new_lengths_x
36 | seqs_x = new_seqs_x
37 |
38 | if len(lengths_x) < 1 :
39 | return None, None
40 |
41 | pad = filter_h -1
42 | x = []
43 | for rev in seqs_x:
44 | xx = []
45 | for i in xrange(pad):
46 | # we need pad the special token.
47 | xx.append(n_words-1)
48 | for idx in rev:
49 | xx.append(idx)
50 | while len(xx) < maxlen + 2*pad:
51 | xx.append(n_words-1)
52 | x.append(xx)
53 | x = np.array(x,dtype='int32')
54 | return x
55 |
56 | def prepare_data_for_rnn(seqs_x, maxlen=40):
57 |
58 | lengths_x = [len(s) for s in seqs_x]
59 |
60 | if maxlen != None:
61 | new_seqs_x = []
62 | new_lengths_x = []
63 | for l_x, s_x in zip(lengths_x, seqs_x):
64 | if l_x < maxlen:
65 | new_seqs_x.append(s_x)
66 | new_lengths_x.append(l_x)
67 | lengths_x = new_lengths_x
68 | seqs_x = new_seqs_x
69 |
70 | if len(lengths_x) < 1 :
71 | return None, None
72 |
73 | n_samples = len(seqs_x)
74 | maxlen_x = np.max(lengths_x)
75 |
76 | x = np.zeros((maxlen_x, n_samples)).astype('int32')
77 | x_mask = np.zeros((maxlen_x, n_samples)).astype(theano.config.floatX)
78 | for idx, s_x in enumerate(seqs_x):
79 | x[:lengths_x[idx], idx] = s_x
80 | x_mask[:lengths_x[idx], idx] = 1.
81 |
82 | return x, x_mask
83 |
84 | def calu_cost(f_cost, prepare_data_for_cnn, prepare_data_for_rnn, data, kf):
85 |
86 | total_negll = 0.
87 | total_len = 0.
88 |
89 | for _, train_index in kf:
90 | sents = [train[t]for t in train_index]
91 |
92 | x = prepare_data_for_cnn(sents)
93 | y, y_mask = prepare_data_for_rnn(sents)
94 | negll = f_cost(x, y, y_mask) * np.sum(y_mask)
95 | length = np.sum(y_mask)
96 | total_negll += negll
97 | total_len += length
98 |
99 | return total_negll/total_len
100 |
101 | """ Training the model. """
102 |
103 | def train_model(train, val, test, n_words=21103, img_w=300, max_len=40,
104 | feature_maps=200, filter_hs=[3,4,5], n_x=300, n_h=600,
105 | max_epochs=8, lrate=0.0002, batch_size=64, valid_batch_size=64, dispFreq=10,
106 | validFreq=500, saveFreq=1000, saveto = 'bookcorpus_result.npz'):
107 |
108 | """ train, valid, test : datasets
109 | n_words : vocabulary size
110 | img_w : word embedding dimension, must be 300.
111 | max_len : the maximum length of a sentence
112 | feature_maps : the number of feature maps we used
113 | filter_hs: the filter window sizes we used
114 | n_x: word embedding dimension
115 | n_h: the number of hidden units in LSTM
116 | max_epochs : the maximum number of epoch to run
117 | lrate : learning rate
118 | batch_size : batch size during training
119 | valid_batch_size : The batch size used for validation/test set
120 | dispFreq : Display to stdout the training progress every N updates
121 | validFreq : Compute the validation error after this number of update.
122 | saveFreq: save the result after this number of update.
123 | saveto: where to save the result.
124 | """
125 |
126 | img_h = max_len + 2*(filter_hs[-1]-1)
127 |
128 | options = {}
129 | options['n_words'] = n_words
130 | options['img_w'] = img_w
131 | options['img_h'] = img_h
132 | options['feature_maps'] = feature_maps
133 | options['filter_hs'] = filter_hs
134 | options['n_x'] = n_x
135 | options['n_h'] = n_h
136 | options['max_epochs'] = max_epochs
137 | options['lrate'] = lrate
138 | options['batch_size'] = batch_size
139 | options['valid_batch_size'] = valid_batch_size
140 | options['dispFreq'] = dispFreq
141 | options['validFreq'] = validFreq
142 | options['saveFreq'] = saveFreq
143 |
144 | logger.info('Model options {}'.format(options))
145 |
146 | logger.info('Building model...')
147 |
148 | filter_w = img_w
149 | filter_shapes = []
150 | pool_sizes = []
151 | for filter_h in filter_hs:
152 | filter_shapes.append((feature_maps, 1, filter_h, filter_w))
153 | pool_sizes.append((img_h-filter_h+1, img_w-filter_w+1))
154 |
155 | options['filter_shapes'] = filter_shapes
156 | options['pool_sizes'] = pool_sizes
157 |
158 | params = init_params(options)
159 | tparams = init_tparams(params)
160 |
161 | use_noise, x, y, y_mask, cost = build_model(tparams,options)
162 |
163 | f_cost = theano.function([x, y, y_mask], cost, name='f_cost')
164 |
165 | lr = tensor.scalar(name='lr')
166 | f_grad_shared, f_update = Adam(tparams, cost, [x, y, y_mask], lr)
167 |
168 | logger.info('Training model...')
169 |
170 | history_cost = []
171 | uidx = 0 # the number of update done
172 | start_time = time.time()
173 |
174 | kf_valid = get_minibatches_idx(len(val), valid_batch_size)
175 |
176 | zero_vec_tensor = tensor.vector()
177 | zero_vec = np.zeros(img_w).astype(theano.config.floatX)
178 | set_zero = theano.function([zero_vec_tensor], updates=[(tparams['Wemb'], tensor.set_subtensor(tparams['Wemb'][21102,:], zero_vec_tensor))])
179 |
180 | try:
181 | for eidx in xrange(max_epochs):
182 | n_samples = 0
183 |
184 | kf = get_minibatches_idx(len(train), batch_size, shuffle=True)
185 |
186 | for _, train_index in kf:
187 | uidx += 1
188 | use_noise.set_value(0.)
189 |
190 | sents = [train[t]for t in train_index]
191 |
192 | x = prepare_data_for_cnn(sents)
193 | y, y_mask = prepare_data_for_rnn(sents)
194 | n_samples += y.shape[1]
195 |
196 | cost = f_grad_shared(x, y, y_mask)
197 | f_update(lrate)
198 | # the special token does not need to update.
199 | set_zero(zero_vec)
200 |
201 | if np.isnan(cost) or np.isinf(cost):
202 |
203 | logger.info('NaN detected')
204 | return 1., 1., 1.
205 |
206 | if np.mod(uidx, dispFreq) == 0:
207 | logger.info('Epoch {} Update {} Cost {}'.format(eidx, uidx, np.exp(cost)))
208 |
209 | if np.mod(uidx, saveFreq) == 0:
210 |
211 | logger.info('Saving ...')
212 |
213 | params = unzip(tparams)
214 | np.savez(saveto, history_cost=history_cost, **params)
215 |
216 | logger.info('Done ...')
217 |
218 | if np.mod(uidx, validFreq) == 0:
219 | use_noise.set_value(0.)
220 |
221 | valid_cost = calu_cost(f_cost, prepare_data_for_cnn, prepare_data_for_rnn, val, kf_valid)
222 | history_cost.append([valid_cost])
223 |
224 | logger.info('Valid {}'.format(np.exp(valid_cost)))
225 |
226 | logger.info('Seen {} samples'.format(n_samples))
227 |
228 | except KeyboardInterrupt:
229 | logger.info('Training interupted')
230 |
231 | end_time = time.time()
232 |
233 | # if best_p is not None:
234 | # zipp(best_p, tparams)
235 | # else:
236 | # best_p = unzip(tparams)
237 |
238 |
239 | use_noise.set_value(0.)
240 | valid_cost = calu_cost(f_cost, prepare_data_for_cnn, prepare_data_for_rnn, val, kf_valid)
241 | logger.info('Valid {}'.format(np.exp(valid_cost)))
242 |
243 | params = unzip(tparams)
244 | np.savez(saveto, history_cost=history_cost, **params)
245 |
246 |
247 | logger.info('The code run for {} epochs, with {} sec/epochs'.format(eidx + 1,
248 | (end_time - start_time) / (1. * (eidx + 1))))
249 |
250 |
251 | return valid_cost
252 |
253 | if __name__ == '__main__':
254 |
255 | logger = logging.getLogger('train_autoencoder')
256 | logger.setLevel(logging.INFO)
257 | fh = logging.FileHandler('train_autoencoder.log')
258 | fh.setLevel(logging.INFO)
259 | ch = logging.StreamHandler()
260 | ch.setLevel(logging.INFO)
261 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')
262 | fh.setFormatter(formatter)
263 | ch.setFormatter(formatter)
264 | logger.addHandler(fh)
265 |
266 | x = cPickle.load(open("./data/bookcorpus_1M.p","rb"))
267 | train, val, test = x[0], x[1], x[2]
268 | train_text, val_text, test_text = x[3], x[4], x[5]
269 | wordtoix, ixtoword = x[6], x[7]
270 | del x
271 | del train_text, val_text, test_text
272 |
273 | n_words = len(ixtoword)
274 | ixtoword[n_words] = ''
275 | wordtoix[''] = n_words
276 | n_words = n_words + 1
277 |
278 | valid_cost = train_model(train, val, test, n_words=n_words)
279 |
--------------------------------------------------------------------------------
/vector_compositionality.py:
--------------------------------------------------------------------------------
1 | '''
2 | Learning Generic Sentence Representations Using Convolutional Neural Networks
3 | https://arxiv.org/pdf/1611.07897.pdf
4 | Developed by Zhe Gan, zhe.gan@duke.edu, April, 19, 2016
5 | '''
6 |
7 | import cPickle
8 | import numpy as np
9 | import theano
10 | import theano.tensor as tensor
11 |
12 | from model.autoencoder import init_params, init_tparams
13 | from model.cnn_layer import encoder
14 | from model.utils import _p
15 |
16 | def prepare_data_for_cnn(seqs_x, maxlen=40, n_words=21103, filter_h=5):
17 |
18 | lengths_x = [len(s) for s in seqs_x]
19 |
20 | if maxlen != None:
21 | new_seqs_x = []
22 | new_lengths_x = []
23 | for l_x, s_x in zip(lengths_x, seqs_x):
24 | if l_x < maxlen:
25 | new_seqs_x.append(s_x)
26 | new_lengths_x.append(l_x)
27 | lengths_x = new_lengths_x
28 | seqs_x = new_seqs_x
29 |
30 | if len(lengths_x) < 1 :
31 | return None, None
32 |
33 | pad = filter_h -1
34 | x = []
35 | for rev in seqs_x:
36 | xx = []
37 | for i in xrange(pad):
38 | xx.append(n_words-1)
39 | for idx in rev:
40 | xx.append(idx)
41 | while len(xx) < maxlen + 2*pad:
42 | xx.append(n_words-1)
43 | x.append(xx)
44 | x = np.array(x,dtype='int32')
45 | return x
46 |
47 | def prepare_data_for_rnn(seqs_x, maxlen=40):
48 |
49 | lengths_x = [len(s) for s in seqs_x]
50 |
51 | if maxlen != None:
52 | new_seqs_x = []
53 | new_lengths_x = []
54 | for l_x, s_x in zip(lengths_x, seqs_x):
55 | if l_x < maxlen:
56 | new_seqs_x.append(s_x)
57 | new_lengths_x.append(l_x)
58 | lengths_x = new_lengths_x
59 | seqs_x = new_seqs_x
60 |
61 | if len(lengths_x) < 1 :
62 | return None, None
63 |
64 | n_samples = len(seqs_x)
65 | maxlen_x = np.max(lengths_x)
66 |
67 | x = np.zeros((maxlen_x, n_samples)).astype('int32')
68 | x_mask = np.zeros((maxlen_x, n_samples)).astype(theano.config.floatX)
69 | for idx, s_x in enumerate(seqs_x):
70 | x[:lengths_x[idx], idx] = s_x
71 | x_mask[:lengths_x[idx], idx] = 1.
72 |
73 | return x, x_mask
74 |
75 | def find_sent_embedding(n_words=21102, img_w=300, img_h=48, feature_maps=200,
76 | filter_hs=[3,4,5],n_x=300, n_h=600):
77 |
78 | options = {}
79 | options['n_words'] = n_words
80 | options['img_w'] = img_w
81 | options['img_h'] = img_h
82 | options['feature_maps'] = feature_maps
83 | options['filter_hs'] = filter_hs
84 | options['n_x'] = n_x
85 | options['n_h'] = n_h
86 |
87 | filter_w = img_w
88 | filter_shapes = []
89 | pool_sizes = []
90 | for filter_h in filter_hs:
91 | filter_shapes.append((feature_maps, 1, filter_h, filter_w))
92 | pool_sizes.append((img_h-filter_h+1, img_w-filter_w+1))
93 |
94 | options['filter_shapes'] = filter_shapes
95 | options['pool_sizes'] = pool_sizes
96 |
97 | params = init_params(options)
98 | tparams = init_tparams(params)
99 |
100 | data = np.load('./bookcorpus_result.npz')
101 |
102 | for kk, pp in params.iteritems():
103 | params[kk] = data[kk]
104 |
105 | for kk, pp in params.iteritems():
106 | tparams[kk].set_value(params[kk])
107 |
108 | x = tensor.matrix('x', dtype='int32')
109 |
110 | layer0_input = tparams['Wemb'][tensor.cast(x.flatten(),dtype='int32')].reshape((x.shape[0],1,x.shape[1],tparams['Wemb'].shape[1]))
111 |
112 | layer1_inputs = []
113 | for i in xrange(len(options['filter_hs'])):
114 | filter_shape = options['filter_shapes'][i]
115 | pool_size = options['pool_sizes'][i]
116 | conv_layer = encoder(tparams, layer0_input,filter_shape, pool_size,prefix=_p('cnn_encoder',i))
117 | layer1_input = conv_layer
118 | layer1_inputs.append(layer1_input)
119 | layer1_input = tensor.concatenate(layer1_inputs,1)
120 |
121 | f_embed = theano.function([x], layer1_input, name='f_embed')
122 |
123 | return f_embed, params
124 |
125 | def predict(z, params, beam_size, max_step, prefix='decoder'):
126 |
127 | """ z: size of (n_z, 1)
128 | """
129 | n_h = params[_p(prefix,'U')].shape[0]
130 |
131 | def _slice(_x, n, dim):
132 | return _x[n*dim:(n+1)*dim]
133 |
134 | def sigmoid(x):
135 | return 1/(1+np.exp(-x))
136 |
137 | Vhid = np.dot(params['Vhid'],params['Wemb'].T)
138 |
139 | def _step(x_prev, h_prev, c_prev):
140 | preact = np.dot(h_prev, params[_p(prefix, 'U')]) + \
141 | np.dot(x_prev, params[_p(prefix, 'W')]) + \
142 | np.dot(z, params[_p(prefix, 'C')]) + params[_p(prefix, 'b')]
143 |
144 | i = sigmoid(_slice(preact, 0, n_h))
145 | f = sigmoid(_slice(preact, 1, n_h))
146 | o = sigmoid(_slice(preact, 2, n_h))
147 | c = np.tanh(_slice(preact, 3, n_h))
148 |
149 | c = f * c_prev + i * c
150 | h = o * np.tanh(c)
151 |
152 | y = np.dot(h, Vhid) + params['bhid']
153 |
154 | return y, h, c
155 |
156 | h0 = np.tanh(np.dot(z, params[_p(prefix, 'C0')]) + params[_p(prefix, 'b0')])
157 | y0 = np.dot(h0, Vhid) + params['bhid']
158 | c0 = np.zeros(h0.shape)
159 |
160 | maxy0 = np.amax(y0)
161 | e0 = np.exp(y0 - maxy0) # for numerical stability shift into good numerical range
162 | p0 = e0 / np.sum(e0)
163 | y0 = np.log(1e-20 + p0) # and back to log domain
164 |
165 | beams = []
166 | nsteps = 1
167 | # generate the first word
168 | top_indices = np.argsort(-y0) # we do -y because we want decreasing order
169 |
170 | for i in xrange(beam_size):
171 | wordix = top_indices[i]
172 | # log probability, indices of words predicted in this beam so far, and the hidden and cell states
173 | beams.append((y0[wordix], [wordix], h0, c0))
174 |
175 | # perform BEAM search.
176 | if beam_size > 1:
177 | # generate the rest n words
178 | while True:
179 | beam_candidates = []
180 | for b in beams:
181 | ixprev = b[1][-1] if b[1] else 0 # start off with the word where this beam left off
182 | if ixprev == 0 and b[1]:
183 | # this beam predicted end token. Keep in the candidates but don't expand it out any more
184 | beam_candidates.append(b)
185 | continue
186 | (y1, h1, c1) = _step(params['Wemb'][ixprev], b[2], b[3])
187 | y1 = y1.ravel() # make into 1D vector
188 | maxy1 = np.amax(y1)
189 | e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range
190 | p1 = e1 / np.sum(e1)
191 | y1 = np.log(1e-20 + p1) # and back to log domain
192 | top_indices = np.argsort(-y1) # we do -y because we want decreasing order
193 | for i in xrange(beam_size):
194 | wordix = top_indices[i]
195 | beam_candidates.append((b[0] + y1[wordix], b[1] + [wordix], h1, c1))
196 | beam_candidates.sort(reverse = True) # decreasing order
197 | beams = beam_candidates[:beam_size] # truncate to get new beams
198 | nsteps += 1
199 | if nsteps >= max_step: # bad things are probably happening, break out
200 | break
201 | # strip the intermediates
202 | predictions = [(b[0], b[1]) for b in beams]
203 | else:
204 | nsteps = 1
205 | h = h0
206 | # generate the first word
207 | top_indices = np.argsort(-y0) # we do -y because we want decreasing order
208 | ixprev = top_indices[0]
209 | predix = [ixprev]
210 | predlogprob = y0[ixprev]
211 | while True:
212 | (y1, h) = _step(params['Wemb'][ixprev], h)
213 | ixprev, ixlogprob = ymax(y1)
214 | predix.append(ixprev)
215 | predlogprob += ixlogprob
216 | nsteps += 1
217 | if nsteps >= max_step:
218 | break
219 | predictions = [(predlogprob, predix)]
220 |
221 | return predictions
222 |
223 | def ymax(y):
224 | """ simple helper function here that takes unnormalized logprobs """
225 | y1 = y.ravel() # make sure 1d
226 | maxy1 = np.amax(y1)
227 | e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range
228 | p1 = e1 / np.sum(e1)
229 | y1 = np.log(1e-20 + p1) # guard against zero probabilities just in case
230 | ix = np.argmax(y1)
231 | return (ix, y1[ix])
232 |
233 | def generate(z_emb, params):
234 |
235 | predset = []
236 | for i in xrange(len(z_emb)):
237 | pred = predict(z_emb[i], params, beam_size=5, max_step=40)
238 | predset.append(pred)
239 | #print i,
240 |
241 | return predset
242 |
243 | def get_idx_from_sent(sent, word_idx_map):
244 | """
245 | Transforms sentence into a list of indices.
246 | """
247 | x = []
248 | words = sent.split()
249 | for word in words:
250 | if word in word_idx_map:
251 | x.append(word_idx_map[word])
252 | else:
253 | x.append(1)
254 | x.append(0)
255 | return x
256 |
257 |
258 | if __name__ == '__main__':
259 |
260 | print "loading data..."
261 | x = cPickle.load(open("./data/bookcorpus_1M.p","rb"))
262 | train, val, test = x[0], x[1], x[2]
263 | train_text, val_text, test_text = x[3], x[4], x[5]
264 | wordtoix, ixtoword = x[6], x[7]
265 | del x
266 |
267 | n_words = len(ixtoword)
268 | ixtoword[n_words] = ''
269 | wordtoix[''] = n_words
270 | n_words = n_words + 1
271 |
272 | f_embed, params = find_sent_embedding()
273 |
274 | x1 = []
275 | x1.append(get_idx_from_sent("you needed me ?",wordtoix))
276 | x1.append(get_idx_from_sent("you got me ?",wordtoix))
277 | x1.append(get_idx_from_sent("i got you .",wordtoix))
278 |
279 | x2 = []
280 | x2.append(get_idx_from_sent("this is great .",wordtoix))
281 | x2.append(get_idx_from_sent("this is awesome .",wordtoix))
282 | x2.append(get_idx_from_sent("you are awesome .",wordtoix))
283 |
284 | x3 = []
285 | x3.append(get_idx_from_sent("its lovely to see you .",wordtoix))
286 | x3.append(get_idx_from_sent("its great to meet you .",wordtoix))
287 | x3.append(get_idx_from_sent("its great to meet him .",wordtoix))
288 |
289 | x4 = []
290 | x4.append(get_idx_from_sent("he had thought he was going crazy .",wordtoix))
291 | x4.append(get_idx_from_sent("i felt like i was going crazy .",wordtoix))
292 | x4.append(get_idx_from_sent("i felt like to say the right thing .",wordtoix))
293 |
294 | sent_emb = f_embed(prepare_data_for_cnn(x1))
295 | sent_emb_x1 = sent_emb[0] - sent_emb[1] + sent_emb[2]
296 |
297 | sent_emb = f_embed(prepare_data_for_cnn(x2))
298 | sent_emb_x2 = sent_emb[0] - sent_emb[1] + sent_emb[2]
299 |
300 | sent_emb = f_embed(prepare_data_for_cnn(x3))
301 | sent_emb_x3 = sent_emb[0] - sent_emb[1] + sent_emb[2]
302 |
303 | sent_emb = f_embed(prepare_data_for_cnn(x4))
304 | sent_emb_x4 = sent_emb[0] - sent_emb[1] + sent_emb[2]
305 |
306 | sent_emb = np.stack((sent_emb_x1,sent_emb_x2,sent_emb_x3,sent_emb_x4))
307 |
308 | predset = generate(sent_emb, params)
309 |
310 | predset_text = []
311 | for sent in predset:
312 | rev = []
313 | for sen in sent:
314 | smal = []
315 | for w in sen[1]:
316 | smal.append(ixtoword[w])
317 | rev.append(' '.join(smal))
318 | predset_text.append(rev)
319 |
320 | for i in range(4):
321 | print predset_text[i][0]
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
--------------------------------------------------------------------------------