├── .idea
├── SRU-tensorflow.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
└── SRU_tensorflow.py
/.idea/SRU-tensorflow.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 | BasicRNN
53 |
54 |
55 |
56 |
57 |
58 |
59 |
64 |
65 |
66 |
67 |
68 | true
69 | DEFINITION_ORDER
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 | 1505547314989
133 |
134 |
135 | 1505547314989
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SRU-tensorflow
2 | Training RNNs as fast as CNNs. An unofficial tensorflow implementation.
3 |
--------------------------------------------------------------------------------
/SRU_tensorflow.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import collections
3 |
4 | from tensorflow.python.ops import variable_scope as vs
5 | from tensorflow.contrib.rnn import RNNCell
6 |
7 | class SRUCell(RNNCell):
8 | def __init__(self, num_units, activation=tf.nn.tanh, state_is_tuple=False, reuse=None):
9 | super(SRUCell, self).__init__(_reuse=reuse)
10 | self.hidden_dim = num_units
11 | self.state_is_tuple = state_is_tuple
12 | self.g = activation
13 | init_matrix = tf.orthogonal_initializer()
14 |
15 | self.Wr = tf.Variable(init_matrix([self.hidden_dim, self.hidden_dim]))
16 | self.br = tf.Variable(self.init_matrix([self.hidden_dim]))
17 |
18 | self.Wf = tf.Variable(init_matrix([self.hidden_dim, self.hidden_dim]))
19 | self.bf = tf.Variable(self.init_matrix([self.hidden_dim]))
20 |
21 | self.U = tf.Variable(init_matrix([self.hidden_dim, self.hidden_dim]))
22 |
23 |
24 | @property
25 | def state_size(self):
26 | return (LSTMStateTuple(self.hidden_dim, self.hidden_dim)
27 | if self.state_is_tuple else self.hidden_dim)
28 |
29 | @property
30 | def output_size(self):
31 | return self.hidden_dim
32 |
33 | def __call__(self, inputs, state, scope=None):
34 | with vs.variable_scope(scope or type(self).__name__):
35 | if self.state_is_tuple:
36 | (c_prev, h_prev) = state
37 | else:
38 | c_prev = state
39 | # Forget Gate
40 | f = tf.sigmoid(
41 | tf.matmul(inputs, self.Wf) + self.bf
42 | )
43 |
44 | # Reset Gate
45 | r = tf.sigmoid(
46 | tf.matmul(inputs, self.Wr) + self.br
47 | )
48 |
49 | # Final Memory cell
50 | c = f * c_prev + (1.0 - f) * tf.matmul(inputs, self.U)
51 |
52 | # Current Hidden state
53 | current_hidden_state = r * self.g(c) + (1.0 - r) * inputs
54 | if self.state_is_tuple:
55 | return current_hidden_state, LSTMStateTuple(c, current_hidden_state)
56 | else:
57 | return current_hidden_state, c
58 |
59 | def init_matrix(self, shape):
60 | return tf.random_normal(shape, stddev=0.1)
61 |
62 |
63 | _LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
64 |
65 |
66 | class LSTMStateTuple(_LSTMStateTuple):
67 | """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
68 |
69 | Stores two elements: `(c, h)`, in that order.
70 |
71 | Only used when `state_is_tuple=True`.
72 | """
73 | __slots__ = ()
74 |
75 | @property
76 | def dtype(self):
77 | (c, h) = self
78 | if c.dtype != h.dtype:
79 | raise TypeError("Inconsistent internal state: %s vs %s" %
80 | (str(c.dtype), str(h.dtype)))
81 | return c.dtype
82 |
--------------------------------------------------------------------------------