├── .gitignore ├── LICENSE.txt ├── layers.py ├── layers_new.py ├── linear_recurrent_net ├── build.sh ├── linear_recurrence.cu ├── linear_recurrence.h └── linear_recurrent_net │ ├── __init__.py │ └── tensorflow_binding │ ├── __init__.py │ ├── gradient_test.py │ └── linear_recurrence_op.cpp ├── paper ├── 1k_20_smoothing.png ├── 1k_synthetic.png ├── 1k_synthetic_new.png ├── 1m_20_smoothing.png ├── 1m_synthetic.png ├── 1m_synthetic_new.png ├── 8k_5_smoothing.png ├── 8k_synthetic.png ├── 8k_synthetic_new.png ├── between_chain_plots │ ├── 1k_0_smoothing.png │ ├── 1k_20_smoothing.png │ ├── 1k_5_smoothing.png │ ├── 1m_0_smoothing.png │ ├── 1m_20_smoothing.png │ ├── 8k_0_smoothing.png │ ├── 8k_20_smoothing.png │ └── 8k_5_smoothing.png ├── cudnn_heatmap.png ├── heatmap.png ├── heatmap_accurate_to_caption.png ├── iclr2018_conference.log ├── iclr2018_conference.sty ├── iclr_reviews.txt ├── lc2.png ├── learning_curves.png ├── main-arxiv.aux ├── main-arxiv.log ├── main-arxiv.out ├── main-arxiv.pdf ├── main-arxiv.tex ├── main.aux ├── main.bbl ├── main.blg ├── main.dvi ├── main.out ├── main.pdf ├── main.tex ├── main.tex.blg ├── medical_training.png ├── nips_2017.sty ├── references.bib ├── references.bib.blg ├── synthetic_diagram.pdf ├── to_do.aux ├── to_do.log ├── to_do.org ├── to_do.out ├── to_do.pdf ├── to_do.tex └── to_do.toc ├── plr_slr.py └── poster ├── 8k_for_poster.png ├── beamerposter.sty ├── beamerthemeconfposter.sty ├── cudnn_heatmap_gilr.png ├── cumsum.png ├── logo.png ├── main.aux ├── main.log ├── main.nav ├── main.out ├── main.pdf ├── main.snm ├── main.tex ├── main.toc ├── placeholder.jpg └── sample.bib /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | paper/main.log 4 | paper/main.aux 5 | paper/main.synctex.gz 6 | paper/missfont.log -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) <2017> 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from linear_recurrent_net.tensorflow_binding import linear_recurrence 3 | 4 | def vscope(name): 5 | return tf.variable_scope(None, default_name=name) 6 | 7 | # contracts on the innermost dimension 8 | def matmul(X, W): 9 | res = tf.tensordot(X, W, [[-1], [0]]) 10 | res.set_shape(X.get_shape().as_list()[:-1] + 11 | W.get_shape().as_list()[1:]) 12 | return res 13 | 14 | def embedding_layer(X, size, dims, name='embedding'): 15 | with vscope(name): 16 | W = tf.get_variable('W', [dims, size]) 17 | return tf.nn.embedding_lookup(W, X) 18 | 19 | def fc_layer(X, hidden_size, nonlin=tf.nn.elu, 20 | use_bias=True, use_layer_norm=False, ln_eps=1e-3, 21 | name='fc'): 22 | n_dims = X.get_shape()[-1].value 23 | with vscope(name): 24 | W = tf.get_variable('W', [n_dims, hidden_size]) 25 | 26 | if use_bias: 27 | b = tf.get_variable('b', [hidden_size]) 28 | else: 29 | b = 0 30 | 31 | prod = matmul(X, W) 32 | if use_layer_norm: 33 | idx = ([None] * (len(X.shape) - 1)) + [slice(None)] 34 | g = tf.get_variable('g', [hidden_size])[idx] 35 | 36 | mu, sigma = tf.nn.moments(prod, [-1], keep_dims=True) 37 | prod = g * (prod - mu) / (sigma + ln_eps) 38 | 39 | return nonlin(prod + b) 40 | 41 | def gilr_layer(X, hidden_size, nonlin=tf.nn.elu, 42 | name='gilr'): 43 | """ 44 | g_t = sigmoid(Ux_t + b) 45 | h_t = g_t h_{t-1} + (1-g_t) f(Vx_t + c) 46 | """ 47 | with vscope(name): 48 | n_dims = X.get_shape()[-1].value 49 | act = fc_layer(X, 2 * hidden_size, nonlin=tf.identity) 50 | gate, impulse = tf.split(act, 2, len(act.shape) - 1) 51 | gate = tf.sigmoid(gate) 52 | impulse = nonlin(impulse) 53 | return linear_recurrence(gate, (1-gate) * impulse) 54 | 55 | def linear_surrogate_lstm(X, hidden_size, name='lin_sur_lstm'): 56 | with vscope(name): 57 | # 2 * hidden_size * n_dims params 58 | h_tilde = gilr_layer(X, hidden_size, nonlin=tf.tanh) 59 | 60 | # 4 * hidden_size * (hidden_size + n_dims) 61 | preact = fc_layer(tf.concat([h_tilde, X], axis=-1), 4 * hidden_size, 62 | nonlin=tf.identity) 63 | 64 | f, i, o, z = tf.split(preact, 4, len(preact.shape) - 1) 65 | 66 | f = tf.sigmoid(f) 67 | i = tf.sigmoid(i) 68 | o = tf.sigmoid(o) 69 | z = tf.tanh(z) 70 | 71 | c = linear_recurrence(f, i * z) 72 | h = o * c 73 | return h 74 | 75 | def SRU(X, name='SRU'): 76 | size = X.get_shape()[-1].value 77 | with vscope(name): 78 | preact = fc_layer(X, 3 * size, nonlin=tf.identity) 79 | x_tilde, f_pre, r_pre = tf.split(preact, 3, len(preact.shape) - 1) 80 | 81 | f = tf.sigmoid(f_pre) 82 | r = tf.sigmoid(r_pre) 83 | 84 | c = linear_recurrence(f, (1 - f) * x_tilde) 85 | h = r * c + (1 - r) * X 86 | return h 87 | -------------------------------------------------------------------------------- /layers_new.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from linear_recurrent_net.tensorflow_binding import linear_recurrence 3 | 4 | def vscope(name): 5 | return tf.variable_scope(None, default_name=name) 6 | 7 | # contracts on the innermost dimension 8 | def matmul(X, W): 9 | res = tf.tensordot(X, W, [[-1], [0]]) 10 | res.set_shape(X.get_shape().as_list()[:-1] + 11 | W.get_shape().as_list()[1:]) 12 | return res 13 | 14 | def embedding_layer(X, size, dims, name='embedding'): 15 | with vscope(name): 16 | W = tf.get_variable('W', [dims, size]) 17 | return tf.nn.embedding_lookup(W, X) 18 | 19 | def fc_layer(X, hidden_size, nonlin=tf.nn.elu, 20 | use_bias=True, use_layer_norm=False, ln_eps=1e-3, 21 | name='fc', sn=0.05, forget_bias=5.0): 22 | n_dims = X.get_shape()[-1].value 23 | with vscope(name): 24 | W = tf.get_variable('W', initializer=tf.random_uniform([n_dims, hidden_size], maxval=sn, minval=-sn)) 25 | 26 | if use_bias and name == 'pre_fc': 27 | b = tf.get_variable('b', initializer=tf.concat([tf.constant(forget_bias, shape=[hidden_size/4]), 28 | tf.zeros([3*(hidden_size/4)])],axis=0)) 29 | elif use_bias and name == 'sru_pre': 30 | b = tf.get_variable('b', initializer=tf.concat([tf.zeros([(hidden_size/3)]), 31 | tf.constant(forget_bias, shape=[hidden_size/3]), 32 | tf.zeros([(hidden_size/3)])],axis=0)) 33 | elif use_bias: 34 | b = tf.get_variable('b', initializer=tf.zeros([hidden_size])) 35 | else: 36 | b = 0 37 | 38 | prod = matmul(X, W) 39 | if use_layer_norm: 40 | idx = ([None] * (len(X.shape) - 1)) + [slice(None)] 41 | g = tf.get_variable('g', [hidden_size])[idx] 42 | 43 | mu, sigma = tf.nn.moments(prod, [-1], keep_dims=True) 44 | prod = g * (prod - mu) / (sigma + ln_eps) 45 | 46 | return nonlin(prod + b) 47 | 48 | def gilr_layer(X, hidden_size, nonlin=tf.nn.elu, 49 | name='gilr'): 50 | """ 51 | g_t = sigmoid(Ux_t + b) 52 | h_t = g_t h_{t-1} + (1-g_t) f(Vx_t + c) 53 | """ 54 | with vscope(name): 55 | n_dims = X.get_shape()[-1].value 56 | act = fc_layer(X, 2 * hidden_size, nonlin=tf.identity) 57 | gate, impulse = tf.split(act, 2, len(act.shape) - 1) 58 | gate = tf.sigmoid(gate) 59 | impulse = nonlin(impulse) 60 | return linear_recurrence(gate, (1-gate) * impulse) 61 | 62 | def linear_surrogate_lstm(X, hidden_size, name='lin_sur_lstm'): 63 | with vscope(name): 64 | # 2 * hidden_size * n_dims params 65 | h_tilde = gilr_layer(X, hidden_size, nonlin=tf.tanh) 66 | 67 | # 4 * hidden_size * (hidden_size + n_dims) 68 | preact = fc_layer(tf.concat([h_tilde, X], axis=-1), 4 * hidden_size, 69 | nonlin=tf.identity, name='pre_fc') 70 | 71 | f, i, o, z = tf.split(preact, 4, len(preact.shape) - 1) 72 | 73 | f = tf.sigmoid(f) 74 | i = tf.sigmoid(i) 75 | o = tf.sigmoid(o) 76 | z = tf.tanh(z) 77 | 78 | c = linear_recurrence(f, i * z) 79 | h = o * c 80 | return h 81 | 82 | def SRU(X, name='SRU'): 83 | size = X.get_shape()[-1].value 84 | with vscope(name): 85 | preact = fc_layer(X, 3 * size, nonlin=tf.identity, name='sru_pre') 86 | x_tilde, f_pre, r_pre = tf.split(preact, 3, len(preact.shape) - 1) 87 | 88 | f = tf.sigmoid(f_pre) 89 | r = tf.sigmoid(r_pre) 90 | 91 | c = linear_recurrence(f, (1 - f) * x_tilde) 92 | h = r * c + (1 - r) * X 93 | return h 94 | 95 | 96 | def QRNN(X, n, name='qrnn'): 97 | size = X.get_shape()[-1].value 98 | length = X.get_shape()[0].value 99 | bs = X.get_shape()[1].value 100 | with vscope(name): 101 | stack_list = [] 102 | for m in range(1, n-1): 103 | stack_list.append(tf.slice(tf.pad(X, [[m,0], [0,0], [0,0]]), 104 | [0,0,0], [length, bs, size])) 105 | X_stacked = tf.concat([X] + stack_list, axis=-1) 106 | 107 | preact = fc_layer(X_stacked, 3 * n * size, nonlin=tf.identity, name='qrnn_pre') 108 | 109 | z, f, o = tf.split(preact, 3, len(preact.shape) - 1) 110 | 111 | z = tf.tanh(tf.add_n(tf.split(z, n, len(preact.shape) - 1))) 112 | f = tf.sigmoid(tf.add_n(tf.split(f, n, len(preact.shape) - 1))) 113 | o = tf.sigmoid(tf.add_n(tf.split(o, n, len(preact.shape) - 1))) 114 | 115 | c = linear_recurrence(f, (1 - f) * z) 116 | h = o * c 117 | return h 118 | 119 | def s_gilr_layer(X, hidden_size, nonlin=tf.nn.elu, 120 | name='gilr'): 121 | """ 122 | g_t = sigmoid(Ux_t + b) 123 | h_t = g_t h_{t-1} + (1-g_t) f(Vx_t + c) 124 | """ 125 | with vscope(name): 126 | n_dims = X.get_shape()[-1].value 127 | act = fc_layer(X, 2 * hidden_size, nonlin=tf.identity) 128 | gate, impulse = tf.split(act, 2, len(act.shape) - 1) 129 | gate = tf.sigmoid(gate) 130 | impulse = nonlin(impulse) 131 | return linear_recurrence(gate, (1-gate) * impulse, serial=True) 132 | 133 | def s_linear_surrogate_lstm(X, hidden_size, name='lin_sur_lstm'): 134 | with vscope(name): 135 | # 2 * hidden_size * n_dims params 136 | h_tilde = gilr_layer(X, hidden_size, nonlin=tf.tanh) 137 | 138 | # 4 * hidden_size * (hidden_size + n_dims) 139 | preact = fc_layer(tf.concat([h_tilde, X], axis=-1), 4 * hidden_size, 140 | nonlin=tf.identity, name='pre_fc') 141 | 142 | f, i, o, z = tf.split(preact, 4, len(preact.shape) - 1) 143 | 144 | f = tf.sigmoid(f) 145 | i = tf.sigmoid(i) 146 | o = tf.sigmoid(o) 147 | z = tf.tanh(z) 148 | 149 | c = linear_recurrence(f, i * z, serial=True) 150 | h = o * c 151 | return h 152 | 153 | def s_SRU(X, name='SRU'): 154 | size = X.get_shape()[-1].value 155 | with vscope(name): 156 | preact = fc_layer(X, 3 * size, nonlin=tf.identity, name='sru_pre') 157 | x_tilde, f_pre, r_pre = tf.split(preact, 3, len(preact.shape) - 1) 158 | 159 | f = tf.sigmoid(f_pre) 160 | r = tf.sigmoid(r_pre) 161 | 162 | c = linear_recurrence(f, (1 - f) * x_tilde, serial=True) 163 | h = r * c + (1 - r) * X 164 | return h 165 | 166 | def s_QRNN(X, n, name='qrnn'): 167 | size = X.get_shape()[-1].value 168 | length = X.get_shape()[0].value 169 | bs = X.get_shape()[1].value 170 | with vscope(name): 171 | stack_list = [] 172 | for m in range(1, n-1): 173 | stack_list.append(tf.slice(tf.pad(X, [[m,0], [0,0], [0,0]]), 174 | [0,0,0], [length, bs, size])) 175 | X_stacked = tf.concat([X] + stack_list, axis=-1) 176 | 177 | preact = fc_layer(X_stacked, 3 * n * size, nonlin=tf.identity, name='qrnn_pre') 178 | 179 | z, f, o = tf.split(preact, 3, len(preact.shape) - 1) 180 | 181 | z = tf.tanh(tf.add_n(tf.split(z, n, len(preact.shape) - 1))) 182 | f = tf.sigmoid(tf.add_n(tf.split(f, n, len(preact.shape) - 1))) 183 | o = tf.sigmoid(tf.add_n(tf.split(o, n, len(preact.shape) - 1))) 184 | 185 | c = linear_recurrence(f, (1 - f) * z, serial=True) 186 | h = o * c 187 | return h 188 | 189 | 190 | def linear_surrogate_lstm_cpu(X, hidden_size, name='lin_sur_lstm'): 191 | with vscope(name): 192 | # 2 * hidden_size * n_dims params 193 | h_tilde = gilr_layer(X, hidden_size, nonlin=tf.tanh) 194 | 195 | # 4 * hidden_size * (hidden_size + n_dims) 196 | preact = fc_layer(tf.concat([h_tilde, X], axis=-1), 4 * hidden_size, 197 | nonlin=tf.identity, name='pre_fc') 198 | 199 | f, i, o, z = tf.split(preact, 4, len(preact.shape) - 1) 200 | 201 | f = tf.sigmoid(f) 202 | i = tf.sigmoid(i) 203 | o = tf.sigmoid(o) 204 | z = tf.tanh(z) 205 | 206 | c = linear_recurrence_cpu(f, i * z) 207 | h = o * c 208 | return h 209 | 210 | def gilr_layer_cpu(X, hidden_size, nonlin=tf.nn.elu, 211 | name='gilr'): 212 | """ 213 | g_t = sigmoid(Ux_t + b) 214 | h_t = g_t h_{t-1} + (1-g_t) f(Vx_t + c) 215 | """ 216 | with vscope(name): 217 | n_dims = X.get_shape()[-1].value 218 | act = fc_layer(X, 2 * hidden_size, nonlin=tf.identity) 219 | gate, impulse = tf.split(act, 2, len(act.shape) - 1) 220 | gate = tf.sigmoid(gate) 221 | impulse = nonlin(impulse) 222 | return s_linear_recurrence_cpu(gate, (1-gate) * impulse) 223 | 224 | def linear_recurrence_cpu(f, b): 225 | """Compute the linear recurrence using native tf operations 226 | so that we evaluate without a GPU. We evaluate the recurrence 227 | which is stepwise h_t = f * h_{t-1} + b, returning all h.""" 228 | fs = tf.unstack(f, axis=0) 229 | bs = tf.unstack(b, axis=0) 230 | h = tf.identity(b) 231 | 232 | hs = [bs[0]] 233 | for index in range(1, len(bs)): 234 | print fs[index], bs[index] 235 | to_append = tf.add(tf.multiply(fs[index], hs[index-1]), bs[index]) 236 | hs.append(to_append) 237 | return tf.stack(hs) 238 | 239 | -------------------------------------------------------------------------------- /linear_recurrent_net/build.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | rm -rf lib/ 3 | 4 | mkdir lib 5 | nvcc -c linear_recurrence.cu -o lib/linear_recurrence.o -O3 --compiler-options '-fPIC' 6 | nvcc lib/linear_recurrence.o -shared -o lib/liblinear_recurrence.so --compiler-options '-fPIC' 7 | 8 | # building tensorflow op 9 | export TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 10 | g++ -std=c++11 -shared -o lib/tf_linear_recurrence.so linear_recurrent_net/tensorflow_binding/linear_recurrence_op.cpp lib/linear_recurrence.o -O3 -I $TF_INC -fPIC -lcudart 11 | -------------------------------------------------------------------------------- /linear_recurrent_net/linear_recurrence.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CEIL_DIV(x, y) ((x + y - 1) / y) 5 | 6 | #define gpuErrChk(ans) { gpuAssert((ans), __FILE__, __LINE__); } 7 | void gpuAssert(cudaError_t code, const char *file, int line) { 8 | if (code != cudaSuccess) { 9 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 10 | } 11 | } 12 | 13 | __device__ int2 divide_work(int n_jobs, int n_workers, int worker_idx) { 14 | // Each worker will do a continuous slice of either n_jobs / n_workers 15 | // or ceil_div(n_jobs, n_workers). The return value is an int2 representing 16 | // a half open interval of jobs for the worker to perform (perform jobs 17 | // i for a <= i < b) 18 | 19 | int cd = CEIL_DIV(n_jobs, n_workers); 20 | int d = n_jobs / n_workers; 21 | 22 | int doing_cd = n_jobs % n_workers; 23 | 24 | int2 retval; 25 | if (worker_idx < doing_cd) { 26 | retval.x = worker_idx * cd; 27 | retval.y = retval.x + cd; 28 | } else { 29 | retval.x = doing_cd * cd + (worker_idx - doing_cd) * d; 30 | retval.y = retval.x + d; 31 | } 32 | 33 | return retval; 34 | } 35 | 36 | __device__ int2 compute_warp_start_stop(int block_idx, int warp_idx, 37 | int n_blocks, int n_steps) { 38 | int2 block_ss = divide_work(n_steps, n_blocks, block_idx); 39 | int block_start = block_ss.x; 40 | int block_stop = block_ss.y; 41 | int block_jobs = block_stop - block_start; 42 | 43 | int2 warp_ss = divide_work(block_jobs, 32, warp_idx); 44 | int warp_start = block_start + warp_ss.x; 45 | int warp_stop = block_start + warp_ss.y; 46 | 47 | int2 retval; 48 | retval.x = warp_start; 49 | retval.y = warp_stop; 50 | return retval; 51 | } 52 | 53 | // decay storage, h_storage: 54 | // each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block reduction 55 | __global__ void reduction_kernel(float *decays, float *impulses, 56 | float *initial_state, 57 | float *_decay_storage, float *_h_storage, 58 | int n_dims, int n_steps) { 59 | int warp = threadIdx.x / 32; 60 | int lane = threadIdx.x % 32; 61 | 62 | float *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims]; 63 | float *h_storage = &_h_storage[blockIdx.x * 33 * n_dims]; 64 | 65 | int2 start_stop = compute_warp_start_stop(blockIdx.x, warp, gridDim.x, n_steps); 66 | int warp_start = start_stop.x; 67 | int warp_stop = start_stop.y; 68 | 69 | /* 70 | * Reduce within warps. 71 | * After this loop exits, the storage arrays should contain the reduction 72 | * from warp_start to warp_stop (including initial state) at index 73 | * (feature_idx, warp, block). 74 | */ 75 | for (int i = lane; i < n_dims; i += 32) { 76 | float cum_decay = 1.0; 77 | float h = 0.0; 78 | if (blockIdx.x == 0 && warp == 0 && initial_state != NULL) { 79 | h = initial_state[i]; 80 | } 81 | 82 | for (int t = warp_start; t < warp_stop; t++) { 83 | cum_decay *= decays[i + t * n_dims]; 84 | h = decays[i + t * n_dims] * h + impulses[i + t * n_dims]; 85 | } 86 | 87 | // TODO: store into shared memory, work in shared memory sized blocks 88 | // store into global memory 89 | decay_storage[i + warp * n_dims] = cum_decay; 90 | h_storage[i + warp * n_dims] = h; 91 | } 92 | 93 | __syncthreads(); 94 | 95 | /* 96 | * Reduce over warps. 97 | * After this loop exits, the storage arrays should contain the reduction 98 | * from block_start to block_finish (including initial state) at index 99 | * (feature_idx, 32, block). 100 | */ 101 | // TODO: parallel reduction (or scan). Need to worry about changing the warp 102 | // reduction values (as I use them again later) 103 | for (int i = lane + 32 * warp; i < n_dims; i += blockDim.x) { 104 | float cum_decay = 1.0; 105 | float h = 0.0; 106 | for (int t = 0; t < 32; t++) { 107 | cum_decay *= decay_storage[i + t * n_dims]; 108 | h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims]; 109 | } 110 | decay_storage[i + 32 * n_dims] = cum_decay; 111 | h_storage[i + 32 * n_dims] = h; 112 | } 113 | } 114 | 115 | __global__ void block_scan_kernel(float *decay_storage, float *h_storage, 116 | int n_dims, int n_blocks) { 117 | /* 118 | * Scan over blocks. 119 | * After this loop exits, the storage arrays should contain the cumulative sum 120 | * from block_idx 0 to i (inclusive) at index (feature_idx, 32, i) 121 | * This means (feature_idx, 32, 2) contains the reduction of blocks 0, 1, and 2. 122 | */ 123 | // TODO: parallel scan (tricky because number of blocks isn't necessarily 124 | // smaller than number of warps that can fit in a single block) 125 | for (int i = threadIdx.x + blockIdx.x * blockDim.x; 126 | i < n_dims; 127 | i += blockDim.x * gridDim.x) { 128 | 129 | for (int t = 1; t < n_blocks; t++) { 130 | int cur_idx = i + 32 * n_dims + t * 33 * n_dims; 131 | int prev_idx = i + 32 * n_dims + (t - 1) * 33 * n_dims; 132 | 133 | // TODO: remove unneccessary reads from global memory (prev_idx accesses) 134 | h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + h_storage[cur_idx]; 135 | decay_storage[cur_idx] *= decay_storage[prev_idx]; 136 | } 137 | } 138 | } 139 | 140 | __global__ void warp_scan_kernel(float *decays, float *impulses, 141 | float *initial_state, float *out, 142 | float *decay_storage, float *h_storage, 143 | int n_dims, int n_steps) { 144 | int warp = threadIdx.x / 32; 145 | int lane = threadIdx.x % 32; 146 | 147 | // Note: Due to the index ordering of the storage arrays, the following 148 | // indices are equivalent: 149 | // 150 | // i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims 151 | // i + 32 * n_dims + (blockIdx.x - 1) * 33 * n_dims 152 | // 153 | // when t is 0. This means something that looks like negative indexing 154 | // (t-1) can be used to safely access the stored value for the previous 155 | // warp (even if the previous warp belonged to the previous block). 156 | 157 | /* 158 | * Scan over warps. 159 | * After this loop executes, the storage arrays should contain the cumulative 160 | * sum from the beginning of sequence (including initial condition) up to 161 | * and including the indexed warp and block. 162 | */ 163 | // TODO: parallel scan 164 | for (int i = lane + 32 * warp; i < n_dims; i += blockDim.x) { 165 | for (int t = 0; t < 32; t++) { 166 | if (t == 0 && blockIdx.x == 0) { 167 | // the reduction over warp 0 (including initial condition) is correct val 168 | // for scan, so there's no work to do 169 | continue; 170 | } 171 | 172 | int cur_idx = i + t * n_dims + blockIdx.x * 33 * n_dims; 173 | int prev_idx = i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims; 174 | h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + h_storage[cur_idx]; 175 | decay_storage[cur_idx] *= decay_storage[prev_idx]; 176 | } 177 | } 178 | 179 | __syncthreads(); 180 | 181 | int2 start_stop = compute_warp_start_stop(blockIdx.x, warp, gridDim.x, n_steps); 182 | int warp_start = start_stop.x; 183 | int warp_stop = start_stop.y; 184 | 185 | /* 186 | * Scan within warps. 187 | * This loop writes to the output array. Each warp reads in it's initial state 188 | * (either from the "initial_state" or the storage arrays) and then writes 189 | * to output for indices warp_start up to warp_stop. 190 | */ 191 | for (int i = lane; i < n_dims; i += 32) { 192 | float h = 0.0; 193 | if (blockIdx.x == 0 && warp == 0) { 194 | if (initial_state != NULL) { 195 | h = initial_state[i]; 196 | } 197 | } else { 198 | h = h_storage[i + (warp - 1) * n_dims + blockIdx.x * 33 * n_dims]; 199 | } 200 | 201 | for (int t = warp_start; t < warp_stop; t++) { 202 | h = decays[i + t * n_dims] * h + impulses[i + t * n_dims]; 203 | out[i + t * n_dims] = h; 204 | } 205 | } 206 | } 207 | 208 | __global__ void serial_linear_recurrence(float *decays, float *impulses, 209 | float *initial_state, float *out, 210 | int n_dims, int n_steps) { 211 | // computes h_t = lambda_t h{t-1} + x_t 212 | 213 | for (int dim_idx = threadIdx.x + blockIdx.x * blockDim.x; 214 | dim_idx < n_dims; 215 | dim_idx += blockDim.x * gridDim.x) { 216 | float val = initial_state[dim_idx]; 217 | 218 | for (int step = 0; step < n_steps; step++) { 219 | int idx = dim_idx + step * n_dims; 220 | val = decays[idx] * val + impulses[idx]; 221 | out[idx] = val; 222 | } 223 | } 224 | } 225 | 226 | extern "C" { 227 | /* 228 | * This is the main method for the prefix sum kernels. 229 | * decays, impulses, out: 230 | * each a n_dims x n_steps column major matrix located on GPU 231 | * initial_state: 232 | * array of size n_dims located on GPU 233 | */ 234 | void compute_linear_recurrence(float *decays, float *impulses, float *initial_state, 235 | float *out, int n_dims, int n_steps) { 236 | 237 | // TODO: query 238 | int n_SMs = 15; 239 | int n_blocks_per_sm = 2; 240 | 241 | // we want at least 32 elements per block, but no reason to run 242 | // with more than the maximum number of concurrent blocks 243 | int n_blocks = min(CEIL_DIV(n_steps, 32), n_SMs * n_blocks_per_sm); 244 | 245 | // TODO: make user pass in working memory? This allows integration 246 | // with CNMeM (used by Theano) 247 | int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof(float); 248 | float *d_reduction_mem; 249 | gpuErrChk(cudaMalloc(&d_reduction_mem, reduction_mem_sz)); 250 | float *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims]; 251 | float *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims]; 252 | 253 | // TODO: run kernels on non-default stream? 254 | reduction_kernel<<>>(decays, impulses, initial_state, 255 | d_decay_storage, d_h_storage, 256 | n_dims, n_steps); 257 | 258 | block_scan_kernel<<>>(d_decay_storage, d_h_storage, 259 | n_dims, n_blocks); 260 | 261 | warp_scan_kernel<<>>(decays, impulses, 262 | initial_state, out, 263 | d_decay_storage, d_h_storage, 264 | n_dims, n_steps); 265 | 266 | gpuErrChk(cudaFree(d_reduction_mem)); 267 | } 268 | 269 | void compute_serial_linear_recurrence(float *decays, float *impulses, 270 | float *initial_state, float *out, 271 | int n_dims, int n_steps) { 272 | // TODO: query 273 | int n_SMs = 15; 274 | int n_blocks_per_sm = 2; 275 | 276 | int n_blocks = n_SMs * n_blocks_per_sm; 277 | serial_linear_recurrence<<>>(decays, impulses, initial_state, 278 | out, n_dims, n_steps); 279 | } 280 | } 281 | 282 | void test() { 283 | int n_dims = 100; 284 | int n_steps = 1000000; 285 | int n_elements = n_dims * n_steps; 286 | 287 | float *decays = (float *) calloc(n_elements, sizeof(float)); 288 | for (int i = 0; i < n_elements; i++) { 289 | decays[i] = .999; 290 | } 291 | float *d_decays; 292 | gpuErrChk(cudaMalloc(&d_decays, n_elements * sizeof(float))); 293 | gpuErrChk(cudaMemcpy(d_decays, decays, n_elements * sizeof(float), 294 | cudaMemcpyHostToDevice)); 295 | 296 | float *impulses = (float *) calloc(n_elements, sizeof(float)); 297 | for (int i = 0; i < n_dims; i++) { 298 | impulses[i + 0 * n_dims] = 2.0; 299 | } 300 | float *d_impulses; 301 | gpuErrChk(cudaMalloc(&d_impulses, n_elements * sizeof(float))); 302 | gpuErrChk(cudaMemcpy(d_impulses, impulses, 303 | n_elements * sizeof(float), cudaMemcpyHostToDevice)); 304 | 305 | float *out = (float *) calloc(n_elements, sizeof(float)); 306 | float *d_out; 307 | gpuErrChk(cudaMalloc(&d_out, n_elements * sizeof(float))); 308 | gpuErrChk(cudaMemset(d_out, 0, n_elements * sizeof(float))); 309 | 310 | compute_linear_recurrence(d_decays, d_impulses, NULL, d_out, n_dims, n_steps); 311 | gpuErrChk(cudaMemcpy(out, d_out, n_elements * sizeof(float), 312 | cudaMemcpyDeviceToHost)); 313 | 314 | gpuErrChk(cudaFree(d_decays)); 315 | gpuErrChk(cudaFree(d_impulses)); 316 | gpuErrChk(cudaFree(d_out)); 317 | } 318 | -------------------------------------------------------------------------------- /linear_recurrent_net/linear_recurrence.h: -------------------------------------------------------------------------------- 1 | extern "C" { 2 | void compute_linear_recurrence( 3 | const float *decays, /* n_steps x n_dims row major matrix */ 4 | const float *impulses, /* n_steps x n_dims row major matrix */ 5 | const float *initial_state, /* n_dims vector */ 6 | float *out, /* n_steps x n_dims row major matrix */ 7 | int n_dims, /* dimensionality of recurrent vector */ 8 | int n_steps /* length of input and output sequences */ 9 | ); 10 | } 11 | -------------------------------------------------------------------------------- /linear_recurrent_net/linear_recurrent_net/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /linear_recurrent_net/linear_recurrent_net/tensorflow_binding/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from tensorflow.python.framework import ops 4 | 5 | dir = os.path.dirname(os.path.abspath(__file__)) 6 | _lr_module = tf.load_op_library('%s/../../lib/tf_linear_recurrence.so' % dir) 7 | 8 | def linear_recurrence(decays, impulses, initial_state=None): 9 | ''' 10 | Compute r[i] = decays[i] * r[i - 1] + impulses[i] with r[0] = initial_state. 11 | 12 | decays and impulses must have the same shape and are [n_steps, ...]. 13 | initial_state must be None (to zero initialize) or [...] 14 | ''' 15 | 16 | if initial_state is None: 17 | initial_state = tf.zeros_like(impulses[0, :]) 18 | 19 | shape = tf.shape(decays) 20 | rank = shape.get_shape()[0].value 21 | if rank > 2: 22 | tail = tf.reduce_prod(shape[1:]) 23 | decays = tf.reshape(decays, [shape[0], tail]) 24 | impulses = tf.reshape(impulses, [shape[0], tail]) 25 | initial_state = tf.reshape(initial_state, [tail]) 26 | 27 | resp = _lr_module.linear_recurrence(decays, impulses, initial_state) 28 | 29 | if rank > 2: 30 | resp = tf.reshape(resp, shape) 31 | return resp 32 | 33 | @ops.RegisterGradient("LinearRecurrence") 34 | def _linear_recurrence_grad(op, dl_dresp): 35 | decays = op.inputs[0] 36 | impulses = op.inputs[1] 37 | initial_state = op.inputs[2] 38 | 39 | n_steps = tf.shape(impulses)[0] 40 | 41 | # forwards goes from h_0 to h_{T-1} 42 | forwards_tail = linear_recurrence(decays, impulses, initial_state)[:-1, :] 43 | forwards = tf.concat([tf.expand_dims(initial_state, 0), forwards_tail], 44 | axis=0) 45 | 46 | reverse = lambda x: tf.reverse(x, axis=[0]) 47 | 48 | # recur on 49 | # decays from T, T-1, ..., 2 50 | # output gradients from T-1, T-2, ..., 1 51 | dl_dh_head = reverse( 52 | linear_recurrence( 53 | reverse(decays)[:-1, :], 54 | reverse(dl_dresp)[1:, :], 55 | dl_dresp[-1, :], 56 | ) 57 | ) 58 | 59 | dl_dh = tf.concat([dl_dh_head, dl_dresp[-1:, :]], axis=0) 60 | 61 | dl_dinit = decays[0, :] * dl_dh[0, :] 62 | dl_dimpulses = dl_dh 63 | dl_ddecays = dl_dh * forwards 64 | 65 | return [dl_ddecays, dl_dimpulses, dl_dinit] 66 | -------------------------------------------------------------------------------- /linear_recurrent_net/linear_recurrent_net/tensorflow_binding/gradient_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | from linear_recurrent_net.tensorflow_binding import linear_recurrence 5 | 6 | n_dims = 20 7 | n_steps = 30 8 | 9 | np.random.seed(2016) 10 | decays = np.random.uniform(size=(n_steps, n_dims)).astype(np.float32) 11 | impulses = np.random.randn(n_steps, n_dims).astype(np.float32) 12 | initial_state = np.random.randn(n_dims).astype(np.float32) 13 | 14 | with tf.Session() as sess: 15 | inp = tf.constant(decays) 16 | response = linear_recurrence(inp, impulses, initial_state) 17 | print('Decays grad err:', 18 | tf.test.compute_gradient_error(inp, decays.shape, 19 | response, impulses.shape) 20 | ) 21 | 22 | inp = tf.constant(impulses) 23 | response = linear_recurrence(decays, inp, initial_state) 24 | print('Impulses grad err:', 25 | tf.test.compute_gradient_error(inp, impulses.shape, 26 | response, impulses.shape) 27 | ) 28 | 29 | inp = tf.constant(initial_state) 30 | response = linear_recurrence(decays, impulses, inp) 31 | print('Initial state grad err:', 32 | tf.test.compute_gradient_error(inp, initial_state.shape, 33 | response, impulses.shape) 34 | ) 35 | -------------------------------------------------------------------------------- /linear_recurrent_net/linear_recurrent_net/tensorflow_binding/linear_recurrence_op.cpp: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/op_kernel.h" 3 | #include "tensorflow/core/framework/shape_inference.h" 4 | #include "../../linear_recurrence.h" 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("LinearRecurrence") 9 | .Input("decays: float") 10 | .Input("impulses: float") 11 | .Input("initial_state: float") 12 | .Output("response: float") 13 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { 14 | c->set_output(0, c->input(1)); 15 | return Status::OK(); 16 | }); 17 | 18 | class GpuLinearRecurrenceOp : public OpKernel { 19 | public: 20 | explicit GpuLinearRecurrenceOp(OpKernelConstruction *ctx): OpKernel(ctx) {} 21 | 22 | void Compute(OpKernelContext *ctx) override { 23 | const Tensor& decays_tensor = ctx->input(0); 24 | const Tensor& impulses_tensor = ctx->input(1); 25 | const Tensor& initial_state_tensor = ctx->input(2); 26 | 27 | int n_steps = impulses_tensor.dim_size(0); 28 | int n_dims = impulses_tensor.dim_size(1); 29 | 30 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(impulses_tensor.shape()), 31 | errors::InvalidArgument("Impulses must be a matrix")); 32 | 33 | 34 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(decays_tensor.shape()), 35 | errors::InvalidArgument("Decays must be a matrix")); 36 | 37 | OP_REQUIRES(ctx, 38 | decays_tensor.dim_size(0) == n_steps && 39 | decays_tensor.dim_size(1) == n_dims, 40 | errors::InvalidArgument("Decay shape must match impulse shape")); 41 | 42 | OP_REQUIRES(ctx, 43 | TensorShapeUtils::IsVector(initial_state_tensor.shape()) && 44 | initial_state_tensor.dim_size(0) == n_dims, 45 | errors::InvalidArgument("Initial state must be a vector of length n_dims")); 46 | 47 | Tensor *response_tensor = NULL; 48 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, impulses_tensor.shape(), &response_tensor)); 49 | 50 | auto decays = decays_tensor.flat(); 51 | auto impulses = impulses_tensor.flat(); 52 | auto initial_state = initial_state_tensor.flat(); 53 | auto response = response_tensor->template flat(); 54 | 55 | compute_linear_recurrence(decays.data(), impulses.data(), 56 | initial_state.data(), response.data(), 57 | n_dims, n_steps); 58 | } 59 | }; 60 | REGISTER_KERNEL_BUILDER(Name("LinearRecurrence").Device(DEVICE_GPU), GpuLinearRecurrenceOp); 61 | -------------------------------------------------------------------------------- /paper/1k_20_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/1k_20_smoothing.png -------------------------------------------------------------------------------- /paper/1k_synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/1k_synthetic.png -------------------------------------------------------------------------------- /paper/1k_synthetic_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/1k_synthetic_new.png -------------------------------------------------------------------------------- /paper/1m_20_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/1m_20_smoothing.png -------------------------------------------------------------------------------- /paper/1m_synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/1m_synthetic.png -------------------------------------------------------------------------------- /paper/1m_synthetic_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/1m_synthetic_new.png -------------------------------------------------------------------------------- /paper/8k_5_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/8k_5_smoothing.png -------------------------------------------------------------------------------- /paper/8k_synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/8k_synthetic.png -------------------------------------------------------------------------------- /paper/8k_synthetic_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/8k_synthetic_new.png -------------------------------------------------------------------------------- /paper/between_chain_plots/1k_0_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/1k_0_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/1k_20_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/1k_20_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/1k_5_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/1k_5_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/1m_0_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/1m_0_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/1m_20_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/1m_20_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/8k_0_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/8k_0_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/8k_20_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/8k_20_smoothing.png -------------------------------------------------------------------------------- /paper/between_chain_plots/8k_5_smoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/between_chain_plots/8k_5_smoothing.png -------------------------------------------------------------------------------- /paper/cudnn_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/cudnn_heatmap.png -------------------------------------------------------------------------------- /paper/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/heatmap.png -------------------------------------------------------------------------------- /paper/heatmap_accurate_to_caption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/heatmap_accurate_to_caption.png -------------------------------------------------------------------------------- /paper/iclr2018_conference.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/iclr2018_conference.log -------------------------------------------------------------------------------- /paper/iclr2018_conference.sty: -------------------------------------------------------------------------------- 1 | %%%% ICLR Macros (LaTex) 2 | %%%% Adapted by Hugo Larochelle from the NIPS stylefile Macros 3 | %%%% Style File 4 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999; October 2014 5 | 6 | % This file can be used with Latex2e whether running in main mode, or 7 | % 2.09 compatibility mode. 8 | % 9 | % If using main mode, you need to include the commands 10 | % \documentclass{article} 11 | % \usepackage{iclr14submit_e,times} 12 | % 13 | 14 | % Change the overall width of the page. If these parameters are 15 | % changed, they will require corresponding changes in the 16 | % maketitle section. 17 | % 18 | \usepackage{eso-pic} % used by \AddToShipoutPicture 19 | \RequirePackage{fancyhdr} 20 | \RequirePackage{natbib} 21 | 22 | % modification to natbib citations 23 | \setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}} 24 | 25 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 26 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 27 | 28 | % Define iclrfinal, set to true if iclrfinalcopy is defined 29 | \newif\ificlrfinal 30 | \iclrfinalfalse 31 | \def\iclrfinalcopy{\iclrfinaltrue} 32 | \font\iclrtenhv = phvb at 8pt 33 | 34 | % Specify the dimensions of each page 35 | 36 | \setlength{\paperheight}{11in} 37 | \setlength{\paperwidth}{8.5in} 38 | 39 | 40 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 41 | \evensidemargin .5in 42 | \marginparwidth 0.07 true in 43 | %\marginparwidth 0.75 true in 44 | %\topmargin 0 true pt % Nominal distance from top of page to top of 45 | %\topmargin 0.125in 46 | \topmargin -0.625in 47 | \addtolength{\headsep}{0.25in} 48 | \textheight 9.0 true in % Height of text (including footnotes & figures) 49 | \textwidth 5.5 true in % Width of text line. 50 | \widowpenalty=10000 51 | \clubpenalty=10000 52 | 53 | % \thispagestyle{empty} \pagestyle{empty} 54 | \flushbottom \sloppy 55 | 56 | % We're never going to need a table of contents, so just flush it to 57 | % save space --- suggested by drstrip@sandia-2 58 | \def\addcontentsline#1#2#3{} 59 | 60 | % Title stuff, taken from deproc. 61 | \def\maketitle{\par 62 | \begingroup 63 | \def\thefootnote{\fnsymbol{footnote}} 64 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 65 | % name centering 66 | % The footnote-mark was overlapping the footnote-text, 67 | % added the following to fix this problem (MK) 68 | \long\def\@makefntext##1{\parindent 1em\noindent 69 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 70 | \@maketitle \@thanks 71 | \endgroup 72 | \setcounter{footnote}{0} 73 | \let\maketitle\relax \let\@maketitle\relax 74 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 75 | 76 | % The toptitlebar has been raised to top-justify the first page 77 | 78 | \usepackage{fancyhdr} 79 | \pagestyle{fancy} 80 | \fancyhead{} 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | %\linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\sc \@title\par} 86 | %\bottomtitlebar % \vskip 0.1in % minus 87 | \ificlrfinal 88 | \lhead{Published as a conference paper at ICLR 2018} 89 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 90 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 91 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 92 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 93 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 94 | \else 95 | \lhead{Under review as a conference paper at ICLR 2018} 96 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 97 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 98 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 99 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 100 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}Anonymous authors\\Paper under double-blind review\end{tabular}% 101 | \fi 102 | \vskip 0.3in minus 0.1in}} 103 | 104 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\sc 105 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 106 | 107 | % sections with less space 108 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 109 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 110 | minus0.2ex}{\large\sc\raggedright}} 111 | 112 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 113 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\sc\raggedright}} 114 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 115 | plus -0.5ex minus -.2ex}{0.5ex plus 116 | .2ex}{\normalsize\sc\raggedright}} 117 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 118 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 119 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 120 | 0.5ex minus .2ex}{-1em}{\normalsize\sc}} 121 | \def\subsubsubsection{\vskip 122 | 5pt{\noindent\normalsize\rm\raggedright}} 123 | 124 | 125 | % Footnotes 126 | \footnotesep 6.65pt % 127 | \skip\footins 9pt plus 4pt minus 2pt 128 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 129 | \setcounter{footnote}{0} 130 | 131 | % Lists and paragraphs 132 | \parindent 0pt 133 | \topsep 4pt plus 1pt minus 2pt 134 | \partopsep 1pt plus 0.5pt minus 0.5pt 135 | \itemsep 2pt plus 1pt minus 0.5pt 136 | \parsep 2pt plus 1pt minus 0.5pt 137 | \parskip .5pc 138 | 139 | 140 | %\leftmargin2em 141 | \leftmargin3pc 142 | \leftmargini\leftmargin \leftmarginii 2em 143 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 144 | 145 | %\labelsep \labelsep 5pt 146 | 147 | \def\@listi{\leftmargin\leftmargini} 148 | \def\@listii{\leftmargin\leftmarginii 149 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 150 | \topsep 2pt plus 1pt minus 0.5pt 151 | \parsep 1pt plus 0.5pt minus 0.5pt 152 | \itemsep \parsep} 153 | \def\@listiii{\leftmargin\leftmarginiii 154 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 155 | \topsep 1pt plus 0.5pt minus 0.5pt 156 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 157 | \itemsep \topsep} 158 | \def\@listiv{\leftmargin\leftmarginiv 159 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 160 | \def\@listv{\leftmargin\leftmarginv 161 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 162 | \def\@listvi{\leftmargin\leftmarginvi 163 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 164 | 165 | \abovedisplayskip 7pt plus2pt minus5pt% 166 | \belowdisplayskip \abovedisplayskip 167 | \abovedisplayshortskip 0pt plus3pt% 168 | \belowdisplayshortskip 4pt plus3pt minus3pt% 169 | 170 | % Less leading in most fonts (due to the narrow columns) 171 | % The choices were between 1-pt and 1.5-pt leading 172 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 173 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 174 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 175 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 176 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 177 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 178 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 179 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 180 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 181 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 182 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 183 | 184 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 185 | 186 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 187 | .09in} % 188 | %Reduced second vskip to compensate for adding the strut in \@author 189 | 190 | 191 | %% % Vertical Ruler 192 | %% % This code is, largely, from the CVPR 2010 conference style file 193 | %% % ----- define vruler 194 | %% \makeatletter 195 | %% \newbox\iclrrulerbox 196 | %% \newcount\iclrrulercount 197 | %% \newdimen\iclrruleroffset 198 | %% \newdimen\cv@lineheight 199 | %% \newdimen\cv@boxheight 200 | %% \newbox\cv@tmpbox 201 | %% \newcount\cv@refno 202 | %% \newcount\cv@tot 203 | %% % NUMBER with left flushed zeros \fillzeros[] 204 | %% \newcount\cv@tmpc@ \newcount\cv@tmpc 205 | %% \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 206 | %% \cv@tmpc=1 % 207 | %% \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 208 | %% \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 209 | %% \ifnum#2<0\advance\cv@tmpc1\relax-\fi 210 | %% \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 211 | %% \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 212 | %% % \makevruler[][][][][] 213 | %% \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 214 | %% \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 215 | %% \global\setbox\iclrrulerbox=\vbox to \textheight{% 216 | %% {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 217 | %% \cv@lineheight=#1\global\iclrrulercount=#2% 218 | %% \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 219 | %% \cv@refno1\vskip-\cv@lineheight\vskip1ex% 220 | %% \loop\setbox\cv@tmpbox=\hbox to0cm{{\iclrtenhv\hfil\fillzeros[#4]\iclrrulercount}}% 221 | %% \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 222 | %% \advance\cv@refno1\global\advance\iclrrulercount#3\relax 223 | %% \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 224 | %% \makeatother 225 | %% % ----- end of vruler 226 | 227 | %% % \makevruler[][][][][] 228 | %% \def\iclrruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\iclrrulerbox}} 229 | %% \AddToShipoutPicture{% 230 | %% \ificlrfinal\else 231 | %% \iclrruleroffset=\textheight 232 | %% \advance\iclrruleroffset by -3.7pt 233 | %% \color[rgb]{.7,.7,.7} 234 | %% \AtTextUpperLeft{% 235 | %% \put(\LenToUnit{-35pt},\LenToUnit{-\iclrruleroffset}){%left ruler 236 | %% \iclrruler{\iclrrulercount}} 237 | %% } 238 | %% \fi 239 | %% } 240 | %%% To add a vertical bar on the side 241 | %\AddToShipoutPicture{ 242 | %\AtTextLowerLeft{ 243 | %\hspace*{-1.8cm} 244 | %\colorbox[rgb]{0.7,0.7,0.7}{\small \parbox[b][\textheight]{0.1cm}{}}} 245 | %} 246 | 247 | -------------------------------------------------------------------------------- /paper/iclr_reviews.txt: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------ 2 | Title: 3 | Faster RNNs, with novel insights on need for nonlinear recurrence; 4 | novel and clear presentation 5 | 6 | Rating: 7: Good paper, accept 7 | 8 | Review: 9 | This paper abstracts two recently-proposed RNN variants into a 10 | family of RNNs called the Linear Surrogate RNNs which satisfy 11 | Blelloch's criteria for parallelizable sequential computation. The 12 | authors then propose an efficient parallel algorithm for this class of 13 | RNNs, which produces speedups over the existing implements of 14 | Quasi-RNN, SRU, and LSTM. Apart from efficiency results, the paper 15 | also contributes a comparison of model convergence on a long-term 16 | dependency task due to (Hochreiter and Schmidhuber, 1997). A novel 17 | linearized version of the LSTM outperforms traditional LSTM on this 18 | long-term dependency task, and raises questions about whether RNNs and 19 | LSTMs truly need the nonlinear structure. 20 | 21 | The paper is written very well, with explanation (as opposed to 22 | obfuscation) as the goal. Linear Surrogate RNNs is an important 23 | concept that is useful to understand RNN variants today, and 24 | potentially other future novel architectures. 25 | 26 | The paper provides argument and experimental evidence against the 27 | rotation used typically in RNNs. While this is an interesting insight, 28 | and worthy of further discussion, such a claim needs backing up with 29 | more large-scale experiments on real datasets. 30 | 31 | While the experiments on toy tasks is clearly useful, the paper could 32 | be significantly improved by adding experiments on real tasks such as 33 | language modelling. 34 | 35 | Confidence: 36 | 4: The reviewer is confident but not absolutely certain 37 | that the evaluation is correct 38 | 39 | Rebuttal 1: 40 | 41 | We feel the very impressive performance of SRUs and QRNNs on a variety 42 | of large-scale tasks demonstrates the applicability and usefulness of 43 | our work. We could have replicated their results in language modelling 44 | and machine translation with faster training times, but we believe 45 | that showing large speedup factors for these models is sufficient 46 | evidence for the value of parallel linear recurrence. 47 | 48 | We argue more strongly that a non-linearity recurrence is unnecessary 49 | than we do that "rotation free" RNNs are just as powerful as RNNs with 50 | non-diagonal weight matrices. However, SRUs are "rotation free" linear 51 | recurrences with performance equal or superior to LSTM and other 52 | non-linear RNNs on 6 sequence classification datasets, the SQuAD 53 | question answering dataset, Penn Treebank language modelling, 54 | Switchboard-1 speech recognition, and WMT English->German translation. 55 | 56 | ------------------------------------------------------------------------ 57 | Title: Authors propose a method to make recurrent learning over 1000s 58 | and more time steps possible. 59 | 60 | Rating: 7: Good paper, accept 61 | 62 | Review: 63 | # Summary and Assessment 64 | The paper addresses an important issue–that 65 | of making learning of recurrent networks tractable for sequence 66 | lengths well beyond 1’000s of time steps. A key problem here is that 67 | processing such sequences with ordinary RNNs requires a reduce 68 | operation, where the output of the net at time step t depends on the 69 | outputs of *all* its predecessor. The authors now make a crucial 70 | observation, namely that a certain class of RNNs allows evaluation in 71 | a non-linear fashion through a so-called SCAN operator. Here, if 72 | certain conditions are satisfied, the calculation of the output can be 73 | parallelised massively. In the following, the authors explore the 74 | landscape of RNNs satisfying the necessary conditions. The performance 75 | is investigated in terms of wall clock time. Further, experimental 76 | results of problems with previously untacked sequence lengths are 77 | reported. 78 | 79 | The paper is certainly relevant, as it can pave the way towards the 80 | application of recurrent architectures to problems that have extremely 81 | long term dependencies. To me, the execution seems sound. The 82 | experiments back up the claim. 83 | 84 | ## Minor - I challenge the claim that thousands and millions of time 85 | steps are a common issue in “robotics, remote sensing, control 86 | systems, speech recognition, medicine and finance”, as claimed in the 87 | first paragraph of the introduction. IMHO, most problems in these 88 | domains get away with a few hundred time steps; nevertheless, I’d 89 | appreciate a few examples where this is a case to better justify the 90 | method. 91 | 92 | Confidence: 93 | 2: The reviewer is willing to defend the evaluation, but 94 | it is quite likely that the reviewer did not understand central parts 95 | of the paper 96 | 97 | Rebuttal 1: 98 | 99 | We agree that you can often "get away with" backprop through time 100 | (BPTT) truncated at several hundred time steps for many sequential 101 | problems, even when the inherent sequence length of the data is very 102 | long. 103 | 104 | Some problems which can benefit from additional sequence length: 105 | 106 | * Medical waveforms are often sampled at greater than 1KHz. This means 107 | relatively short recordings create very long sequences. These 108 | sequences may be used for a sequence classification task which makes 109 | it difficult to use truncated BPTT. Sequence classification on very 110 | long sequences must either handle the entire sequence, classify 111 | subsequences (suboptimal as label may only be determined by part of 112 | the sequence), or down-sample the sequence data (suboptimal because 113 | it loses information). The 2016 PhysioNet Challenge 114 | (https://physionet.org/challenge/2016/) involved classifying EEGs 115 | sampled at 2KHz for 5-120s for a total of 10K-240K events per 116 | sequence. It would be difficult to apply neural nets to such a 117 | problem without a technique to parallelize over timesteps. An even 118 | more extreme dataset is 90 minutes @ 30KHz (= 160 million steps) of 119 | neural recordings of a mouse: http://data.cortexlab.net/dualPhase3/ 120 | 121 | * Example future machine learning task: Generate a (text) review of a 122 | 2+ hour movie, including comments on dialogue and 123 | cinematography. Even with significant downsampling of both frames 124 | and audio, a 2 hour movie contains 7200 frames at 1 frame/sec and an 125 | average of 9000 words 126 | (http://kaylinwalker.com/long-winded-actors-and-movies-with-the-most-dialogue/). 127 | We believe parallel sequential methods would be hugely useful for 128 | such a task. 129 | 130 | * I am not an expert, but I believe reinforcement learning on long 131 | episodes with sparse rewards could benefit from less episode 132 | truncation. 133 | 134 | ------------------------------------------------------------------------ 135 | Title: simple but effective method for RNN speed up 136 | 137 | Rating: 6: Marginally above acceptance threshold 138 | 139 | Review: 140 | This paper focuses on accelerating RNN by applying the method 141 | from Blelloch (1990). The application is straightforward and thus 142 | technical novelty of this paper is limited. But the results are 143 | impressive. 144 | 145 | One concern is the proposed technique is only applied for few types of 146 | RNNs which may limit its applications in practice. Could the authors 147 | comment on this potential limitation? 148 | 149 | Confidence: 3: The reviewer is fairly confident that the evaluation is 150 | correct 151 | 152 | Rebuttal 1: 153 | 154 | We contest the limited technical novelty of this work. It is true that 155 | parallel scan is "a key primitive in many parallel algorithms"[1] and 156 | has been heavily studied and optimized. Parallel linear recurrence is 157 | a lesser known application of the widely popular parallel scan 158 | algorithm. Neural nets are hugely dependent on high performance 159 | parallel computational primitives such as matrix multiplication and 160 | convolution. We believe the first application of this classic parallel 161 | algorithm to a field dependent on fast parallel algorithms is a novel 162 | idea; otherwise someone else would have published this paper in the 163 | previous 30+ years that both parallel linear recurrence and RNNs have 164 | existed. 165 | 166 | Beyond the new architectures introduced in the paper, we applied 167 | parallel linear recurrence (PLR) to SRU and QRNN and note that it 168 | could also be applied to strongly-typed RNNs. Further, we show that 169 | PLR can also accelerate (the currently uninvestigated) architectures 170 | involving on h_t = A_t h_{t-1} + x_t for square matrices A_t. 171 | 172 | The broader question is "how limiting is it that PLR cannot accelerate 173 | LSTMs, GRUs, vanilla RNNs, or other non-linear RNN models?". We do not 174 | think this will limit the applicability of PLR within RNNs. A 175 | significant amount of recent research (listed below in [2]) has matched 176 | or surpassed the performance of non-linear RNNs with models with only 177 | linear sequential dependency. Given this body of research, our belief 178 | has shifted from "RNNs depend on sequential non-linearity" to "there 179 | is no evidence that sequential non-linearity is necessary, and there 180 | is a fair amount of evidence it is not necessary". With this in mind, 181 | we believe PLR's incompatibility with non-linear RNNs is not a major 182 | practical limitation as we expect linear surrogate RNNs to continue 183 | growing in popularity due to their fast training times and good 184 | performance. We also think this work will accelerate the growing 185 | popularity of linear surrogate RNNs. 186 | 187 | [1] 188 | http://people.cs.vt.edu/yongcao/teaching/cs5234/spring2013/slides/Lecture10.pdf 189 | 190 | [2] 191 | Sequential models with linear dependendencies with experimental 192 | performance on par with non-linear RNNs. Most models listed trained in 193 | significantly less time than non-linear RNN. 194 | 195 | Strongly-typed RN https://arxiv.org/abs/1602.02218 (language 196 | modelling) 197 | 198 | ByteNet https://arxiv.org/abs/1610.10099 (state of the art (SotA) 199 | character level language model on Hutter Prize, SotA character to 200 | character machine on WMT) 201 | 202 | Quasi-RNN https://arxiv.org/abs/1611.01576 (sentiment classification, 203 | language modelling, machine translation) 204 | 205 | Convolutional Sequence to Sequence Learning 206 | https://arxiv.org/abs/1705.03122 (machine translation, outperforms 207 | LSTM) 208 | 209 | Attention Is All You Need https://arxiv.org/abs/1706.03762 (SotA 210 | machine translation on WMT) 211 | 212 | WaveNet https://arxiv.org/abs/1609.03499 (high fidelity audio 213 | generation) 214 | 215 | Simple Recurrent Unit https://arxiv.org/abs/1709.02755 (matches or 216 | outperforms LSTM on sequence classification, question answering, 217 | language modelling, machine translation, speech recognition). This 218 | work significantly accelerates already fast SRU training. 219 | -------------------------------------------------------------------------------- /paper/lc2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/lc2.png -------------------------------------------------------------------------------- /paper/learning_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/learning_curves.png -------------------------------------------------------------------------------- /paper/main-arxiv.aux: -------------------------------------------------------------------------------- 1 | \relax 2 | \providecommand\hyper@newdestlabel[2]{} 3 | \bibstyle{abbrvnat} 4 | \providecommand\HyperFirstAtBeginDocument{\AtBeginDocument} 5 | \HyperFirstAtBeginDocument{\ifx\hyper@anchor\@undefined 6 | \global\let\oldcontentsline\contentsline 7 | \gdef\contentsline#1#2#3#4{\oldcontentsline{#1}{#2}{#3}} 8 | \global\let\oldnewlabel\newlabel 9 | \gdef\newlabel#1#2{\newlabelxx{#1}#2} 10 | \gdef\newlabelxx#1#2#3#4#5#6{\oldnewlabel{#1}{{#2}{#3}}} 11 | \AtEndDocument{\ifx\hyper@anchor\@undefined 12 | \let\contentsline\oldcontentsline 13 | \let\newlabel\oldnewlabel 14 | \fi} 15 | \fi} 16 | \global\let\hyper@last\relax 17 | \gdef\HyperFirstAtBeginDocument#1{#1} 18 | \providecommand\HyField@AuxAddToFields[1]{} 19 | \providecommand\HyField@AuxAddToCoFields[2]{} 20 | \citation{sutskever2014sequence} 21 | \citation{amodei2015deep} 22 | \citation{hausknecht2015deep} 23 | \citation{hochreiter1997long} 24 | \citation{cho2014learning} 25 | \citation{keskar2017large} 26 | \citation{diamos2016persistent} 27 | \citation{balduzzi2016strongly} 28 | \citation{bradbury2017quasi} 29 | \citation{kalchbrenner2016neural} 30 | \citation{gehring2017convolutional} 31 | \citation{van2016wavenet} 32 | \@writefile{toc}{\contentsline {section}{\numberline {1}Introduction}{1}{section.1}} 33 | \citation{ladner1980parallel} 34 | \citation{blelloch1990prefix} 35 | \citation{blelloch1990prefix} 36 | \@writefile{toc}{\contentsline {section}{\numberline {2}Parallel linear recurrence}{2}{section.2}} 37 | \@writefile{loa}{\contentsline {algorithm}{\numberline {1}{\ignorespaces Parallel linear recurrence on $p$ processors}}{2}{algorithm.1}} 38 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.1}Theoretical performance}{2}{subsection.2.1}} 39 | \citation{abadi2016tensorflow} 40 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.2}Backpropagation}{3}{subsection.2.2}} 41 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.3}Implementation}{3}{subsection.2.3}} 42 | \@writefile{toc}{\contentsline {section}{\numberline {3}Models}{3}{section.3}} 43 | \citation{bradbury2017quasi} 44 | \citation{balduzzi2016strongly} 45 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.1}Gated impulse linear recurrent layer}{4}{subsection.3.1}} 46 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {3.1.1}Impact on effective "batch size"}{4}{subsubsection.3.1.1}} 47 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.2}Linear surrogate RNNs}{4}{subsection.3.2}} 48 | \citation{orchard2015converting} 49 | \citation{lecun1998mnist} 50 | \citation{kingma2014adam} 51 | \citation{glorot2010understanding} 52 | \citation{abadi2016tensorflow} 53 | \@writefile{toc}{\contentsline {section}{\numberline {4}Experiments}{5}{section.4}} 54 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.1}Computational performance}{5}{subsection.4.1}} 55 | \@writefile{lof}{\contentsline {figure}{\numberline {1}{\ignorespaces Throughput comparison between LSTM-256-256 with LS-LSTM-256-256. The LSTM only has a single row of data because its throughput is independent of sequence length. Entries are missing from the LS-LSTM table because there was not enough memory on the GPU to handle such large batch sizes and sequences.}}{6}{figure.1}} 56 | \newlabel{fig:tp_perf}{{1}{6}{Throughput comparison between LSTM-256-256 with LS-LSTM-256-256. The LSTM only has a single row of data because its throughput is independent of sequence length. Entries are missing from the LS-LSTM table because there was not enough memory on the GPU to handle such large batch sizes and sequences}{figure.1}{}} 57 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.2}Training performance}{6}{subsection.4.2}} 58 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.3}Test performance}{6}{subsection.4.3}} 59 | \@writefile{lof}{\contentsline {figure}{\numberline {2}{\ignorespaces The LSTM models had a direct relationship between sequence length and training curve: the shorter the sequence length, the faster the initial learning and the higher the final loss. The faster initial learning is explained by the decreased latency of optimization steps and the higher final loss is explained by the inability to learn long dependencies. The LS-LSTM has 234 units per hidden layer to have the same number of parameters as the 256 unit LSTMs.}}{7}{figure.2}} 60 | \newlabel{fig:learning_curves}{{2}{7}{The LSTM models had a direct relationship between sequence length and training curve: the shorter the sequence length, the faster the initial learning and the higher the final loss. The faster initial learning is explained by the decreased latency of optimization steps and the higher final loss is explained by the inability to learn long dependencies. The LS-LSTM has 234 units per hidden layer to have the same number of parameters as the 256 unit LSTMs}{figure.2}{}} 61 | \@writefile{lot}{\contentsline {table}{\numberline {1}{\ignorespaces N-MNIST test results}}{7}{table.1}} 62 | \newlabel{test-results}{{1}{7}{N-MNIST test results}{table.1}{}} 63 | \@writefile{toc}{\contentsline {section}{\numberline {5}Conclusion}{7}{section.5}} 64 | -------------------------------------------------------------------------------- /paper/main-arxiv.out: -------------------------------------------------------------------------------- 1 | \BOOKMARK [1][-]{section.1}{Introduction}{}% 1 2 | \BOOKMARK [1][-]{section.2}{Parallel linear recurrence}{}% 2 3 | \BOOKMARK [2][-]{subsection.2.1}{Theoretical performance}{section.2}% 3 4 | \BOOKMARK [2][-]{subsection.2.2}{Backpropagation}{section.2}% 4 5 | \BOOKMARK [2][-]{subsection.2.3}{Implementation}{section.2}% 5 6 | \BOOKMARK [1][-]{section.3}{Models}{}% 6 7 | \BOOKMARK [2][-]{subsection.3.1}{Gated impulse linear recurrent layer}{section.3}% 7 8 | \BOOKMARK [3][-]{subsubsection.3.1.1}{Impact on effective "batch size"}{subsection.3.1}% 8 9 | \BOOKMARK [2][-]{subsection.3.2}{Linear surrogate RNNs}{section.3}% 9 10 | \BOOKMARK [1][-]{section.4}{Experiments}{}% 10 11 | \BOOKMARK [2][-]{subsection.4.1}{Computational performance}{section.4}% 11 12 | \BOOKMARK [2][-]{subsection.4.2}{Training performance}{section.4}% 12 13 | \BOOKMARK [2][-]{subsection.4.3}{Test performance}{section.4}% 13 14 | \BOOKMARK [1][-]{section.5}{Conclusion}{}% 14 15 | -------------------------------------------------------------------------------- /paper/main-arxiv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/main-arxiv.pdf -------------------------------------------------------------------------------- /paper/main-arxiv.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} 2 | 3 | % if you need to pass options to natbib, use, e.g.: 4 | \PassOptionsToPackage{numbers, compress}{natbib} 5 | % before loading nips_2017 6 | % 7 | % to avoid loading the natbib package, add option nonatbib: 8 | % \usepackage[nonatbib]{nips_2017} 9 | 10 | %\usepackage{nips_2017} 11 | \usepackage[final]{nips_2017} 12 | \bibliographystyle{abbrvnat} 13 | 14 | % to compile a camera-ready version, add the [final] option, e.g.: 15 | %\usepackage[final]{nips_2017} 16 | 17 | \usepackage[utf8]{inputenc} % allow utf-8 input 18 | \usepackage[T1]{fontenc} % use 8-bit T1 fonts 19 | \usepackage{hyperref} % hyperlinks 20 | \usepackage{url} % simple URL typesetting 21 | \usepackage{booktabs} % professional-quality tables 22 | \usepackage{amsfonts} % blackboard math symbols 23 | \usepackage{nicefrac} % compact symbols for 1/2, etc. 24 | \usepackage{microtype} % microtypography 25 | \usepackage{amsmath} 26 | \usepackage{algorithm} 27 | \usepackage{algpseudocode} 28 | \usepackage{graphicx} 29 | 30 | \providecommand{\scan}{\text{SCAN}} 31 | \providecommand{\reduce}{\text{REDUCE}} 32 | % declaration of the new block 33 | \algblock{ParFor}{EndParFor} 34 | % customising the new block 35 | \algnewcommand\algorithmicparfor{\textbf{parfor}} 36 | \algnewcommand\algorithmicpardo{\textbf{do}} 37 | \algnewcommand\algorithmicendparfor{\textbf{end\ parfor}} 38 | \algrenewtext{ParFor}[1]{\algorithmicparfor\ #1\ \algorithmicpardo} 39 | \algrenewtext{EndParFor}{\algorithmicendparfor} 40 | 41 | \title{Parallelizing Linear Recurrent Neural Nets Over Sequence Length} 42 | 43 | % The \author macro works with any number of authors. There are two 44 | % commands used to separate the names and addresses of multiple 45 | % authors: \And and \AND. 46 | % 47 | % Using \And between authors leaves it to LaTeX to determine where to 48 | % break the lines. Using \AND forces a line break at that point. So, 49 | % if LaTeX puts 3 of 4 authors names on the first line, and the last 50 | % on the second line, try using \AND instead of \And before the third 51 | % author name. 52 | 53 | \author{ 54 | Eric Martin \\%\thanks{Use footnote for providing further 55 | %information about author (webpage, alternative 56 | %address)---\emph{not} for acknowledging funding agencies.} \\ 57 | %Some Affiliation\\ 58 | \texttt{eric@ericmart.in} 59 | %% examples of more authors 60 | \And 61 | Chris Cundy \\ 62 | \texttt{c.cundy@berkeley.edu} 63 | %%\texttt{some@email.com} 64 | %% Affiliation \\ 65 | %% Address \\ 66 | %% \texttt{email} \\ 67 | %% \AND 68 | %% Coauthor \\ 69 | %% Affiliation \\ 70 | %% Address \\ 71 | %% \texttt{email} \\ 72 | %% \And 73 | %% Coauthor \\ 74 | %% Affiliation \\ 75 | %% Address \\ 76 | %% \texttt{email} \\ 77 | %% \And 78 | %% Coauthor \\ 79 | %% Affiliation \\ 80 | %% Address \\ 81 | %% \texttt{email} \\ 82 | } 83 | 84 | \begin{document} 85 | % \nipsfinalcopy is no longer used 86 | 87 | \maketitle 88 | 89 | \begin{abstract} 90 | Recurrent neural networks (RNNs) are widely used to model sequential data 91 | but their non-linear dependencies between sequence elements prevent parallelizing 92 | training over sequence length. We show the training of RNNs with only 93 | linear sequential dependencies can be parallelized over the 94 | sequence length using the parallel scan algorithm, leading to rapid training 95 | on long sequences with small minibatch size. We abstract prior linear sequence models 96 | into a new framework of linear surrogate RNNs and develop 97 | a linear surrogate long short-term memory (LS-LSTM) powered by a parallel linear recurrence 98 | CUDA kernel we implemented. 99 | We evaluate the LS-LSTM on a long sequence noisy autoregressive task and find 100 | the LS-LSTM achieves slightly superior train and test performance to a similar 101 | sized LSTM in 4x less training time. We analyze latency and throughput of 102 | the LS-LSTM and find the LS-LSTM reaches up to 103 | 175x the throughput of the LSTM in the small minibatch long sequence regime. 104 | \end{abstract} 105 | 106 | \section{Introduction} 107 | 108 | Recurrent neural networks (RNNs) are widely used for sequence modelling tasks in domains such as 109 | natural language processing \cite{sutskever2014sequence}, 110 | speech recognition \cite{amodei2015deep}, 111 | and reinforcement learning \cite{hausknecht2015deep}. Most RNNs, including popular variants such as long short-term memories (LSTMs) \cite{hochreiter1997long} and gated recurrent units (GRUs) \cite{cho2014learning}, contain a non-linear dependency 112 | between sequential inputs. These non-linear dependencies create a very flexible class of 113 | models but limit the feasibility of training RNNs on long sequences as each sequence element 114 | must be processed sequentially. 115 | Modelling sequences of thousands to millions of elements is important to domains 116 | such as robotics, remote sensing, control systems, speech recognition, medicine, and finance. 117 | 118 | The RNN serial evaluation inefficiency problem is usually mitigated by parallelizing the forward 119 | and backward pass over a minibatch of inputs. Without minibatches, RNN evaluation is a sequence of matrix-vector multiplications. Minibatches transform RNN computation into a sequence of more efficient matrix-matrix multiplications, but minibatches within RNNs present many issues. 120 | RNN model size is often limited by GPU memory size, and running a forward 121 | and backward pass on a minibatch requires memory linear in 122 | the minibatch size. 123 | Grouping data into minibatches increases the latency of each pass and reduces the rate of optimization steps. Finally, training with larger minibatches damages generalization ability 124 | \cite{keskar2017large}. Persistent RNNs \cite{diamos2016persistent} use a novel implementation that can achieve high GPU utilization with very small 125 | minibatch sizes when the recurrent state is larger than 500 elements, but even 126 | persistent RNNs become limited by the serial evaluation 127 | inefficiency at smaller hidden sizes. 128 | 129 | Numerous prior works have shown strong performance from neural sequential models with only linear dependence on earlier sequence elements. \citet{balduzzi2016strongly} investigated RNNs with only elementwise linear recurrence relations 130 | $h_t = \alpha_t \odot h_{t-1} + (1-\alpha_t) \odot x_t$ and developed 131 | linear variants of LSTM and GRU that perform similarly to 132 | standard non-linear RNNs on text generation tasks. \citet{bradbury2017quasi}, \citet{kalchbrenner2016neural}, \citet{gehring2017convolutional}, and \citet{van2016wavenet} have successfully applied networks of convolutions over sequences for tasks 133 | such as machine translation, language modelling, and audio generation. 134 | These works have observed up to an order of magnitude 135 | increase in training throughput compared to RNN alternatives. Convolutional sequence models typically rely on either an attention mechanism or a (possibly linear) recurrent layer 136 | to integrate information at scales larger than the filter width. Introduction of a recurrent 137 | layer prevents full parallelization over the sequence length while attention mechanisms are 138 | expensive to apply on long sequences in online inference use cases. 139 | One dimensional convolution can be viewed as a learnable linear 140 | finite impulse response (FIR) filter with a parallel evaluation algorithm, while 141 | linear recurrence is a learnable linear infinite impulse response (IIR). This work 142 | parallelizes evaluation of linear recurrences through application of the parallel 143 | scan algorithm. 144 | 145 | Scans and reductions are computations involving repeated application of a binary 146 | operator $\oplus$ over an array of data. Computing the sum or maximum 147 | of an array is an example of a reduction, while a cumulative sum is a common 148 | example of a scan operation. Throughout this work, the scan of $\oplus$ with 149 | initial value $b$ is defined as 150 | \begin{align*} 151 | \scan(\oplus, [a_1, a_2, ..., a_n], b) = [(a_1 \oplus b), (a_2 \oplus a_1 \oplus b), ..., (a_n \oplus a_{n-1} ... \oplus a_1 \oplus b)] 152 | \end{align*} 153 | The reduction of $\oplus$ over array $A$ and initial value $b$ is denoted 154 | $\reduce(\oplus, A, b)$ and is the final element of $\scan(\oplus, A, b)$. 155 | Despite their dependent computation graph, algorithms exist to parallelize scans 156 | and reductions when $\oplus$ is associative \cite{ladner1980parallel}. 157 | 158 | \citet{blelloch1990prefix} shows that first order recurrences of the form 159 | $h_t = (\Lambda_t \otimes h_{t-1}) \oplus x_t$ can be parallelized with 160 | the parallel scan algorithm if three conditions are met: 161 | 162 | \begin{enumerate} 163 | \item $\oplus$ is associative: $(a \oplus b) \oplus c = a \oplus (b \oplus c)$ 164 | \item $\otimes$ is semiassociative: there exists a binary associative operator 165 | $\odot$ such that $a \otimes (b \otimes c) = (a \odot b) \otimes c$ 166 | \item $\otimes$ distributes over $\oplus$: $a\otimes(b\oplus c) = (a\otimes b) \oplus (a \otimes c)$ 167 | \end{enumerate} 168 | 169 | Our primary contribution is the application of the parallel linear recurrence algorithm to RNNs. 170 | 171 | \section{Parallel linear recurrence} 172 | 173 | Using elementwise vector addition as $x \oplus y= x+y$, matrix-vector multiplication as $A \otimes x=Ax$, 174 | and matrix-matrix multiplication as $A \odot B=AB$ satisfies Blelloch's three conditions, thus allowing linear 175 | recurrence $h_t = \Lambda_t h_{t-1} + x_t$ to be evaluated in parallel over $t$ for 176 | vectors $x_t$ and square matrices $\Lambda_t$. As the method is 177 | essential to this work, Algorithm 1 presents the parallel linear recurrence algorithm 178 | for the interested reader. 179 | \begin{algorithm} 180 | \caption{Parallel linear recurrence on $p$ processors} 181 | \begin{algorithmic}[1] 182 | \State Let $y = [(\Lambda_1, x_1), (\Lambda_2, x_2), ..., (\Lambda_T, x_T)]$ 183 | \State Let binary operator $\bullet$ act as $(\Lambda, x) \bullet h = \Lambda h + x$ 184 | \State Let $S_0=1, S_i < E_i, E_i + 1 = S_{i+1}, E_{p-1}=T$ for $i$ in $0,p-1$ 185 | 186 | \\ 187 | \ParFor{$i \gets 0,p-1$} 188 | \State $P_i = \reduce(\odot, \Lambda_{S_i:E_i}, I)$ 189 | \State $R_i = \reduce(\bullet, y_{S_i:E_i}, 0)$ 190 | \EndParFor 191 | 192 | \\ 193 | \State Let $z = [(P_0, R_0), (P_1, R_1), ..., (P_p, R_p)]$. 194 | \State $C = \scan(\bullet, z, h_0)$ \Comment{compute $C_i = P_i C_{i-1} + R_i$ with $C_{-1}=h_0$} 195 | 196 | \\ 197 | \ParFor{$i \gets 0,p-1$} 198 | \State $h_{S_i:E_i} = \scan(\bullet, y_{S_i:E_i}, C_{i-1})$ 199 | \EndParFor 200 | 201 | \State \Return $h$ 202 | \end{algorithmic} 203 | \end{algorithm} 204 | 205 | \subsection{Theoretical performance} 206 | The cost of a serial scan over a sequence of length $T$ is 207 | $C_\text{sscan} = (C_\otimes + C_\oplus)T$, compared to the parallel scan cost 208 | $C_\text{pscan} = 2(C_\odot + C_\otimes + C_\oplus)(T/p + \lg p)$ \cite{blelloch1990prefix}. 209 | If $h_t$ is a vector of dimension $n$ then 210 | $C_\odot=n^3, C_\otimes=n^2, C_\oplus=n$ giving 211 | $C_\text{pscan} = 2(n^3 + n^2 + n)(T/p + \lg p)$ and 212 | $C_\text{sscan} = (n^2 + n)T$. The $n^3$ cost of the matrix 213 | multiplication in the parallel algorithm can destroy any parallel speedups for 214 | sufficiently large hidden states and lead to a slower algorithm overall. 215 | 216 | To avoid this problem, we will only consider diagonal matrices $\Lambda_t$, in which 217 | case matrix-matrix and matrix-vector have cost $n$ and 218 | $C_\text{pscan}=6n(T/p + \lg p)$ and $C_\text{sscan}=2nT$. Assuming $p \ll T$, then 219 | $C_\text{pscan} \le C_\text{sscan}$ when $p \ge 3$. 220 | As we are only considering diagonal matrices, the 221 | linear recurrence will be written $h_t = \lambda_t \odot h_{t-1} + x_t$ where 222 | $\odot$ indicates elementwise multiplication. 223 | 224 | Limiting $\Lambda_t$ to be diagonal may seem like a severe constraint but there are 225 | several reasons to do so beyond the unfavorable parallelization performance. Relatively few neural 226 | network models use separate recurrent matrices for each sequence element and using these 227 | separate matrices would require potentially prohibitive $n^2T$ memory. Applying 228 | the same matrix $\Lambda$ to each sequence element is also unappealing considering that a matrix 229 | multiplication can be thought of as a rotation and a scaling. The same rotation at every 230 | element seems unlikely to be useful, and the scaling is exactly what's captured in diagonal 231 | vectors $\lambda_t$. Recurrent coefficient vectors $\lambda_t$ provide enough flexibility 232 | to implement schemes such as exponential moving averages or a gating mechanism. 233 | 234 | \subsection{Backpropagation} 235 | \begin{align*} 236 | \nabla_{h_T}L &= \frac{\partial L}{\partial h_T} \\ 237 | \nabla_{h_t}L &= \frac{\partial h_{t+1}}{\partial h_t} \odot \nabla_{h_{t+1}} L + \frac{\partial L}{\partial h_t} \\ 238 | &= \lambda_{t+1} \odot \nabla_{h_{t+1}} L + \frac{\partial L}{\partial h_t} \\ 239 | \nabla_{\lambda_t}L &= \frac{\partial h_t}{\partial\lambda_t} \odot \nabla_{h_t}L = h_{t-1} \odot \nabla_{h_t}L \\ 240 | \nabla_{x_t}L &= \nabla_{h_t} L \\ 241 | \nabla_{h_0}L &= \frac{\partial h_1}{\partial h_0} \odot \nabla_{h_1} L = \lambda_1 \odot \nabla_{h_1} L 242 | \end{align*} 243 | 244 | The backpropagation equations center around a linear recurrence over $\frac{\partial L}{\partial h_t}$ in the reverse order of the original sequence. This allows for parallelizing both the forwards and backwards pass of a linear RNN over the sequence length. 245 | 246 | \subsection{Implementation} 247 | A modern high-end NVIDIA GPU consists of between 640 and 3200 concurrently 248 | executing warps. Each warp operates on 32 single precision floating point numbers 249 | in parallel. 250 | 251 | This work implemented parallel linear recurrence as a CUDA kernel with 252 | bindings into the TensorFlow \cite{abadi2016tensorflow} framework. Each warp acts as a processor, which means algorithmic $p$ is up to 3200 and the theoretical parallelization speedup factor is up to several hundred. 253 | The 32 lanes of each warp work on different 254 | elements of the recurrence vector in parallel. These implementation details mean that 255 | peak performance is only obtained on sequences of at least several thousand 256 | steps on at least a 32 element vector. 257 | 258 | \section{Models} 259 | Parallel linear recurrence can be used to construct a wide variety of differentiable modules that can be evaluated in parallel. Common applications of linear recurrence include gating schemes and exponential moving averages. Although linear recurrence values can depend only linearly on previous elements, the stacking of linear recurrent layers separated by non-linearities allows for a non-linear dependence on the past. In this sense the non-linear depth of a linear recurrent network is the number of layers and not the sequence length. 260 | 261 | \subsection{Gated impulse linear recurrent layer} 262 | A gated impulse linear recurrent (GILR) layer transforms its $m$ dimensional inputs $x_t$ into a sequence of $n$ dimensional hidden states $h_t$: 263 | \begin{align*} 264 | g_t &= \sigma(Ux_t + b_g) \\ 265 | i_t &= \tau(Vx_t + b_z) \\ 266 | h_t &= g_t \odot h_{t-1} + (1-g_t)\odot i_t 267 | \end{align*} 268 | A GILR layer applies the same non-linear transform to each sequence element and then accumulates the sequence elements with a non-linear gating mechanism. Gate $g_t$ uses the sigmoid activation function to have values in [0,1] for reasonable gating semantics, while impulse $i_t$ can use any activation function $\tau$. Stacking GILR layers allows for rich non-linear dependence on previous events while still taking advantage of fast parallel sequence evaluation. 269 | 270 | \subsubsection{Impact on effective "batch size"} 271 | Consider evaluating a vanilla RNN $h_t = \sigma(Uh_{t-1} + Vx_t + b)$ from $m$ inputs to $n$ hidden units on a sequence of length $T$ with minibatch size $b$ using a serial evaluation strategy. At each of $T$ iterations, the naive approach performs two $(b, m) \mathbf{x} (m, n)$ matrix multiplications. Larger matrix multiplications achieve higher throughput due to less IO overhead, so the better approach computes $Vx_t$ for all $t$ ahead of time in a single $(bT, m) \mathbf{x} (m, n)$ matrix multiply. The non-linear recurrence forces even the better approach to perform $T$ potentially small $(b, m) \mathbf{x} (m, n)$ matrix multiplications in serial which makes performance heavily dependent on minibatch size. 272 | 273 | Now consider the GILR, noting that it has the same two matrix-vector multiplications per iteration as the vanilla RNN. $g$ and $i$ can each be evaluated for all $t$ with a single $(bT, m) \mathbf{x} (m, n)$ matrix multiplication each. Given $g$ and $i$, $h$ can be computed using a parallel linear recurrence over $T$ vectors each of $bn$ elements. Rather than $T$ small operations, the GILR can be evaluated over all sequence elements with two large matrix multiplies and a parallel linear recurrence. GILR performance is much less dependent on batch size as the matrix multiplies see an "effective batch size" of $bT$ and $T$ is typically large. 274 | 275 | \subsection{Linear surrogate RNNs} 276 | RNNs learn a transition function $s_t = f(s_{t-1}, x_t)$ which combines previous state $s_{t-1}$ with input $x_t$ to compute current state $s_t$. Non-linear $f$ prevents application of the parallel linear recurrence algorithm and forces slow serial evaluation. To work around this inefficiency, note that $s_t$ serves a dual purpose. In $s_t = f(s_{t-1}, x_t)$, $s_{t-1}$ serves as an input to $f$ summarizing the previous inputs while $s_t$ serves as the output of $f$ to be passed to other layers of the network. If we decouple these uses, we can instead compute $s_t = f(\tilde{s}_{t-1}, x_t)$ with $\tilde{s}_t$ as a linearly computable surrogate for $s_t$. With this linear surrogate, non-linear $f$ can still be evaluated. We refer to this class of model as a linear surrogate RNN (LS-RNN). Quasi-RNNs \cite{bradbury2017quasi} are LS-RNNs using $\tilde{h}_{t-1} = W_k x_{t-k} + ... W_1 x_{t-1}$ 277 | and strongly typed RNNs\cite{balduzzi2016strongly} are LS-RNNs with $\tilde{h}_t=x_{t-1}$. Although not a rule, LS-RNNs can often be parallelized over sequence length with either 278 | convolution or linear recurrence. 279 | 280 | As an example LS-RNN, consider an LSTM: 281 | \begin{align*} 282 | f_t, i_t, o_t &= \sigma(U_{f,i,o} h_{t-1} + V_{f,i,o} x_t + b_{f,i,o}) \\ 283 | z_t &= \tau(U_z h_{t-1} + V_z x_t + b_z) \\ 284 | c_t &= f_t \odot c_{t-1} + i_t \odot z_t \\ 285 | h_t &= o_t \odot c_t 286 | \end{align*} 287 | An LSTM has state $s_t = (h_t, c_t)$. $c_t$ depends only linearly on $c_{t-1}$, so no surrogate is needed for $c_t$. $h_t$ has a non-linear dependence on $h_{t-1}$, so $h_t$ needs a linear surrogate. With a GILR layer as surrogate, the linear surrogate LSTM (LS-LSTM) is 288 | \begin{align*} 289 | f_t, i_t, o_t &= \sigma(U_{f,i,o} \tilde{h}_{t-1} + V_{f,i,o} x_t + b_{f,i,o}) \\ 290 | z_t &= \tau(U_z \tilde{h}_{t-1} + V_z x_t + b_z) \\ 291 | c_t &= f_t \odot c_{t-1} + i_t \odot z_t \\ 292 | h_t &= o_t \odot c_t \\ 293 | g_t &= \sigma(V_g x_t + b_g) \\ 294 | \tilde{h}_t &= g_t \odot \tilde{h}_{t-1} + (1-g_t)\odot \tau(Wx_t + c) 295 | \end{align*} 296 | 297 | For $m$ inputs and hidden size $n$, the LS-LSTM contains $2n(n+m)$ more parameters than the equivalently sized LSTM to handle the mapping from $x$ to $\tilde{h}$. More generally, a LS-RNN contains all of the same parameters as the underlying RNN as well as some additional parameters to compute the linear surrogate. 298 | 299 | \section{Experiments} 300 | Experiments were performed on the N-MNIST \cite{orchard2015converting} dataset. N-MNIST captures the MNIST digit dataset \cite{lecun1998mnist} by panning 301 | an event driven camera over each digit. Each example in N-MNIST is a sequence of single pixel events (x, y, polarity, timestamp) where 302 | x and y each indicate a pixel position in [0, 33], polarity indicates whether the pixel was switching on or off, and timestamp is the time of the event in microseconds. Videos produced from the event data show that positive polarity events are often located on the leading edge of the digit motion and negative polarity events on the trailing edge. Between the 60,000 digits in the training set, N-MNIST contains approximately 250 million pixel events with sequence lengths ranging from 500 to 8000 and averaging 4000 events. 303 | 304 | We attempt to forecast 50 events ahead in N-MNIST using a two layer LSTM with 256 units per layer and a two layer LS-LSTM with 234 units per layer. The LS-LSTM layer size was selected so that it had slightly fewer parameters than the LSTM. Each model transformed the incoming event position into a 40-dimensional embedding vector which was then combined with the polarity to produce a 41 dimensional input. Both models output a $2 \mathbf{x} 34^2$ matrix containing two future event location probability distributions conditioned on the polarity of the future event. The training algorithm only considers the probability distribution of the true future polarity and uses the cross entropy loss function. The Adam \cite{kingma2014adam} optimization algorithm and Glorot \cite{glorot2010understanding} initialization scheme were used. Training on a minibatch of size $b$ with sequence length $T$ consisted of uniformly sampling $b$ N-MNIST sequences and then extracting a single random $T$ element subsequence (and its 50 element ahead forecast) from the sequence. Sequences less than $T$ elements were padded out to length $T$, and there was a "burn-in" of 30 events at the start of each sequence where no predictions were made. An epoch was defined as a pass over as many pixel events as there are in the full training set. 305 | 306 | All experiments were performed using TensorFlow 1.0 \cite{abadi2016tensorflow} on a single Nvidia K80 GPU running for up to 18 hours. The LSTM model was computed using TensorFlow's dynamic\_rnn and BasicLSTMCell routines which are slower than but algorithmically similar to the cuDNN LSTM implementation. 307 | 308 | \subsection{Computational performance} 309 | The computational performance of the LSTM and LS-LSTM models were compared across a wide range of minibatch sizes and sequence lengths. Although the 234 unit LS-LSTM has a similar number of parameters to the 256 unit LSTM, we measured the computational performance of a 256 unit LS-LSTM to observe any impact of the linear surrogate calculation. We define throughput as events/s and a minibatch of size $b$ and length $T$ to be $bT$ events. The serial evaluation of the LSTM causes its throughput to be independent of the sequence length as a doubling of sequence length causes a doubling of runtime. 310 | 311 | Figure \ref{fig:tp_perf} shows the LS-LSTM model achieves a throughput between 1.17x and 175x that of the LSTM, with the greatest relative advantages occurring at small minibatch sizes and long sequences. Notably, the LS-LSTM can achieve a higher throughput by running on one 8192 event sequence at a time than an LSTM running with 256 sequence minibatches. Similar speedups were found for inference. These speedups indicate the use of parallel linear recurrence through a linear surrogate can massively accelerate RNN training. 312 | 313 | \begin{figure}[t] 314 | \centering 315 | \includegraphics[width=12cm]{heatmap.png} 316 | \caption{Throughput comparison between LSTM-256-256 with LS-LSTM-256-256. The LSTM only has a single row of data because its throughput is independent of sequence length. Entries are missing from the LS-LSTM table because there was not enough memory on the GPU to handle such large batch sizes and sequences.} 317 | \label{fig:tp_perf} 318 | \end{figure} 319 | 320 | \subsection{Training performance} 321 | The speed of neural net training is not solely determined by the training throughput but also by the frequency of optimization steps. As an example, a batch method may process inputs at the same rate as a minibatch method, but the minibatch method will generally converge much faster on large datasets. On an infinite dataset, the batch method never takes a single optimization step regardless of its throughput. With this example in mind, it is clear that achieving fast training is a balancing act between training data throughput and optimization step latency and that training performance should be evaluated with training curves and not just throughput numbers such as time per epoch. 322 | 323 | The LSTM and LS-LSTM offer differ latency and throughput tradeoffs. LSTM throughput depends only on minibatch size but LSTM latency depends on both minibatch size and sequence length. Experiments were conducted training the LSTM models with minibatch size 256 and sequence lengths 128, 256, 512, and 1024. Figure \ref{fig:learning_curves} shows the smaller sequence lengths led to faster initial learning but inferior final performance. 324 | 325 | The throughput and latency of the LS-LSTM are both influenced by batch size and sequence length. Several LS-LSTMs were trained with batch size 16 and sequence length 1024. This combination was chosen because of the nearly maximum throughput, the low latency, and the ease of building a minibatch given the distribution of sequence lengths in the data set. The 234 unit LS-LSTM reaches a better training loss in roughly 4.5 hours than any of the LSTMs could reach in 18 hours. This experimental evidence indicates the LS-LSTM is as powerful of a model as the LSTM and can be trained in a fraction of the time. 326 | 327 | \begin{figure}[t] 328 | \centering 329 | \includegraphics[width=12cm]{lc2.png} 330 | \caption{The LSTM models had a direct relationship between sequence length and training curve: the shorter the sequence length, the faster the initial learning and 331 | the higher the final loss. The faster initial learning is explained by the decreased 332 | latency of optimization steps and the higher final loss is explained by the 333 | inability to learn long dependencies. The LS-LSTM has 234 units per hidden 334 | layer to have the same number of parameters as the 256 unit LSTMs.} 335 | \label{fig:learning_curves} 336 | \end{figure} 337 | 338 | \subsection{Test performance} 339 | \begin{table}[h] 340 | \caption{N-MNIST test results} 341 | \label{test-results} 342 | \centering 343 | \begin{tabular}{lllll} 344 | \toprule 345 | & Train Cross Entropy & Test Cross Entropy & Top-5 Accuracy & Top-20 Accuracy \\ 346 | \midrule 347 | LS-LSTM & \textbf{4.567} & 4.856 & 10.51\% & 36.70\% \\ 348 | LSTM, seqlen=128 & 4.664 & \textbf{4.846} & \textbf{10.62\%} & \textbf{37.08\%} \\ 349 | LSTM, seqlen=256 & 4.620 & 4.906 & 10.11\% & 35.40\% \\ 350 | LSTM, seqlen=512 & 4.602 & 4.883 & 10.45\% & 36.32\% \\ 351 | LSTM, seqlen=1024 & 4.611 & 4.880 & 10.32\% & 35.92\% \\ 352 | \bottomrule 353 | \end{tabular} 354 | \end{table} 355 | Beyond cross entropy on train and test, we also evaluated top-$k$ accuracy for $k=5, 20$. A probability distribution is top-$k$ correct if it assigns the realized location one of the 356 | $k$ largest probabilities. N-MNIST contains 1156 pixel positions, so top-5 and top-20 accuracy are equivalent to localizing the future to 0.43\% and 1.7\% of pixels. 357 | 358 | No regularization was attempted and all of the models overfit, as indicated by the best test performance from the model with the worst train performance. Table \ref{test-results} contains the full test results. Our focus was the fast training of powerful models, and we leave regularizing parallel linear recurrences and LS-RNNs to future work. Although not tested, it is possible that the LS-LSTM generalized better than the LSTMs trained on long sequences due to the much smaller minibatch size of LS-LSTM leading to a wider minima. 359 | 360 | \section{Conclusion} 361 | Parallel linear recurrence is an extremely powerful algorithm and the LS-LSTM is just one 362 | of many possible models that can be built with it. Future applications of parallel linear recurrence could include sequences orders of magnitude longer than N-MNIST, the development of parallel computable differentiable memory modules, and the combination of linear recurrence with convolutional sequence models. Besides future research, existing models such as Quasi-RNNs and strongly typed RNNs that already contain linear recurrences can immediately benefit from parallel linear recurrence. This work demonstrates the LS-LSTM significantly accelerates the training of small to medium sized LSTMs. Although similar techniques have been used before, the now explicit concept of linear surrogacy provides a framework for future development and analysis of fast sequence models. 363 | 364 | We intend to expand upon the experiments section and open-source the parallel linear recurrence kernel in the near future. 365 | 366 | \subsubsection*{Acknowledgments} 367 | We would like to acknowledge Kevin Bowers, Alex Meiburg, JD Co-Reyes, Carson McNeil, and several others for fruitful conversations and guidance. 368 | 369 | 370 | \end{document} 371 | -------------------------------------------------------------------------------- /paper/main.aux: -------------------------------------------------------------------------------- 1 | \relax 2 | \providecommand\hyper@newdestlabel[2]{} 3 | \bibstyle{natbib} 4 | \providecommand\HyperFirstAtBeginDocument{\AtBeginDocument} 5 | \HyperFirstAtBeginDocument{\ifx\hyper@anchor\@undefined 6 | \global\let\oldcontentsline\contentsline 7 | \gdef\contentsline#1#2#3#4{\oldcontentsline{#1}{#2}{#3}} 8 | \global\let\oldnewlabel\newlabel 9 | \gdef\newlabel#1#2{\newlabelxx{#1}#2} 10 | \gdef\newlabelxx#1#2#3#4#5#6{\oldnewlabel{#1}{{#2}{#3}}} 11 | \AtEndDocument{\ifx\hyper@anchor\@undefined 12 | \let\contentsline\oldcontentsline 13 | \let\newlabel\oldnewlabel 14 | \fi} 15 | \fi} 16 | \global\let\hyper@last\relax 17 | \gdef\HyperFirstAtBeginDocument#1{#1} 18 | \providecommand\HyField@AuxAddToFields[1]{} 19 | \providecommand\HyField@AuxAddToCoFields[2]{} 20 | \citation{sutskever2014sequence} 21 | \citation{amodei2015deep} 22 | \citation{hausknecht2015deep} 23 | \citation{hochreiter1997long} 24 | \citation{cho2014learning} 25 | \citation{keskar2017large} 26 | \citation{diamos2016persistent} 27 | \citation{balduzzi2016strongly} 28 | \citation{bradbury2017quasi} 29 | \citation{kalchbrenner2016neural} 30 | \citation{gehring2017convolutional} 31 | \citation{van2016wavenet} 32 | \@writefile{toc}{\contentsline {section}{\numberline {1}Introduction}{1}{section.1}} 33 | \citation{ladner1980parallel} 34 | \citation{blelloch1990prefix} 35 | \citation{bradbury2017quasi} 36 | \citation{lei2017} 37 | \citation{blelloch1990prefix} 38 | \@writefile{toc}{\contentsline {section}{\numberline {2}Parallel linear recurrence}{2}{section.2}} 39 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.1}Theoretical performance}{2}{subsection.2.1}} 40 | \newlabel{alg:plr}{{2}{3}{Parallel linear recurrence}{section.2}{}} 41 | \@writefile{loa}{\contentsline {algorithm}{\numberline {1}{\ignorespaces Parallel linear recurrence on $p$ processors}}{3}{algorithm.1}} 42 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.2}Backpropagation}{3}{subsection.2.2}} 43 | \citation{abadi2016tensorflow} 44 | \citation{bradbury2017quasi} 45 | \citation{balduzzi2016strongly} 46 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.3}Implementation}{4}{subsection.2.3}} 47 | \@writefile{toc}{\contentsline {section}{\numberline {3}Models}{4}{section.3}} 48 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.1}Gated impulse linear recurrent layer}{4}{subsection.3.1}} 49 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {3.1.1}Impact on effective "batch size"}{4}{subsubsection.3.1.1}} 50 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.2}Linear surrogate RNNs}{4}{subsection.3.2}} 51 | \newlabel{sec:ls-rnns}{{3.2}{4}{Linear surrogate RNNs}{subsection.3.2}{}} 52 | \citation{hochreiter1997long} 53 | \@writefile{toc}{\contentsline {section}{\numberline {4}Experiments}{5}{section.4}} 54 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.1}Throughput benchmarks}{5}{subsection.4.1}} 55 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {4.1.1}Kernel performance}{5}{subsubsection.4.1.1}} 56 | \citation{hochreiter1997long} 57 | \@writefile{lot}{\contentsline {table}{\numberline {1}{\ignorespaces Parallel kernel speedup on $m$ features (minibatch size $= 1$)}}{6}{table.1}} 58 | \newlabel{table:kernel-throughput}{{1}{6}{Parallel kernel speedup on $m$ features (minibatch size $= 1$)}{table.1}{}} 59 | \@writefile{lot}{\contentsline {table}{\numberline {2}{\ignorespaces Parallel kernel speedup for a variety of LS-RNNs, implemented as two stacked RNN layers with 256 hidden units. We keep the GPU memory usage constant by fixing $bT = 65,536$ for minibatch size $b$ and sequence length $T$}}{6}{table.2}} 60 | \newlabel{table:rnn-throughput}{{2}{6}{Parallel kernel speedup for a variety of LS-RNNs, implemented as two stacked RNN layers with 256 hidden units. We keep the GPU memory usage constant by fixing $bT = 65,536$ for minibatch size $b$ and sequence length $T$}{table.2}{}} 61 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {4.1.2}Accelerating existing RNN architectures}{6}{subsubsection.4.1.2}} 62 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.2}Synthetic Experiment}{6}{subsection.4.2}} 63 | \citation{hochreiter1997long} 64 | \citation{hochreiter1997long} 65 | \@writefile{lof}{\contentsline {figure}{\numberline {1}{\ignorespaces The structure of the synthetic example and the GILR-LSTM architecture we used to tackle it. We feed in one-hot unit vectors \(x\) which are chosen uniformly at random (with replacement). The class is determined by the very first vector \(x_0\), which has a fixed direction. The sign of \(x_0\) determines the class. In the diagram, each rounded block indicates a cell of the RNN, whilst the square indicates a linear unit.}}{7}{figure.1}} 66 | \newlabel{fig:synthetic_diagram}{{1}{7}{The structure of the synthetic example and the GILR-LSTM architecture we used to tackle it. We feed in one-hot unit vectors \(x\) which are chosen uniformly at random (with replacement). The class is determined by the very first vector \(x_0\), which has a fixed direction. The sign of \(x_0\) determines the class. In the diagram, each rounded block indicates a cell of the RNN, whilst the square indicates a linear unit}{figure.1}{}} 67 | \newlabel{table:synth-table}{{4.2}{7}{Synthetic Experiment}{figure.1}{}} 68 | \@writefile{lot}{\contentsline {table}{\numberline {3}{\ignorespaces Performance of the GILR-LSTM compared to the CuDNN LSTM on problem 2b from \citet {hochreiter1997long}. }}{7}{table.3}} 69 | \bibcite{abadi2016tensorflow}{{1}{2016}{{Abadi et~al.}}{{Abadi, Agarwal, Barham, Brevdo, Chen, Citro, Corrado, Davis, Dean, Devin, et~al.}}} 70 | \@writefile{lof}{\contentsline {figure}{\numberline {2}{\ignorespaces Learning curves for GILR-LSTM and CuDNN LSTM architectures for various sequence lengths. Each plot shows the moving mean and standard deviation of classification accuracy over five training runs, with the exception of a single run for CuDNN LSTM on 1 million sequence length.}}{8}{figure.2}} 71 | \newlabel{fig:synthetic_training}{{2}{8}{Learning curves for GILR-LSTM and CuDNN LSTM architectures for various sequence lengths. Each plot shows the moving mean and standard deviation of classification accuracy over five training runs, with the exception of a single run for CuDNN LSTM on 1 million sequence length}{figure.2}{}} 72 | \@writefile{toc}{\contentsline {section}{\numberline {5}Conclusion}{8}{section.5}} 73 | \bibcite{amodei2015deep}{{2}{2015}{{Amodei et~al.}}{{Amodei, Anubhai, Battenberg, Case, Casper, Catanzaro, Chen, Chrzanowski, Coates, Diamos, et~al.}}} 74 | \bibcite{balduzzi2016strongly}{{3}{2016}{{Balduzzi and Ghifary}}{{}}} 75 | \bibcite{blelloch1990prefix}{{4}{1990}{{Blelloch}}{{}}} 76 | \bibcite{bradbury2017quasi}{{5}{2017}{{Bradbury et~al.}}{{Bradbury, Merity, Xiong, and Socher}}} 77 | \bibcite{cho2014learning}{{6}{2014}{{Cho et~al.}}{{Cho, Van~Merri{\"e}nboer, Gulcehre, Bahdanau, Bougares, Schwenk, and Bengio}}} 78 | \bibcite{diamos2016persistent}{{7}{2016}{{Diamos et~al.}}{{Diamos, Sengupta, Catanzaro, Chrzanowski, Coates, Elsen, Engel, Hannun, and Satheesh}}} 79 | \bibcite{gehring2017convolutional}{{8}{2017}{{Gehring et~al.}}{{Gehring, Auli, Grangier, Yarats, and Dauphin}}} 80 | \bibcite{glorot2010understanding}{{9}{2010}{{Glorot and Bengio}}{{}}} 81 | \bibcite{hausknecht2015deep}{{10}{2015}{{Hausknecht and Stone}}{{}}} 82 | \bibcite{hochreiter1997long}{{11}{1997}{{Hochreiter and Schmidhuber}}{{}}} 83 | \bibcite{kalchbrenner2016neural}{{12}{2016}{{Kalchbrenner et~al.}}{{Kalchbrenner, Espeholt, Simonyan, Oord, Graves, and Kavukcuoglu}}} 84 | \bibcite{keskar2017large}{{13}{2017}{{Keskar et~al.}}{{Keskar, Mudigere, Nocedal, Smelyanskiy, and Tang}}} 85 | \bibcite{kingma2014adam}{{14}{2014}{{Kingma and Ba}}{{}}} 86 | \bibcite{ladner1980parallel}{{15}{1980}{{Ladner and Fischer}}{{}}} 87 | \bibcite{lecun1998mnist}{{16}{1998}{{LeCun et~al.}}{{LeCun, Cortes, and Burges}}} 88 | \bibcite{orchard2015converting}{{17}{2015}{{Orchard et~al.}}{{Orchard, Jayawant, Cohen, and Thakor}}} 89 | \bibcite{sutskever2014sequence}{{18}{2014}{{Sutskever et~al.}}{{Sutskever, Vinyals, and Le}}} 90 | \bibcite{van2016wavenet}{{19}{2016}{{van~den Oord et~al.}}{{van~den Oord, Dieleman, Zen, Simonyan, Vinyals, Graves, Kalchbrenner, Senior, and Kavukcuoglu}}} 91 | \bibcite{lei2017}{{20}{2017}{{Lei and Zhang}}{{}}} 92 | \bibcite{physiobank}{{21}{2000}{{Goldberger et~al.}}{{ Amaral, Glass, Hausdorff, Ivanov, Mark, Mietus, Moody, Peng, Stanley}}} 93 | -------------------------------------------------------------------------------- /paper/main.bbl: -------------------------------------------------------------------------------- 1 | \begin{thebibliography}{19} 2 | \providecommand{\natexlab}[1]{#1} 3 | \providecommand{\url}[1]{\texttt{#1}} 4 | \expandafter\ifx\csname urlstyle\endcsname\relax 5 | \providecommand{\doi}[1]{doi: #1}\else 6 | \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi 7 | 8 | \bibitem[Abadi et~al.(2016)Abadi, Agarwal, Barham, Brevdo, Chen, Citro, 9 | Corrado, Davis, Dean, Devin, et~al.]{abadi2016tensorflow} 10 | M.~Abadi, A.~Agarwal, P.~Barham, E.~Brevdo, Z.~Chen, C.~Citro, G.~S. Corrado, 11 | A.~Davis, J.~Dean, M.~Devin, et~al. 12 | \newblock Tensorflow: Large-scale machine learning on heterogeneous distributed 13 | systems. 14 | \newblock \emph{arXiv preprint arXiv:1603.04467}, 2016. 15 | 16 | \bibitem[Amodei et~al.(2015)Amodei, Anubhai, Battenberg, Case, Casper, 17 | Catanzaro, Chen, Chrzanowski, Coates, Diamos, et~al.]{amodei2015deep} 18 | D.~Amodei, R.~Anubhai, E.~Battenberg, C.~Case, J.~Casper, B.~Catanzaro, 19 | J.~Chen, M.~Chrzanowski, A.~Coates, G.~Diamos, et~al. 20 | \newblock Deep speech 2: End-to-end speech recognition in english and mandarin. 21 | \newblock \emph{arXiv preprint arXiv:1512.02595}, 2015. 22 | 23 | \bibitem[Balduzzi and Ghifary(2016)]{balduzzi2016strongly} 24 | D.~Balduzzi and M.~Ghifary. 25 | \newblock Strongly-typed recurrent neural networks. 26 | \newblock In \emph{Proceedings of The 33rd International Conference on Machine 27 | Learning}, pages 1292--1300, 2016. 28 | 29 | \bibitem[Blelloch(1990)]{blelloch1990prefix} 30 | G.~E. Blelloch. 31 | \newblock Prefix sums and their applications. 32 | \newblock 1990. 33 | 34 | \bibitem[Bradbury et~al.(2017)Bradbury, Merity, Xiong, and 35 | Socher]{bradbury2017quasi} 36 | J.~Bradbury, S.~Merity, C.~Xiong, and R.~Socher. 37 | \newblock Quasi-recurrent neural networks. 38 | \newblock In \emph{International Conference on Learning Representations 39 | (ICLR)}, 2017. 40 | 41 | \bibitem[Cho et~al.(2014)Cho, Van~Merri{\"e}nboer, Gulcehre, Bahdanau, 42 | Bougares, Schwenk, and Bengio]{cho2014learning} 43 | K.~Cho, B.~Van~Merri{\"e}nboer, C.~Gulcehre, D.~Bahdanau, F.~Bougares, 44 | H.~Schwenk, and Y.~Bengio. 45 | \newblock Learning phrase representations using rnn encoder-decoder for 46 | statistical machine translation. 47 | \newblock \emph{arXiv preprint arXiv:1406.1078}, 2014. 48 | 49 | \bibitem[Diamos et~al.(2016)Diamos, Sengupta, Catanzaro, Chrzanowski, Coates, 50 | Elsen, Engel, Hannun, and Satheesh]{diamos2016persistent} 51 | G.~Diamos, S.~Sengupta, B.~Catanzaro, M.~Chrzanowski, A.~Coates, E.~Elsen, 52 | J.~Engel, A.~Hannun, and S.~Satheesh. 53 | \newblock Persistent rnns: Stashing recurrent weights on-chip. 54 | \newblock In \emph{International Conference on Machine Learning}, pages 55 | 2024--2033, 2016. 56 | 57 | \bibitem[Gehring et~al.(2017)Gehring, Auli, Grangier, Yarats, and 58 | Dauphin]{gehring2017convolutional} 59 | J.~Gehring, M.~Auli, D.~Grangier, D.~Yarats, and Y.~N. Dauphin. 60 | \newblock Convolutional sequence to sequence learning. 61 | \newblock \emph{arXiv preprint arXiv:1705.03122}, 2017. 62 | 63 | \bibitem[Glorot and Bengio(2010)]{glorot2010understanding} 64 | X.~Glorot and Y.~Bengio. 65 | \newblock Understanding the difficulty of training deep feedforward neural 66 | networks. 67 | \newblock In \emph{Aistats}, volume~9, pages 249--256, 2010. 68 | 69 | \bibitem[Hausknecht and Stone(2015)]{hausknecht2015deep} 70 | M.~Hausknecht and P.~Stone. 71 | \newblock Deep recurrent q-learning for partially observable mdps. 72 | \newblock In \emph{2015 AAAI Fall Symposium Series}, 2015. 73 | 74 | \bibitem[Hochreiter and Schmidhuber(1997)]{hochreiter1997long} 75 | S.~Hochreiter and J.~Schmidhuber. 76 | \newblock Long short-term memory. 77 | \newblock \emph{Neural computation}, 9\penalty0 (8):\penalty0 1735--1780, 1997. 78 | 79 | \bibitem[Kalchbrenner et~al.(2016)Kalchbrenner, Espeholt, Simonyan, Oord, 80 | Graves, and Kavukcuoglu]{kalchbrenner2016neural} 81 | N.~Kalchbrenner, L.~Espeholt, K.~Simonyan, A.~v.~d. Oord, A.~Graves, and 82 | K.~Kavukcuoglu. 83 | \newblock Neural machine translation in linear time. 84 | \newblock \emph{arXiv preprint arXiv:1610.10099}, 2016. 85 | 86 | \bibitem[Keskar et~al.(2017)Keskar, Mudigere, Nocedal, Smelyanskiy, and 87 | Tang]{keskar2017large} 88 | N.~S. Keskar, D.~Mudigere, J.~Nocedal, M.~Smelyanskiy, and P.~T.~P. Tang. 89 | \newblock On large-batch training for deep learning: Generalization gap and 90 | sharp minima. 91 | \newblock 2017. 92 | 93 | \bibitem[Kingma and Ba(2014)]{kingma2014adam} 94 | D.~Kingma and J.~Ba. 95 | \newblock Adam: A method for stochastic optimization. 96 | \newblock \emph{arXiv preprint arXiv:1412.6980}, 2014. 97 | 98 | \bibitem[Ladner and Fischer(1980)]{ladner1980parallel} 99 | R.~E. Ladner and M.~J. Fischer. 100 | \newblock Parallel prefix computation. 101 | \newblock \emph{Journal of the ACM (JACM)}, 27\penalty0 (4):\penalty0 831--838, 102 | 1980. 103 | 104 | \bibitem[LeCun et~al.(1998)LeCun, Cortes, and Burges]{lecun1998mnist} 105 | Y.~LeCun, C.~Cortes, and C.~J. Burges. 106 | \newblock The mnist database of handwritten digits, 1998. 107 | 108 | \bibitem[Orchard et~al.(2015)Orchard, Jayawant, Cohen, and 109 | Thakor]{orchard2015converting} 110 | G.~Orchard, A.~Jayawant, G.~Cohen, and N.~Thakor. 111 | \newblock Converting static image datasets to spiking neuromorphic datasets 112 | using saccades. 113 | \newblock \emph{arXiv preprint arXiv:1507.07629}, 2015. 114 | 115 | \bibitem[Sutskever et~al.(2014)Sutskever, Vinyals, and 116 | Le]{sutskever2014sequence} 117 | I.~Sutskever, O.~Vinyals, and Q.~V. Le. 118 | \newblock Sequence to sequence learning with neural networks. 119 | \newblock In \emph{Advances in neural information processing systems}, pages 120 | 3104--3112, 2014. 121 | 122 | \bibitem[van~den Oord et~al.(2016)van~den Oord, Dieleman, Zen, Simonyan, 123 | Vinyals, Graves, Kalchbrenner, Senior, and Kavukcuoglu]{van2016wavenet} 124 | A.~van~den Oord, S.~Dieleman, H.~Zen, K.~Simonyan, O.~Vinyals, A.~Graves, 125 | N.~Kalchbrenner, A.~Senior, and K.~Kavukcuoglu. 126 | \newblock Wavenet: A generative model for raw audio. 127 | \newblock \emph{CoRR abs/1609.03499}, 2016. 128 | 129 | \end{thebibliography} 130 | -------------------------------------------------------------------------------- /paper/main.blg: -------------------------------------------------------------------------------- 1 | This is BibTeX, Version 0.99d (TeX Live 2015) 2 | Capacity: max_strings=35307, hash_size=35307, hash_prime=30011 3 | The top-level auxiliary file: main.aux 4 | The style file: abbrvnat.bst 5 | Database file #1: references.bib 6 | Warning--empty journal in blelloch1990prefix 7 | Warning--empty journal in keskar2017large 8 | You've used 19 entries, 9 | 2773 wiz_defined-function locations, 10 | 685 strings with 8219 characters, 11 | and the built_in function-call counts, 10703 in all, are: 12 | = -- 880 13 | > -- 863 14 | < -- 12 15 | + -- 289 16 | - -- 270 17 | * -- 892 18 | := -- 1667 19 | add.period$ -- 56 20 | call.type$ -- 19 21 | change.case$ -- 145 22 | chr.to.int$ -- 19 23 | cite$ -- 40 24 | duplicate$ -- 437 25 | empty$ -- 728 26 | format.name$ -- 301 27 | if$ -- 2213 28 | int.to.chr$ -- 1 29 | int.to.str$ -- 1 30 | missing$ -- 18 31 | newline$ -- 102 32 | num.names$ -- 76 33 | pop$ -- 314 34 | preamble$ -- 1 35 | purify$ -- 126 36 | quote$ -- 0 37 | skip$ -- 394 38 | stack$ -- 0 39 | substring$ -- 246 40 | swap$ -- 66 41 | text.length$ -- 5 42 | text.prefix$ -- 0 43 | top$ -- 0 44 | type$ -- 209 45 | warning$ -- 2 46 | while$ -- 73 47 | width$ -- 0 48 | write$ -- 238 49 | (There were 2 warnings) 50 | -------------------------------------------------------------------------------- /paper/main.dvi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/main.dvi -------------------------------------------------------------------------------- /paper/main.out: -------------------------------------------------------------------------------- 1 | \BOOKMARK [1][-]{section.1}{Introduction}{}% 1 2 | \BOOKMARK [1][-]{section.2}{Parallel linear recurrence}{}% 2 3 | \BOOKMARK [2][-]{subsection.2.1}{Theoretical performance}{section.2}% 3 4 | \BOOKMARK [2][-]{subsection.2.2}{Backpropagation}{section.2}% 4 5 | \BOOKMARK [2][-]{subsection.2.3}{Implementation}{section.2}% 5 6 | \BOOKMARK [1][-]{section.3}{Models}{}% 6 7 | \BOOKMARK [2][-]{subsection.3.1}{Gated impulse linear recurrent layer}{section.3}% 7 8 | \BOOKMARK [3][-]{subsubsection.3.1.1}{Impact on effective "batch size"}{subsection.3.1}% 8 9 | \BOOKMARK [2][-]{subsection.3.2}{Linear surrogate RNNs}{section.3}% 9 10 | \BOOKMARK [1][-]{section.4}{Experiments}{}% 10 11 | \BOOKMARK [2][-]{subsection.4.1}{Throughput benchmarks}{section.4}% 11 12 | \BOOKMARK [3][-]{subsubsection.4.1.1}{Kernel performance}{subsection.4.1}% 12 13 | \BOOKMARK [3][-]{subsubsection.4.1.2}{Accelerating existing RNN architectures}{subsection.4.1}% 13 14 | \BOOKMARK [2][-]{subsection.4.2}{Synthetic Experiment}{section.4}% 14 15 | \BOOKMARK [1][-]{section.5}{Conclusion}{}% 15 16 | -------------------------------------------------------------------------------- /paper/main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/main.pdf -------------------------------------------------------------------------------- /paper/main.tex.blg: -------------------------------------------------------------------------------- 1 | [0] Config.pm:324> INFO - This is Biber 2.1 2 | [1] Config.pm:327> INFO - Logfile is 'main.tex.blg' 3 | [93] biber-darwin:276> INFO - === Tue Sep 12, 2017, 09:52:47 4 | [249] Utils.pm:162> ERROR - Cannot find control file 'main.tex.bcf'! - did you pass the "backend=biber" option to BibLaTeX? 5 | [249] Biber.pm:110> INFO - ERRORS: 1 6 | -------------------------------------------------------------------------------- /paper/medical_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/medical_training.png -------------------------------------------------------------------------------- /paper/nips_2017.sty: -------------------------------------------------------------------------------- 1 | % partial rewrite of the LaTeX2e package for submissions to the 2 | % Conference on Neural Information Processing Systems (NIPS): 3 | % 4 | % - uses more LaTeX conventions 5 | % - line numbers at submission time replaced with aligned numbers from 6 | % lineno package 7 | % - \nipsfinalcopy replaced with [final] package option 8 | % - automatically loads times package for authors 9 | % - loads natbib automatically; this can be suppressed with the 10 | % [nonatbib] package option 11 | % - adds foot line to first page identifying the conference 12 | % 13 | % Roman Garnett (garnett@wustl.edu) and the many authors of 14 | % nips15submit_e.sty, including MK and drstrip@sandia 15 | % 16 | % last revision: March 2017 17 | 18 | \NeedsTeXFormat{LaTeX2e} 19 | \ProvidesPackage{nips_2017}[2017/03/20 NIPS 2017 submission/camera-ready style file] 20 | 21 | % declare final option, which creates camera-ready copy 22 | \newif\if@nipsfinal\@nipsfinalfalse 23 | \DeclareOption{final}{ 24 | \@nipsfinaltrue 25 | } 26 | 27 | % declare nonatbib option, which does not load natbib in case of 28 | % package clash (users can pass options to natbib via 29 | % \PassOptionsToPackage) 30 | \newif\if@natbib\@natbibtrue 31 | \DeclareOption{nonatbib}{ 32 | \@natbibfalse 33 | } 34 | 35 | \ProcessOptions\relax 36 | 37 | % fonts 38 | \renewcommand{\rmdefault}{ptm} 39 | \renewcommand{\sfdefault}{phv} 40 | 41 | % change this every year for notice string at bottom 42 | \newcommand{\@nipsordinal}{31st} 43 | \newcommand{\@nipsyear}{2017} 44 | \newcommand{\@nipslocation}{Long Beach, CA, USA} 45 | 46 | % handle tweaks for camera-ready copy vs. submission copy 47 | \if@nipsfinal 48 | %%\newcommand{\@noticestring}{% 49 | %%\@nipsordinal\/ Conference on Neural Information Processing Systems 50 | %%(NIPS \@nipsyear), \@nipslocation.% 51 | %} 52 | \newcommand{\@noticestring}{} 53 | \else 54 | \newcommand{\@noticestring}{% 55 | Submitted to \@nipsordinal\/ Conference on Neural Information 56 | Processing Systems (NIPS \@nipsyear). Do not distribute.% 57 | } 58 | 59 | % line numbers for submission 60 | \RequirePackage{lineno} 61 | \linenumbers 62 | 63 | % fix incompatibilities between lineno and amsmath, if required, by 64 | % transparently wrapping linenomath environments around amsmath 65 | % environments 66 | \AtBeginDocument{% 67 | \@ifpackageloaded{amsmath}{% 68 | \newcommand*\patchAmsMathEnvironmentForLineno[1]{% 69 | \expandafter\let\csname old#1\expandafter\endcsname\csname #1\endcsname 70 | \expandafter\let\csname oldend#1\expandafter\endcsname\csname end#1\endcsname 71 | \renewenvironment{#1}% 72 | {\linenomath\csname old#1\endcsname}% 73 | {\csname oldend#1\endcsname\endlinenomath}% 74 | }% 75 | \newcommand*\patchBothAmsMathEnvironmentsForLineno[1]{% 76 | \patchAmsMathEnvironmentForLineno{#1}% 77 | \patchAmsMathEnvironmentForLineno{#1*}% 78 | }% 79 | \patchBothAmsMathEnvironmentsForLineno{equation}% 80 | \patchBothAmsMathEnvironmentsForLineno{align}% 81 | \patchBothAmsMathEnvironmentsForLineno{flalign}% 82 | \patchBothAmsMathEnvironmentsForLineno{alignat}% 83 | \patchBothAmsMathEnvironmentsForLineno{gather}% 84 | \patchBothAmsMathEnvironmentsForLineno{multline}% 85 | }{} 86 | } 87 | \fi 88 | 89 | % load natbib unless told otherwise 90 | \if@natbib 91 | \RequirePackage{natbib} 92 | \fi 93 | 94 | % set page geometry 95 | \usepackage[verbose=true,letterpaper]{geometry} 96 | \AtBeginDocument{ 97 | \newgeometry{ 98 | textheight=9in, 99 | textwidth=5.5in, 100 | top=1in, 101 | headheight=12pt, 102 | headsep=25pt, 103 | footskip=30pt 104 | } 105 | \@ifpackageloaded{fullpage} 106 | {\PackageWarning{nips_2016}{fullpage package not allowed! Overwriting formatting.}} 107 | {} 108 | } 109 | 110 | \widowpenalty=10000 111 | \clubpenalty=10000 112 | \flushbottom 113 | \sloppy 114 | 115 | % font sizes with reduced leading 116 | \renewcommand{\normalsize}{% 117 | \@setfontsize\normalsize\@xpt\@xipt 118 | \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@ 119 | \abovedisplayshortskip \z@ \@plus 3\p@ 120 | \belowdisplayskip \abovedisplayskip 121 | \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@ 122 | } 123 | \normalsize 124 | \renewcommand{\small}{% 125 | \@setfontsize\small\@ixpt\@xpt 126 | \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@ 127 | \abovedisplayshortskip \z@ \@plus 2\p@ 128 | \belowdisplayskip \abovedisplayskip 129 | \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@ 130 | } 131 | \renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt} 132 | \renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt} 133 | \renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt} 134 | \renewcommand{\large}{\@setfontsize\large\@xiipt{14}} 135 | \renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}} 136 | \renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}} 137 | \renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}} 138 | \renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}} 139 | 140 | % sections with less space 141 | \providecommand{\section}{} 142 | \renewcommand{\section}{% 143 | \@startsection{section}{1}{\z@}% 144 | {-2.0ex \@plus -0.5ex \@minus -0.2ex}% 145 | { 1.5ex \@plus 0.3ex \@minus 0.2ex}% 146 | {\large\bf\raggedright}% 147 | } 148 | \providecommand{\subsection}{} 149 | \renewcommand{\subsection}{% 150 | \@startsection{subsection}{2}{\z@}% 151 | {-1.8ex \@plus -0.5ex \@minus -0.2ex}% 152 | { 0.8ex \@plus 0.2ex}% 153 | {\normalsize\bf\raggedright}% 154 | } 155 | \providecommand{\subsubsection}{} 156 | \renewcommand{\subsubsection}{% 157 | \@startsection{subsubsection}{3}{\z@}% 158 | {-1.5ex \@plus -0.5ex \@minus -0.2ex}% 159 | { 0.5ex \@plus 0.2ex}% 160 | {\normalsize\bf\raggedright}% 161 | } 162 | \providecommand{\paragraph}{} 163 | \renewcommand{\paragraph}{% 164 | \@startsection{paragraph}{4}{\z@}% 165 | {1.5ex \@plus 0.5ex \@minus 0.2ex}% 166 | {-1em}% 167 | {\normalsize\bf}% 168 | } 169 | \providecommand{\subparagraph}{} 170 | \renewcommand{\subparagraph}{% 171 | \@startsection{subparagraph}{5}{\z@}% 172 | {1.5ex \@plus 0.5ex \@minus 0.2ex}% 173 | {-1em}% 174 | {\normalsize\bf}% 175 | } 176 | \providecommand{\subsubsubsection}{} 177 | \renewcommand{\subsubsubsection}{% 178 | \vskip5pt{\noindent\normalsize\rm\raggedright}% 179 | } 180 | 181 | % float placement 182 | \renewcommand{\topfraction }{0.85} 183 | \renewcommand{\bottomfraction }{0.4} 184 | \renewcommand{\textfraction }{0.1} 185 | \renewcommand{\floatpagefraction}{0.7} 186 | 187 | \newlength{\@nipsabovecaptionskip}\setlength{\@nipsabovecaptionskip}{7\p@} 188 | \newlength{\@nipsbelowcaptionskip}\setlength{\@nipsbelowcaptionskip}{\z@} 189 | 190 | \setlength{\abovecaptionskip}{\@nipsabovecaptionskip} 191 | \setlength{\belowcaptionskip}{\@nipsbelowcaptionskip} 192 | 193 | % swap above/belowcaptionskip lengths for tables 194 | \renewenvironment{table} 195 | {\setlength{\abovecaptionskip}{\@nipsbelowcaptionskip}% 196 | \setlength{\belowcaptionskip}{\@nipsabovecaptionskip}% 197 | \@float{table}} 198 | {\end@float} 199 | 200 | % footnote formatting 201 | \setlength{\footnotesep }{6.65\p@} 202 | \setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@} 203 | \renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@} 204 | \setcounter{footnote}{0} 205 | 206 | % paragraph formatting 207 | \setlength{\parindent}{\z@} 208 | \setlength{\parskip }{5.5\p@} 209 | 210 | % list formatting 211 | \setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@} 212 | \setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@} 213 | \setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@} 214 | \setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@} 215 | \setlength{\leftmargin }{3pc} 216 | \setlength{\leftmargini }{\leftmargin} 217 | \setlength{\leftmarginii }{2em} 218 | \setlength{\leftmarginiii}{1.5em} 219 | \setlength{\leftmarginiv }{1.0em} 220 | \setlength{\leftmarginv }{0.5em} 221 | \def\@listi {\leftmargin\leftmargini} 222 | \def\@listii {\leftmargin\leftmarginii 223 | \labelwidth\leftmarginii 224 | \advance\labelwidth-\labelsep 225 | \topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@ 226 | \parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@ 227 | \itemsep \parsep} 228 | \def\@listiii{\leftmargin\leftmarginiii 229 | \labelwidth\leftmarginiii 230 | \advance\labelwidth-\labelsep 231 | \topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@ 232 | \parsep \z@ 233 | \partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@ 234 | \itemsep \topsep} 235 | \def\@listiv {\leftmargin\leftmarginiv 236 | \labelwidth\leftmarginiv 237 | \advance\labelwidth-\labelsep} 238 | \def\@listv {\leftmargin\leftmarginv 239 | \labelwidth\leftmarginv 240 | \advance\labelwidth-\labelsep} 241 | \def\@listvi {\leftmargin\leftmarginvi 242 | \labelwidth\leftmarginvi 243 | \advance\labelwidth-\labelsep} 244 | 245 | % create title 246 | \providecommand{\maketitle}{} 247 | \renewcommand{\maketitle}{% 248 | \par 249 | \begingroup 250 | \renewcommand{\thefootnote}{\fnsymbol{footnote}} 251 | % for perfect author name centering 252 | \renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}} 253 | % The footnote-mark was overlapping the footnote-text, 254 | % added the following to fix this problem (MK) 255 | \long\def\@makefntext##1{% 256 | \parindent 1em\noindent 257 | \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1 258 | } 259 | \thispagestyle{empty} 260 | \@maketitle 261 | \@thanks 262 | \@notice 263 | \endgroup 264 | \let\maketitle\relax 265 | \let\thanks\relax 266 | } 267 | 268 | % rules for title box at top of first page 269 | \newcommand{\@toptitlebar}{ 270 | \hrule height 4\p@ 271 | \vskip 0.25in 272 | \vskip -\parskip% 273 | } 274 | \newcommand{\@bottomtitlebar}{ 275 | \vskip 0.29in 276 | \vskip -\parskip 277 | \hrule height 1\p@ 278 | \vskip 0.09in% 279 | } 280 | 281 | % create title (includes both anonymized and non-anonymized versions) 282 | \providecommand{\@maketitle}{} 283 | \renewcommand{\@maketitle}{% 284 | \vbox{% 285 | \hsize\textwidth 286 | \linewidth\hsize 287 | \vskip 0.1in 288 | \@toptitlebar 289 | \centering 290 | {\LARGE\bf \@title\par} 291 | \@bottomtitlebar 292 | \if@nipsfinal 293 | \def\And{% 294 | \end{tabular}\hfil\linebreak[0]\hfil% 295 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces% 296 | } 297 | \def\AND{% 298 | \end{tabular}\hfil\linebreak[4]\hfil% 299 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces% 300 | } 301 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}% 302 | \else 303 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@} 304 | Anonymous Author(s) \\ 305 | Affiliation \\ 306 | Address \\ 307 | \texttt{email} \\ 308 | \end{tabular}% 309 | \fi 310 | \vskip 0.3in \@minus 0.1in 311 | } 312 | } 313 | 314 | % add conference notice to bottom of first page 315 | \newcommand{\ftype@noticebox}{8} 316 | \newcommand{\@notice}{% 317 | % give a bit of extra room back to authors on first page 318 | \enlargethispage{2\baselineskip}% 319 | \@float{noticebox}[b]% 320 | \footnotesize\@noticestring% 321 | \end@float% 322 | } 323 | 324 | % abstract styling 325 | \renewenvironment{abstract}% 326 | {% 327 | \vskip 0.075in% 328 | \centerline% 329 | {\large\bf Abstract}% 330 | \vspace{0.5ex}% 331 | \begin{quote}% 332 | } 333 | { 334 | \par% 335 | \end{quote}% 336 | \vskip 1ex% 337 | } 338 | 339 | \endinput 340 | -------------------------------------------------------------------------------- /paper/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{sutskever2014sequence, 2 | title={Sequence to sequence learning with neural networks}, 3 | author={Sutskever, Ilya and Vinyals, Oriol and Le, Quoc V}, 4 | booktitle={Advances in neural information processing systems}, 5 | pages={3104--3112}, 6 | year={2014} 7 | } 8 | 9 | @article{amodei2015deep, 10 | title={Deep speech 2: End-to-end speech recognition in english and mandarin}, 11 | author={Amodei, Dario and Anubhai, Rishita and Battenberg, Eric and Case, Carl and Casper, Jared and Catanzaro, Bryan and Chen, Jingdong and Chrzanowski, Mike and Coates, Adam and Diamos, Greg and others}, 12 | journal={arXiv preprint arXiv:1512.02595}, 13 | year={2015} 14 | } 15 | 16 | @inproceedings{hausknecht2015deep, 17 | title={Deep Recurrent Q-Learning for Partially Observable MDPs}, 18 | author={Hausknecht, Matthew and Stone, Peter}, 19 | booktitle={2015 AAAI Fall Symposium Series}, 20 | year={2015} 21 | } 22 | 23 | @article{hochreiter1997long, 24 | title={Long short-term memory}, 25 | author={Hochreiter, Sepp and Schmidhuber, J{\"u}rgen}, 26 | journal={Neural computation}, 27 | volume={9}, 28 | number={8}, 29 | pages={1735--1780}, 30 | year={1997}, 31 | publisher={MIT Press} 32 | } 33 | 34 | @article{cho2014learning, 35 | title={Learning phrase representations using RNN encoder-decoder for statistical machine translation}, 36 | author={Cho, Kyunghyun and Van Merri{\"e}nboer, Bart and Gulcehre, Caglar and Bahdanau, Dzmitry and Bougares, Fethi and Schwenk, Holger and Bengio, Yoshua}, 37 | journal={arXiv preprint arXiv:1406.1078}, 38 | year={2014} 39 | } 40 | 41 | @incollection{neil2016phased, 42 | title = {Phased LSTM: Accelerating Recurrent Network Training for Long or Event-based Sequences}, 43 | author = {Neil, Daniel and Pfeiffer, Michael and Liu, Shih-Chii}, 44 | booktitle = {Advances in Neural Information Processing Systems 29}, 45 | editor = {D. D. Lee and M. Sugiyama and U. V. Luxburg and I. Guyon and R. Garnett}, 46 | pages = {3882--3890}, 47 | year = {2016}, 48 | publisher = {Curran Associates, Inc.}, 49 | url = {http://papers.nips.cc/paper/6310-phased-lstm-accelerating-recurrent-network-training-for-long-or-event-based-sequences.pdf} 50 | } 51 | 52 | @misc{hochreiter2001gradient, 53 | title={Gradient flow in recurrent nets: the difficulty of learning long-term dependencies}, 54 | author={Hochreiter, Sepp and Bengio, Yoshua and Frasconi, Paolo and Schmidhuber, J{\"u}rgen}, 55 | year={2001}, 56 | publisher={A field guide to dynamical recurrent neural networks. IEEE Press} 57 | } 58 | 59 | @inproceedings{diamos2016persistent, 60 | title={Persistent rnns: Stashing recurrent weights on-chip}, 61 | author={Diamos, Greg and Sengupta, Shubho and Catanzaro, Bryan and Chrzanowski, Mike and Coates, Adam and Elsen, Erich and Engel, Jesse and Hannun, Awni and Satheesh, Sanjeev}, 62 | booktitle={International Conference on Machine Learning}, 63 | pages={2024--2033}, 64 | year={2016} 65 | } 66 | 67 | 68 | @inproceedings{bradbury2017quasi, 69 | title={Quasi-Recurrent Neural Networks}, 70 | author={Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard}, 71 | arxiv = "https://arxiv.org/abs/1611.01576", 72 | booktitle = "International Conference on Learning Representations (ICLR)", 73 | year={2017} 74 | } 75 | 76 | @article{kalchbrenner2016neural, 77 | title={Neural machine translation in linear time}, 78 | author={Kalchbrenner, Nal and Espeholt, Lasse and Simonyan, Karen and Oord, Aaron van den and Graves, Alex and Kavukcuoglu, Koray}, 79 | journal={arXiv preprint arXiv:1610.10099}, 80 | year={2016} 81 | } 82 | 83 | @article{gehring2017convolutional, 84 | title={Convolutional Sequence to Sequence Learning}, 85 | author={Gehring, Jonas and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, 86 | journal={arXiv preprint arXiv:1705.03122}, 87 | year={2017} 88 | } 89 | 90 | @article{van2016wavenet, 91 | title={Wavenet: A generative model for raw audio}, 92 | author={van den Oord, A{\"a}ron and Dieleman, Sander and Zen, Heiga and Simonyan, Karen and Vinyals, Oriol and Graves, Alex and Kalchbrenner, Nal and Senior, Andrew and Kavukcuoglu, Koray}, 93 | journal={CoRR abs/1609.03499}, 94 | year={2016} 95 | } 96 | 97 | @article{keskar2017large, 98 | title={On large-batch training for deep learning: Generalization gap and sharp minima}, 99 | author={Keskar, Nitish Shirish and Mudigere, Dheevatsa and Nocedal, Jorge and Smelyanskiy, Mikhail and Tang, Ping Tak Peter}, 100 | arxiv = "https://arxiv.org/abs/1609.04836", 101 | booktitle = "International Conference on Learning Representations (ICLR)", 102 | year={2017} 103 | } 104 | 105 | @inproceedings{balduzzi2016strongly, 106 | title={Strongly-Typed Recurrent Neural Networks}, 107 | author={Balduzzi, David and Ghifary, Muhammad}, 108 | booktitle={Proceedings of The 33rd International Conference on Machine Learning}, 109 | pages={1292--1300}, 110 | year={2016} 111 | } 112 | 113 | @book{goodfellow2016deep, 114 | title={Deep learning}, 115 | author={Goodfellow, Ian and Bengio, Yoshua and Courville, Aaron}, 116 | year={2016}, 117 | publisher={MIT Press} 118 | } 119 | 120 | @article{blelloch1990prefix, 121 | title={Prefix sums and their applications}, 122 | author={Blelloch, Guy E}, 123 | year={1990} 124 | } 125 | 126 | @article{orchard2015converting, 127 | title={Converting static image datasets to spiking neuromorphic datasets using saccades}, 128 | author={Orchard, Garrick and Jayawant, Ajinkya and Cohen, Gregory and Thakor, Nitish}, 129 | journal={arXiv preprint arXiv:1507.07629}, 130 | year={2015} 131 | } 132 | 133 | @misc{lecun1998mnist, 134 | title={The MNIST database of handwritten digits}, 135 | author={LeCun, Yann and Cortes, Corinna and Burges, Christopher JC}, 136 | year={1998} 137 | } 138 | 139 | @article{ladner1980parallel, 140 | title={Parallel prefix computation}, 141 | author={Ladner, Richard E and Fischer, Michael J}, 142 | journal={Journal of the ACM (JACM)}, 143 | volume={27}, 144 | number={4}, 145 | pages={831--838}, 146 | year={1980}, 147 | publisher={ACM} 148 | } 149 | 150 | @article{abadi2016tensorflow, 151 | title={Tensorflow: Large-scale machine learning on heterogeneous distributed systems}, 152 | author={Abadi, Mart{\'\i}n and Agarwal, Ashish and Barham, Paul and Brevdo, Eugene and Chen, Zhifeng and Citro, Craig and Corrado, Greg S and Davis, Andy and Dean, Jeffrey and Devin, Matthieu and others}, 153 | journal={arXiv preprint arXiv:1603.04467}, 154 | year={2016} 155 | } 156 | 157 | @article{kingma2014adam, 158 | title={Adam: A method for stochastic optimization}, 159 | author={Kingma, Diederik and Ba, Jimmy}, 160 | journal={arXiv preprint arXiv:1412.6980}, 161 | year={2014} 162 | } 163 | 164 | @inproceedings{glorot2010understanding, 165 | title={Understanding the difficulty of training deep feedforward neural networks.}, 166 | author={Glorot, Xavier and Bengio, Yoshua}, 167 | booktitle={Aistats}, 168 | volume={9}, 169 | pages={249--256}, 170 | year={2010} 171 | } -------------------------------------------------------------------------------- /paper/references.bib.blg: -------------------------------------------------------------------------------- 1 | [0] Config.pm:324> INFO - This is Biber 2.1 2 | [0] Config.pm:327> INFO - Logfile is 'references.bib.blg' 3 | [86] biber-darwin:276> INFO - === Tue Sep 12, 2017, 09:52:59 4 | [233] Utils.pm:162> ERROR - Cannot find control file 'references.bib.bcf'! - did you pass the "backend=biber" option to BibLaTeX? 5 | [234] Biber.pm:110> INFO - ERRORS: 1 6 | -------------------------------------------------------------------------------- /paper/synthetic_diagram.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/synthetic_diagram.pdf -------------------------------------------------------------------------------- /paper/to_do.aux: -------------------------------------------------------------------------------- 1 | \relax 2 | \providecommand\hyper@newdestlabel[2]{} 3 | \providecommand\HyperFirstAtBeginDocument{\AtBeginDocument} 4 | \HyperFirstAtBeginDocument{\ifx\hyper@anchor\@undefined 5 | \global\let\oldcontentsline\contentsline 6 | \gdef\contentsline#1#2#3#4{\oldcontentsline{#1}{#2}{#3}} 7 | \global\let\oldnewlabel\newlabel 8 | \gdef\newlabel#1#2{\newlabelxx{#1}#2} 9 | \gdef\newlabelxx#1#2#3#4#5#6{\oldnewlabel{#1}{{#2}{#3}}} 10 | \AtEndDocument{\ifx\hyper@anchor\@undefined 11 | \let\contentsline\oldcontentsline 12 | \let\newlabel\oldnewlabel 13 | \fi} 14 | \fi} 15 | \global\let\hyper@last\relax 16 | \gdef\HyperFirstAtBeginDocument#1{#1} 17 | \providecommand\HyField@AuxAddToFields[1]{} 18 | \providecommand\HyField@AuxAddToCoFields[2]{} 19 | \@writefile{toc}{\contentsline {section}{\numberline {1}To do for paper}{1}{section.1}} 20 | \newlabel{sec-1}{{1}{1}{To do for paper}{section.1}{}} 21 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.1}Edit all}{1}{subsection.1.1}} 22 | \newlabel{sec-1-1}{{1.1}{1}{Edit all}{subsection.1.1}{}} 23 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.2}Look up templates etc, double-check there's not one out there}{1}{subsection.1.2}} 24 | \newlabel{sec-1-2}{{1.2}{1}{Look up templates etc, double-check there's not one out there}{subsection.1.2}{}} 25 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.3}Do and write up the experiments}{1}{subsection.1.3}} 26 | \newlabel{sec-1-3}{{1.3}{1}{Do and write up the experiments}{subsection.1.3}{}} 27 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {1.3.1}Synthetic Task}{1}{subsubsection.1.3.1}} 28 | \newlabel{sec-1-3-1}{{1.3.1}{1}{Synthetic Task}{subsubsection.1.3.1}{}} 29 | \newlabel{sec-1-3-1-1}{{1}{1}{Synthetic Task}{Item.1}{}} 30 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {1.3.2}Medical Task}{1}{subsubsection.1.3.2}} 31 | \newlabel{sec-1-3-2}{{1.3.2}{1}{Medical Task}{subsubsection.1.3.2}{}} 32 | \newlabel{sec-1-3-2-1}{{1}{1}{Medical Task}{Item.2}{}} 33 | \newlabel{sec-1-3-2-2}{{2}{1}{Medical Task}{Item.3}{}} 34 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {1.3.3}Throughput Task}{1}{subsubsection.1.3.3}} 35 | \newlabel{sec-1-3-3}{{1.3.3}{1}{Throughput Task}{subsubsection.1.3.3}{}} 36 | \newlabel{sec-1-3-3-1}{{1}{1}{Throughput Task}{Item.4}{}} 37 | \newlabel{sec-1-3-3-2}{{2}{1}{Throughput Task}{Item.5}{}} 38 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {1.3.4}PLR/SLR comparison}{1}{subsubsection.1.3.4}} 39 | \newlabel{sec-1-3-4}{{1.3.4}{1}{PLR/SLR comparison}{subsubsection.1.3.4}{}} 40 | \newlabel{sec-1-3-4-1}{{1}{1}{PLR/SLR comparison}{Item.6}{}} 41 | \newlabel{sec-1-3-4-2}{{2}{1}{PLR/SLR comparison}{Item.7}{}} 42 | \newlabel{sec-1-3-4-3}{{3}{2}{PLR/SLR comparison}{Item.8}{}} 43 | \newlabel{sec-1-3-4-4}{{4}{2}{PLR/SLR comparison}{Item.9}{}} 44 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.4}Collate the experiments and smooth}{2}{subsection.1.4}} 45 | \newlabel{sec-1-4}{{1.4}{2}{Collate the experiments and smooth}{subsection.1.4}{}} 46 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.5}Think about the overall message of the piece and rewrite as necessary}{2}{subsection.1.5}} 47 | \newlabel{sec-1-5}{{1.5}{2}{Think about the overall message of the piece and rewrite as necessary}{subsection.1.5}{}} 48 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.6}Send to Adam, Andy to check}{2}{subsection.1.6}} 49 | \newlabel{sec-1-6}{{1.6}{2}{Send to Adam, Andy to check}{subsection.1.6}{}} 50 | \@writefile{toc}{\contentsline {subsection}{\numberline {1.7}Submit}{2}{subsection.1.7}} 51 | \newlabel{sec-1-7}{{1.7}{2}{Submit}{subsection.1.7}{}} 52 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {1.7.1}Then go to party and relax the fuck out}{2}{subsubsection.1.7.1}} 53 | \newlabel{sec-1-7-1}{{1.7.1}{2}{Then go to party and relax the fuck out}{subsubsection.1.7.1}{}} 54 | \@writefile{toc}{\contentsline {section}{\numberline {2}Timeline}{2}{section.2}} 55 | \newlabel{sec-2}{{2}{2}{Timeline}{section.2}{}} 56 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.1}Saturday evening:}{2}{subsection.2.1}} 57 | \newlabel{sec-2-1}{{2.1}{2}{Saturday evening:}{subsection.2.1}{}} 58 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.1.1}Done interesting medical results, tested on testing set}{2}{subsubsection.2.1.1}} 59 | \newlabel{sec-2-1-1}{{2.1.1}{2}{Done interesting medical results, tested on testing set}{subsubsection.2.1.1}{}} 60 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.1.2}Looked at throughput task}{2}{subsubsection.2.1.2}} 61 | \newlabel{sec-2-1-2}{{2.1.2}{2}{Looked at throughput task}{subsubsection.2.1.2}{}} 62 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.1.3}Done PLR / SLR throughput task}{2}{subsubsection.2.1.3}} 63 | \newlabel{sec-2-1-3}{{2.1.3}{2}{Done PLR / SLR throughput task}{subsubsection.2.1.3}{}} 64 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.2}Sunday evening:}{2}{subsection.2.2}} 65 | \newlabel{sec-2-2}{{2.2}{2}{Sunday evening:}{subsection.2.2}{}} 66 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.2.1}Fully written up medical task}{2}{subsubsection.2.2.1}} 67 | \newlabel{sec-2-2-1}{{2.2.1}{2}{Fully written up medical task}{subsubsection.2.2.1}{}} 68 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.2.2}Written up throuput and PLR/SLR task}{2}{subsubsection.2.2.2}} 69 | \newlabel{sec-2-2-2}{{2.2.2}{2}{Written up throuput and PLR/SLR task}{subsubsection.2.2.2}{}} 70 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.2.3}Start editing the paper, smoothing into a cohesive whole}{2}{subsubsection.2.2.3}} 71 | \newlabel{sec-2-2-3}{{2.2.3}{2}{Start editing the paper, smoothing into a cohesive whole}{subsubsection.2.2.3}{}} 72 | \@writefile{toc}{\contentsline {subsection}{\numberline {2.3}Monday evening:}{2}{subsection.2.3}} 73 | \newlabel{sec-2-3}{{2.3}{2}{Monday evening:}{subsection.2.3}{}} 74 | \@writefile{toc}{\contentsline {subsubsection}{\numberline {2.3.1}Finish editing, send off to everyone else}{2}{subsubsection.2.3.1}} 75 | \newlabel{sec-2-3-1}{{2.3.1}{2}{Finish editing, send off to everyone else}{subsubsection.2.3.1}{}} 76 | -------------------------------------------------------------------------------- /paper/to_do.org: -------------------------------------------------------------------------------- 1 | * To do for paper 2 | ** Edit all 3 | ** Look up templates etc, double-check there's not one out there 4 | ** Do and write up the experiments 5 | *** Synthetic Task 6 | **** Already done, written up. Need to review 7 | *** Medical Task 8 | **** Being done, need to check on testing set 9 | **** Need to write up 10 | *** Throughput Task 11 | **** Chat to Eric about this 12 | **** Just generate random data 13 | *** PLR/SLR comparison 14 | **** Have the pure kernel data 15 | 16 | **** Do that for random data 17 | **** Do a 2-layer network, with varying seq_len 18 | **** Report the throughput 19 | 20 | | Sequence Length | LS-LSTM | SRU | QRNN(2) | QRNN(10 | 21 | |-----------------+---------+------+---------+---------| 22 | | 16 | 0.61 | 0.28 | 0.38 | 0.78 | 23 | | 256 | 0.91 | 0.84 | 0.86 | 0.99 | 24 | | 4,096 | 0.98 | 1.38 | 1.18 | 1.05 | 25 | | 65,536 | 1.41 | 9.21 | 6.68 | 2.05 | 26 | For an input_size of 24, output 2, hidden_size 256, 27 | batch_size 65536 / seq_len 28 | 29 | ** Collate the experiments and smooth 30 | ** Think about the overall message of the piece and rewrite as necessary 31 | ** Send to Adam, Andy to check 32 | ** Submit 33 | *** Then go to party and relax the fuck out 34 | 35 | 36 | * Timeline 37 | ** Saturday evening: 38 | *** Done interesting medical results, tested on testing set 39 | *** Looked at throughput task 40 | *** Done PLR / SLR throughput task 41 | ** Sunday evening: 42 | *** Fully written up medical task 43 | *** Written up throuput and PLR/SLR task 44 | *** Start editing the paper, smoothing into a cohesive whole 45 | ** Monday evening: 46 | *** Finish editing, send off to everyone else 47 | 48 | 49 | * Paper structure 50 | ** Introduction, abstract, background 51 | *** Describe 52 | We can apply the PLR method to any architecture that satisfies the 53 | constraints. E.g. SRU, QRNN, etc. Introduce LS-LSTM as good substitute. 54 | Describe how previous papers have shown that linear LSTMs work well, this 55 | approach allows us to speed it up. 56 | ** Experiments 57 | *** Benchmarks 58 | **** Throughput of pure PLR kernel vs SLR kernel 59 | **** Show how it speeds up the SRU, QRNN, and LS-LSTM 60 | **** Show how the LS-LSTM has much better throughput than the CudnnLSTM 61 | *** Synthetic task to show that linear LSTMs can still work well 62 | *** Medical task 63 | 64 | 65 | 66 | 31.4 * (4.3 * 6) = 67 | ~50 * ~74 68 | 135 * 190 69 | -------------------------------------------------------------------------------- /paper/to_do.out: -------------------------------------------------------------------------------- 1 | \BOOKMARK [1][-]{section.1}{To do for paper}{}% 1 2 | \BOOKMARK [2][-]{subsection.1.1}{Edit all}{section.1}% 2 3 | \BOOKMARK [2][-]{subsection.1.2}{Look up templates etc, double-check there's not one out there}{section.1}% 3 4 | \BOOKMARK [2][-]{subsection.1.3}{Do and write up the experiments}{section.1}% 4 5 | \BOOKMARK [3][-]{subsubsection.1.3.1}{Synthetic Task}{subsection.1.3}% 5 6 | \BOOKMARK [3][-]{subsubsection.1.3.2}{Medical Task}{subsection.1.3}% 6 7 | \BOOKMARK [3][-]{subsubsection.1.3.3}{Throughput Task}{subsection.1.3}% 7 8 | \BOOKMARK [3][-]{subsubsection.1.3.4}{PLR/SLR comparison}{subsection.1.3}% 8 9 | \BOOKMARK [2][-]{subsection.1.4}{Collate the experiments and smooth}{section.1}% 9 10 | \BOOKMARK [2][-]{subsection.1.5}{Think about the overall message of the piece and rewrite as necessary}{section.1}% 10 11 | \BOOKMARK [2][-]{subsection.1.6}{Send to Adam, Andy to check}{section.1}% 11 12 | \BOOKMARK [2][-]{subsection.1.7}{Submit}{section.1}% 12 13 | \BOOKMARK [3][-]{subsubsection.1.7.1}{Then go to party and relax the fuck out}{subsection.1.7}% 13 14 | \BOOKMARK [1][-]{section.2}{Timeline}{}% 14 15 | \BOOKMARK [2][-]{subsection.2.1}{Saturday evening:}{section.2}% 15 16 | \BOOKMARK [3][-]{subsubsection.2.1.1}{Done interesting medical results, tested on testing set}{subsection.2.1}% 16 17 | \BOOKMARK [3][-]{subsubsection.2.1.2}{Looked at throughput task}{subsection.2.1}% 17 18 | \BOOKMARK [3][-]{subsubsection.2.1.3}{Done PLR / SLR throughput task}{subsection.2.1}% 18 19 | \BOOKMARK [2][-]{subsection.2.2}{Sunday evening:}{section.2}% 19 20 | \BOOKMARK [3][-]{subsubsection.2.2.1}{Fully written up medical task}{subsection.2.2}% 20 21 | \BOOKMARK [3][-]{subsubsection.2.2.2}{Written up throuput and PLR/SLR task}{subsection.2.2}% 21 22 | \BOOKMARK [3][-]{subsubsection.2.2.3}{Start editing the paper, smoothing into a cohesive whole}{subsection.2.2}% 22 23 | \BOOKMARK [2][-]{subsection.2.3}{Monday evening:}{section.2}% 23 24 | \BOOKMARK [3][-]{subsubsection.2.3.1}{Finish editing, send off to everyone else}{subsection.2.3}% 24 25 | -------------------------------------------------------------------------------- /paper/to_do.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/paper/to_do.pdf -------------------------------------------------------------------------------- /paper/to_do.tex: -------------------------------------------------------------------------------- 1 | % Created 2017-10-23 Mon 00:20 2 | \documentclass[11pt]{article} 3 | \usepackage[utf8]{inputenc} 4 | \usepackage[T1]{fontenc} 5 | \usepackage{fixltx2e} 6 | \usepackage{graphicx} 7 | \usepackage{longtable} 8 | \usepackage{float} 9 | \usepackage{wrapfig} 10 | \usepackage{rotating} 11 | \usepackage[normalem]{ulem} 12 | \usepackage{amsmath} 13 | \usepackage{textcomp} 14 | \usepackage{marvosym} 15 | \usepackage{wasysym} 16 | \usepackage{amssymb} 17 | \usepackage{hyperref} 18 | \tolerance=1000 19 | \author{Christopher Cundy} 20 | \date{\today} 21 | \title{to\_do} 22 | \hypersetup{ 23 | pdfkeywords={}, 24 | pdfsubject={}, 25 | pdfcreator={Emacs 25.1.1 (Org mode 8.2.10)}} 26 | \begin{document} 27 | 28 | \maketitle 29 | \tableofcontents 30 | 31 | \section{To do for paper} 32 | \label{sec-1} 33 | \subsection{Edit all} 34 | \label{sec-1-1} 35 | \subsection{Look up templates etc, double-check there's not one out there} 36 | \label{sec-1-2} 37 | \subsection{Do and write up the experiments} 38 | \label{sec-1-3} 39 | \subsubsection{Synthetic Task} 40 | \label{sec-1-3-1} 41 | \begin{enumerate} 42 | \item Already done, written up. Need to review 43 | \label{sec-1-3-1-1} 44 | \end{enumerate} 45 | \subsubsection{Medical Task} 46 | \label{sec-1-3-2} 47 | \begin{enumerate} 48 | \item Being done, need to check on testing set 49 | \label{sec-1-3-2-1} 50 | \item Need to write up 51 | \label{sec-1-3-2-2} 52 | \end{enumerate} 53 | \subsubsection{Throughput Task} 54 | \label{sec-1-3-3} 55 | \begin{enumerate} 56 | \item Chat to Eric about this 57 | \label{sec-1-3-3-1} 58 | \item Just generate random data 59 | \label{sec-1-3-3-2} 60 | \end{enumerate} 61 | \subsubsection{PLR/SLR comparison} 62 | \label{sec-1-3-4} 63 | \begin{enumerate} 64 | \item Have the pure kernel data 65 | \label{sec-1-3-4-1} 66 | 67 | \item Do that for random data 68 | \label{sec-1-3-4-2} 69 | \item Do a 2-layer network, with varying seq$_{\text{len}}$ 70 | \label{sec-1-3-4-3} 71 | \item Report the throughput 72 | \label{sec-1-3-4-4} 73 | 74 | \begin{center} 75 | \begin{tabular}{lrrrr} 76 | Sequence Length & LS-LSTM & SRU & QRNN(2) & QRNN(10\\ 77 | \hline 78 | 16 & 0.61 & 0.28 & 0.38 & 0.78\\ 79 | 256 & 0.91 & 0.84 & 0.86 & 0.99\\ 80 | 4,096 & 0.98 & 1.38 & 1.18 & 1.05\\ 81 | 65,536 & 1.41 & 9.21 & 6.68 & 2.05\\ 82 | \end{tabular} 83 | \end{center} 84 | For an input$_{\text{size}}$ of 24, output 2, hidden$_{\text{size}}$ 256, 85 | batch$_{\text{size}}$ 65536 / seq$_{\text{len}}$ 86 | \end{enumerate} 87 | 88 | \subsection{Collate the experiments and smooth} 89 | \label{sec-1-4} 90 | \subsection{Think about the overall message of the piece and rewrite as necessary} 91 | \label{sec-1-5} 92 | \subsection{Send to Adam, Andy to check} 93 | \label{sec-1-6} 94 | \subsection{Submit} 95 | \label{sec-1-7} 96 | \subsubsection{Then go to party and relax the fuck out} 97 | \label{sec-1-7-1} 98 | 99 | 100 | \section{Timeline} 101 | \label{sec-2} 102 | \subsection{Saturday evening:} 103 | \label{sec-2-1} 104 | \subsubsection{Done interesting medical results, tested on testing set} 105 | \label{sec-2-1-1} 106 | \subsubsection{Looked at throughput task} 107 | \label{sec-2-1-2} 108 | \subsubsection{Done PLR / SLR throughput task} 109 | \label{sec-2-1-3} 110 | \subsection{Sunday evening:} 111 | \label{sec-2-2} 112 | \subsubsection{Fully written up medical task} 113 | \label{sec-2-2-1} 114 | \subsubsection{Written up throuput and PLR/SLR task} 115 | \label{sec-2-2-2} 116 | \subsubsection{Start editing the paper, smoothing into a cohesive whole} 117 | \label{sec-2-2-3} 118 | \subsection{Monday evening:} 119 | \label{sec-2-3} 120 | \subsubsection{Finish editing, send off to everyone else} 121 | \label{sec-2-3-1} 122 | 123 | 124 | \section{Paper structure} 125 | \label{sec-3} 126 | \subsection{Introduction, abstract, background} 127 | \label{sec-3-1} 128 | \subsubsection{Describe} 129 | \label{sec-3-1-1} 130 | We can apply the PLR method to any architecture that satisfies the 131 | constraints. E.g. SRU, QRNN, etc. Introduce LS-LSTM as good substitute. 132 | Describe how previous papers have shown that linear LSTMs work well, this 133 | approach allows us to speed it up. 134 | \subsection{Experiments} 135 | \label{sec-3-2} 136 | \subsubsection{Benchmarks} 137 | \label{sec-3-2-1} 138 | \begin{enumerate} 139 | \item Throughput of pure PLR kernel vs SLR kernel 140 | \label{sec-3-2-1-1} 141 | \item Show how it speeds up the SRU, QRNN, and LS-LSTM 142 | \label{sec-3-2-1-2} 143 | \item Show how the LS-LSTM has much better throughput than the CudnnLSTM 144 | \label{sec-3-2-1-3} 145 | \end{enumerate} 146 | \subsubsection{Synthetic task to show that linear LSTMs can still work well} 147 | \label{sec-3-2-2} 148 | \subsubsection{Medical task} 149 | \label{sec-3-2-3} 150 | 151 | 152 | 153 | 31.4 * (4.3 * 6) = 154 | \textasciitilde{}50 * \textasciitilde{}74 155 | 135 * 190 156 | % Emacs 25.1.1 (Org mode 8.2.10) 157 | \end{document} -------------------------------------------------------------------------------- /paper/to_do.toc: -------------------------------------------------------------------------------- 1 | \contentsline {section}{\numberline {1}To do for paper}{1}{section.1} 2 | \contentsline {subsection}{\numberline {1.1}Edit all}{1}{subsection.1.1} 3 | \contentsline {subsection}{\numberline {1.2}Look up templates etc, double-check there's not one out there}{1}{subsection.1.2} 4 | \contentsline {subsection}{\numberline {1.3}Do and write up the experiments}{1}{subsection.1.3} 5 | \contentsline {subsubsection}{\numberline {1.3.1}Synthetic Task}{1}{subsubsection.1.3.1} 6 | \contentsline {subsubsection}{\numberline {1.3.2}Medical Task}{1}{subsubsection.1.3.2} 7 | \contentsline {subsubsection}{\numberline {1.3.3}Throughput Task}{1}{subsubsection.1.3.3} 8 | \contentsline {subsubsection}{\numberline {1.3.4}PLR/SLR comparison}{1}{subsubsection.1.3.4} 9 | \contentsline {subsection}{\numberline {1.4}Collate the experiments and smooth}{2}{subsection.1.4} 10 | \contentsline {subsection}{\numberline {1.5}Think about the overall message of the piece and rewrite as necessary}{2}{subsection.1.5} 11 | \contentsline {subsection}{\numberline {1.6}Send to Adam, Andy to check}{2}{subsection.1.6} 12 | \contentsline {subsection}{\numberline {1.7}Submit}{2}{subsection.1.7} 13 | \contentsline {subsubsection}{\numberline {1.7.1}Then go to party and relax the fuck out}{2}{subsubsection.1.7.1} 14 | \contentsline {section}{\numberline {2}Timeline}{2}{section.2} 15 | \contentsline {subsection}{\numberline {2.1}Saturday evening:}{2}{subsection.2.1} 16 | \contentsline {subsubsection}{\numberline {2.1.1}Done interesting medical results, tested on testing set}{2}{subsubsection.2.1.1} 17 | \contentsline {subsubsection}{\numberline {2.1.2}Looked at throughput task}{2}{subsubsection.2.1.2} 18 | \contentsline {subsubsection}{\numberline {2.1.3}Done PLR / SLR throughput task}{2}{subsubsection.2.1.3} 19 | \contentsline {subsection}{\numberline {2.2}Sunday evening:}{2}{subsection.2.2} 20 | \contentsline {subsubsection}{\numberline {2.2.1}Fully written up medical task}{2}{subsubsection.2.2.1} 21 | \contentsline {subsubsection}{\numberline {2.2.2}Written up throuput and PLR/SLR task}{2}{subsubsection.2.2.2} 22 | \contentsline {subsubsection}{\numberline {2.2.3}Start editing the paper, smoothing into a cohesive whole}{2}{subsubsection.2.2.3} 23 | \contentsline {subsection}{\numberline {2.3}Monday evening:}{2}{subsection.2.3} 24 | \contentsline {subsubsection}{\numberline {2.3.1}Finish editing, send off to everyone else}{2}{subsubsection.2.3.1} 25 | -------------------------------------------------------------------------------- /plr_slr.py: -------------------------------------------------------------------------------- 1 | def plr_slr(bs_seq_len_list): 2 | """Given a list of pairs (batch size, seq_len), 3 | calculate the throughput of an LS-LSTM, an SRU, a QRNN(2), 4 | and QRNN(10) using the parallel kernel as opposed to the serial 5 | one""" 6 | import tensorflow as tf 7 | import numpy as np 8 | import scipy.io.wavfile 9 | from tensorflow.contrib import rnn 10 | import math 11 | from layers_new import linear_surrogate_lstm 12 | from layers_new import s_linear_surrogate_lstm 13 | from layers_new import SRU 14 | from layers_new import s_SRU 15 | from layers_new import QRNN 16 | from layers_new import s_QRNN 17 | import time 18 | import os 19 | import random 20 | 21 | throughput_list = [] 22 | 23 | #TODO: 24 | #Make LS_LSTM with PLR 25 | #Make SRU with PLR 26 | #Make QRNN with PLR 27 | #Make LS_LSTM with SLR 28 | #Make SRU with SLR 29 | #Make QRNN with SLR 30 | 31 | 32 | for seq_len in seq_len_list: 33 | #First generate the LS-LSTM and work out the throughput 34 | tf.reset_default_graph() 35 | n_hidden = 256 36 | n_classes = 2 37 | n_steps = seq_len 38 | batch_size = 65536 / seq_len 39 | bs = batch_size 40 | print "Batch size is {} and sequence length is {}".format(bs, seq_len) 41 | n_input = 24 42 | n_layers = 2 43 | forget_gate_init = 1.0 # = 1/(n_in). We use uniform p(x) 44 | #Training Parameters 45 | sn = 1.0 / math.sqrt(n_hidden) 46 | learning_rate = 0.001 47 | training_iters = 5000000 48 | 49 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 50 | y = tf.placeholder("float", [batch_size, n_classes]) 51 | tf.get_variable_scope().reuse == True 52 | W1 = tf.get_variable('W1', initializer= 53 | tf.random_normal([n_hidden, n_classes]), dtype='float') 54 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 55 | 56 | layer1 = linear_surrogate_lstm(x, n_hidden, name='ls-lstm') 57 | outputs = linear_surrogate_lstm(layer1, n_hidden, name='ls-lstm2') 58 | pred = tf.matmul(outputs[-1], W1) + b1 59 | #Evaluate network, run adam and clip gradients 60 | ################################################################################ 61 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 62 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 63 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 64 | gradients = raw_gradients 65 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 66 | init = tf.global_variables_initializer() 67 | 68 | #Initialise the model and evaluate 69 | step = 0 70 | times = [] 71 | x_in = np.random.random((n_steps, batch_size, n_input)) 72 | y_in = np.random.random((batch_size, n_classes)) 73 | with tf.device("gpu:0"): 74 | with tf.Session() as sess: 75 | sess.run(init) 76 | while step < 10: 77 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 78 | step += 1 79 | if step != 0: 80 | start = time.time() 81 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 82 | finish = time.time() 83 | times.append(finish - start) 84 | ls_lstm_tp = (bs * n_steps) / np.mean(times) 85 | 86 | 87 | tf.reset_default_graph() 88 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 89 | y = tf.placeholder("float", [batch_size, n_classes]) 90 | tf.get_variable_scope().reuse == True 91 | W1 = tf.get_variable('W1', initializer= 92 | tf.random_normal([n_hidden, n_classes]), dtype='float') 93 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 94 | layer1 = s_linear_surrogate_lstm(x, n_hidden, name='ls-lstm') 95 | output = s_linear_surrogate_lstm(layer1, n_hidden, name='ls-lstm') 96 | pred = tf.matmul(output[-1], W1) + b1 97 | 98 | #Evaluate network, run adam and clip gradients 99 | ################################################################################ 100 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 101 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 102 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 103 | gradients = raw_gradients 104 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 105 | init = tf.global_variables_initializer() 106 | 107 | #Initialise the model and evaluate 108 | step = 0 109 | times = [] 110 | x_in = np.random.random((n_steps, batch_size, n_input)) 111 | y_in = np.random.random((batch_size, n_classes)) 112 | with tf.device("gpu:0"): 113 | with tf.Session() as sess: 114 | sess.run(init) 115 | while step < 10: 116 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 117 | step += 1 118 | if step != 0: 119 | start = time.time() 120 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 121 | finish = time.time() 122 | times.append(finish - start) 123 | s_ls_lstm_tp = (bs * n_steps) / np.mean(times) 124 | 125 | 126 | tf.reset_default_graph() 127 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 128 | y = tf.placeholder("float", [batch_size, n_classes]) 129 | tf.get_variable_scope().reuse == True 130 | W1 = tf.get_variable('W1', initializer= 131 | tf.random_normal([n_input, n_classes]), dtype='float') 132 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 133 | layer1 = SRU(x, name='SRU_1') 134 | output = SRU(layer1, name='SRU_2') 135 | pred = tf.matmul(output[-1], W1) + b1 136 | 137 | tf.reset_default_graph() 138 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 139 | y = tf.placeholder("float", [batch_size, n_classes]) 140 | tf.get_variable_scope().reuse == True 141 | W1 = tf.get_variable('W1', initializer= 142 | tf.random_normal([n_hidden, n_classes]), dtype='float') 143 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 144 | layer1 = s_linear_surrogate_lstm(x, n_hidden, name='ls-lstm') 145 | output = s_linear_surrogate_lstm(layer1, n_hidden, name='ls-lstm') 146 | pred = tf.matmul(output[-1], W1) + b1 147 | 148 | #Evaluate network, run adam and clip gradients 149 | ################################################################################ 150 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 151 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 152 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 153 | gradients = raw_gradients 154 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 155 | init = tf.global_variables_initializer() 156 | 157 | #Initialise the model and evaluate 158 | step = 0 159 | times = [] 160 | x_in = np.random.random((n_steps, batch_size, n_input)) 161 | y_in = np.random.random((batch_size, n_classes)) 162 | with tf.device("gpu:0"): 163 | with tf.Session() as sess: 164 | sess.run(init) 165 | while step < 10: 166 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 167 | step += 1 168 | if step != 0: 169 | start = time.time() 170 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 171 | finish = time.time() 172 | times.append(finish - start) 173 | s_ls_lstm_tp = (bs * n_steps) / np.mean(times) 174 | 175 | tf.reset_default_graph() 176 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 177 | y = tf.placeholder("float", [batch_size, n_classes]) 178 | tf.get_variable_scope().reuse == True 179 | W1 = tf.get_variable('W1', initializer= 180 | tf.random_normal([n_input, n_classes]), dtype='float') 181 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 182 | layer1 = SRU(x, name='SRU_1') 183 | output = SRU(layer1, name='SRU_2') 184 | pred = tf.matmul(output[-1], W1) + b1 185 | 186 | #Evaluate network, run adam and clip gradients 187 | ################################################################################ 188 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 189 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 190 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 191 | gradients = raw_gradients 192 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 193 | init = tf.global_variables_initializer() 194 | 195 | #Initialise the model and evaluate 196 | step = 0 197 | times = [] 198 | x_in = np.random.random((n_steps, batch_size, n_input)) 199 | y_in = np.random.random((batch_size, n_classes)) 200 | with tf.device("gpu:0"): 201 | with tf.Session() as sess: 202 | sess.run(init) 203 | while step < 10: 204 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 205 | step += 1 206 | if step != 0: 207 | start = time.time() 208 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 209 | finish = time.time() 210 | times.append(finish - start) 211 | sru_tp = (bs * n_steps) / np.mean(times) 212 | 213 | 214 | tf.reset_default_graph() 215 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 216 | y = tf.placeholder("float", [batch_size, n_classes]) 217 | tf.get_variable_scope().reuse == True 218 | W1 = tf.get_variable('W1', initializer= 219 | tf.random_normal([n_input, n_classes]), dtype='float') 220 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 221 | layer1 = s_SRU(x, name='s_SRU_1') 222 | output = s_SRU(layer1, name='s_SRU_2') 223 | pred = tf.matmul(output[-1], W1) + b1 224 | 225 | #Evaluate network, run adam and clip gradients 226 | ################################################################################ 227 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 228 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 229 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 230 | gradients = raw_gradients 231 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 232 | init = tf.global_variables_initializer() 233 | 234 | #Initialise the model and evaluate 235 | step = 0 236 | times = [] 237 | x_in = np.random.random((n_steps, batch_size, n_input)) 238 | y_in = np.random.random((batch_size, n_classes)) 239 | with tf.device("gpu:0"): 240 | with tf.Session() as sess: 241 | sess.run(init) 242 | while step < 10: 243 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 244 | step += 1 245 | if step != 0: 246 | start = time.time() 247 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 248 | finish = time.time() 249 | times.append(finish - start) 250 | s_sru_tp = (bs * n_steps) / np.mean(times) 251 | 252 | 253 | tf.reset_default_graph() 254 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 255 | y = tf.placeholder("float", [batch_size, n_classes]) 256 | tf.get_variable_scope().reuse == True 257 | W1 = tf.get_variable('W1', initializer= 258 | tf.random_normal([n_input, n_classes]), dtype='float') 259 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 260 | layer1 = QRNN(x, 2, name='QRNN_1') 261 | output = QRNN(layer1, 2, name='QRNN_2') 262 | pred = tf.matmul(output[-1], W1) + b1 263 | 264 | #Evaluate network, run adam and clip gradients 265 | ################################################################################ 266 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 267 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 268 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 269 | gradients = raw_gradients 270 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 271 | init = tf.global_variables_initializer() 272 | 273 | #Initialise the model and evaluate 274 | step = 0 275 | times = [] 276 | x_in = np.random.random((n_steps, batch_size, n_input)) 277 | y_in = np.random.random((batch_size, n_classes)) 278 | with tf.device("gpu:0"): 279 | with tf.Session() as sess: 280 | sess.run(init) 281 | while step < 10: 282 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 283 | step += 1 284 | if step != 0: 285 | start = time.time() 286 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 287 | finish = time.time() 288 | times.append(finish - start) 289 | qrnn_2_tp = (bs * n_steps) / np.mean(times) 290 | 291 | 292 | tf.reset_default_graph() 293 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 294 | y = tf.placeholder("float", [batch_size, n_classes]) 295 | tf.get_variable_scope().reuse == True 296 | W1 = tf.get_variable('W1', initializer= 297 | tf.random_normal([n_input, n_classes]), dtype='float') 298 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 299 | layer1 = s_QRNN(x, 2, name='s_QRNN_3') 300 | output = s_QRNN(layer1, 2, name='s_QRNN_4') 301 | pred = tf.matmul(output[-1], W1) + b1 302 | 303 | #Evaluate network, run adam and clip gradients 304 | ################################################################################ 305 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 306 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 307 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 308 | gradients = raw_gradients 309 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 310 | init = tf.global_variables_initializer() 311 | 312 | #Initialise the model and evaluate 313 | step = 0 314 | times = [] 315 | x_in = np.random.random((n_steps, batch_size, n_input)) 316 | y_in = np.random.random((batch_size, n_classes)) 317 | with tf.device("gpu:0"): 318 | with tf.Session() as sess: 319 | sess.run(init) 320 | while step < 10: 321 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 322 | step += 1 323 | if step != 0: 324 | start = time.time() 325 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 326 | finish = time.time() 327 | times.append(finish - start) 328 | s_qrnn_2_tp = (bs * n_steps) / np.mean(times) 329 | print np.mean(times) 330 | print np.std(times) 331 | 332 | tf.reset_default_graph() 333 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 334 | y = tf.placeholder("float", [batch_size, n_classes]) 335 | tf.get_variable_scope().reuse == True 336 | W1 = tf.get_variable('W1', initializer= 337 | tf.random_normal([n_input, n_classes]), dtype='float') 338 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 339 | layer1 = QRNN(x, 10, name='QRNN_2') 340 | output = QRNN(layer1, 10, name='QRNN_6') 341 | pred = tf.matmul(output[-1], W1) + b1 342 | 343 | #Evaluate network, run adam and clip gradients 344 | ################################################################################ 345 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 346 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 347 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 348 | gradients = raw_gradients 349 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 350 | init = tf.global_variables_initializer() 351 | 352 | #Initialise the model and evaluate 353 | step = 0 354 | times = [] 355 | x_in = np.random.random((n_steps, batch_size, n_input)) 356 | y_in = np.random.random((batch_size, n_classes)) 357 | with tf.device("gpu:0"): 358 | with tf.Session() as sess: 359 | sess.run(init) 360 | while step < 10: 361 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 362 | step += 1 363 | if step != 0: 364 | start = time.time() 365 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 366 | finish = time.time() 367 | times.append(finish - start) 368 | qrnn_10_tp = (bs * n_steps) / np.mean(times) 369 | 370 | 371 | tf.reset_default_graph() 372 | x = tf.placeholder("float", [n_steps, batch_size, n_input]) 373 | y = tf.placeholder("float", [batch_size, n_classes]) 374 | tf.get_variable_scope().reuse == True 375 | W1 = tf.get_variable('W1', initializer= 376 | tf.random_normal([n_input, n_classes]), dtype='float') 377 | b1 = tf.get_variable('b1', initializer=tf.zeros([n_classes]), dtype='float') 378 | layer1 = s_QRNN(x, 10, name='s_QRNN_7') 379 | output = s_QRNN(layer1, 10, name='s_QRNN_8') 380 | pred = tf.matmul(output[-1], W1) + b1 381 | 382 | #Evaluate network, run adam and clip gradients 383 | ################################################################################ 384 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 385 | optimizer_0 = tf.train.AdamOptimizer(learning_rate=learning_rate) 386 | raw_gradients, variables = zip(*optimizer_0.compute_gradients(cost)) 387 | gradients = raw_gradients 388 | optimizer = optimizer_0.apply_gradients(zip(gradients, variables)) 389 | init = tf.global_variables_initializer() 390 | 391 | #Initialise the model and evaluate 392 | step = 0 393 | times = [] 394 | x_in = np.random.random((n_steps, batch_size, n_input)) 395 | y_in = np.random.random((batch_size, n_classes)) 396 | with tf.device("gpu:0"): 397 | with tf.Session() as sess: 398 | sess.run(init) 399 | while step < 10: 400 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 401 | step += 1 402 | if step != 0: 403 | start = time.time() 404 | out = sess.run(pred, feed_dict={x: x_in, y: y_in}) 405 | finish = time.time() 406 | times.append(finish - start) 407 | s_qrnn_10_tp = (bs * n_steps) / np.mean(times) 408 | 409 | 410 | throughput_list.append([ls_lstm_tp, s_ls_lstm_tp, sru_tp, 411 | s_sru_tp, qrnn_2_tp, s_qrnn_2_tp, 412 | qrnn_10_tp, s_qrnn_10_tp]) 413 | return throughput_list 414 | 415 | if __name__ == "__main__": 416 | import numpy as np 417 | seq_len_list = [16 ** x for x in range(1, 5)] 418 | out = plr_slr(seq_len_list) 419 | p_ls_lstm, s_ls_lstm, p_sru, s_sru, p_2_qrnn, s_2_qrnn, p_10_qrnn, s_10_qrnn = zip(*out) 420 | print np.array(p_ls_lstm) / np.array(s_ls_lstm) 421 | print np.array(p_sru) / np.array(s_sru) 422 | print np.array(p_2_qrnn) / np.array(s_2_qrnn) 423 | print np.array(p_10_qrnn) / np.array(s_10_qrnn) 424 | # in_list1 = [[1, x] for x in [2**z for z in range(8, 19-1)]] 425 | # in_list2 = [[2, x] for x in [2**z for z in range(8, 19-2)]] 426 | # in_list4 = [[4, x] for x in [2**z for z in range(8, 19-3)]] 427 | # in_list8 = [[8, x] for x in [2**z for z in range(8, 19-4)]] 428 | # in_list16 = [[16, x] for x in [2**z for z in range(8, 19-5)]] 429 | # in_list32 = [[32, x] for x in [2**z for z in range(8, 19-6)]] 430 | # in_list64 = [[64, x] for x in [2**z for z in range(8, 19-7)]] 431 | # in_list128 = [[128, x] for x in [2**z for z in range(8, 19-8)]] 432 | # in_list256 = [[256, x] for x in [2**z for z in range(8, 19-9)]] 433 | 434 | # in_list1.extend(in_list2) 435 | # in_list1.extend(in_list4) 436 | # in_list1.extend(in_list8) 437 | # in_list1.extend(in_list16) 438 | # in_list1.extend(in_list32) 439 | # in_list1.extend(in_list64) 440 | # in_list1.extend(in_list128) 441 | # in_list1.extend(in_list256) 442 | 443 | # out = random_test(in_list1) 444 | # print out 445 | # lstm_times, cudnn_times, speedups = zip(*out) 446 | 447 | 448 | -------------------------------------------------------------------------------- /poster/8k_for_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/8k_for_poster.png -------------------------------------------------------------------------------- /poster/beamerposter.sty: -------------------------------------------------------------------------------- 1 | % Copyright 2007 by 2 | % Philippe Dreuw and 3 | % Thomas Deselaers 4 | % Slight modifications made in August 2009 by Nathaniel Johnston (nathaniel@nathanieljohnston.com) 5 | % 6 | % This file may be distributed and/or modified 7 | % 8 | % 1. under the LaTeX Project Public License and/or 9 | % 2. under the GNU Public License. 10 | % 11 | % 12 | % ChangeLog: 13 | % 14 | % 1.07 - bugfixed custom size handling, portrait or landscape settings are ignored now 15 | % 1.06 - added the type1cm package for scalable math fonts 16 | % 1.05 - added version check for xkeyval package 17 | % 1.04 - added custom size handling 18 | % 1.03 - improved predefined size handling 19 | % 1.02 - minor bugfixes 20 | % 1.01 - bugfixed size handling 21 | % 1.00 - first beamerposter release 22 | % 23 | \def\beamerposter@version{1.07} 24 | \def\beamerposter@date{2008/03/11} 25 | \def\beamerposter@msg{beamerposter: latex-beamer poster extension} 26 | \typeout{Package: \beamerposter@date. v.\beamerposter@version. \beamerposter@msg} 27 | 28 | \NeedsTeXFormat{LaTeX2e} 29 | \ProvidesPackage{beamerposter}[\beamerposter@date. v.\beamerposter@version. \beamerposter@msg] 30 | \RequirePackage{xkeyval}[2006/11/18] 31 | \RequirePackage{type1cm} %% get it from ftp://cam.ctan.org/tex-archive/macros/latex/contrib/type1cm.zip 32 | 33 | \newif\ifportrait 34 | \newif\ifcustomsize 35 | \newif\ifdebug 36 | 37 | \DeclareOptionX{size}[a0]{ 38 | \typeout{beamerposter: checking size input, please wait.} 39 | \XKV@cc*+[\val\nr]{#1}{a0b,a0,a1,a2,a3,a4,custom}{% 40 | \typeout{beamerposter: the input \val\ \nr\ was correct, we proceed.} 41 | \ifcase\nr\relax 42 | %a0b 43 | \setlength{\paperwidth}{119cm} 44 | \setlength{\paperheight}{88cm} 45 | \setlength{\textwidth}{116cm} 46 | \setlength{\textheight}{88cm} 47 | \or 48 | %a0 49 | \setlength{\paperwidth}{118.82cm} 50 | \setlength{\paperheight}{83.96cm} 51 | \setlength{\textwidth}{117.82cm} 52 | \setlength{\textheight}{82.96cm} 53 | \or 54 | %a1 55 | \setlength{\paperwidth}{83.96cm} 56 | \setlength{\paperheight}{59.4cm} 57 | \setlength{\textwidth}{82.96cm} 58 | \setlength{\textheight}{58.4cm} 59 | \or 60 | %a2 61 | \setlength{\paperwidth}{59.4cm} 62 | \setlength{\paperheight}{41.98cm} 63 | \setlength{\textwidth}{58.4cm} 64 | \setlength{\textheight}{40.98cm} 65 | \or 66 | %a3 67 | \setlength{\paperwidth}{41.98cm} 68 | \setlength{\paperheight}{29.7cm} 69 | \setlength{\textwidth}{40.98cm} 70 | \setlength{\textheight}{28.7cm} 71 | \or 72 | %a4 73 | \setlength{\paperwidth}{29.7cm} 74 | \setlength{\paperheight}{21.0cm} 75 | \setlength{\textwidth}{28.7cm} 76 | \setlength{\textheight}{20.0cm} 77 | \or 78 | \customsizetrue 79 | \fi 80 | }{% 81 | \PackageWarning{beamerposter}{the input \val\ was incorrect and was ignored.} 82 | }% 83 | \typeout{beamerposter: finished size input check.} 84 | } 85 | \DeclareOptionX{orientation}[portrait]{ 86 | \typeout{beamerposter: checking orientation input, please wait.} 87 | \XKV@cc*+[\val\nr]{#1}{portrait,landscape}{% 88 | \typeout{beamerposter: the input \val\ \nr\ was correct, we proceed.} 89 | \ifcase\nr\relax 90 | \portraittrue 91 | \or 92 | \portraitfalse 93 | \fi 94 | }{% 95 | \PackageWarning{beamerposter}{the input \val\ was incorrect and was ignored.} 96 | }% 97 | \typeout{beamerposter: finished orientation check.} 98 | } 99 | \DeclareOptionX{scale}[1.0]{\edef\myfontscale{#1}\typeout{beamerposter: myfontscale=\myfontscale}} 100 | \DeclareOptionX{width}{\edef\customwidth{#1}\typeout{beamerposter: custom poster width=\customwidth}} 101 | \DeclareOptionX{height}{\edef\customheight{#1}\typeout{beamerposter: custom poster height=\customheight}} 102 | \DeclareOptionX{debug}{\typeout{beamerposter: enabled debug mode}\debugtrue} 103 | \DeclareOptionX*{\PackageWarning{beamerposter}{Unknown option ignored: \CurrentOption}} 104 | %\DeclareOptionX*{\PassOptionsToClass{\CurrentOption}{beamer}} 105 | \ExecuteOptionsX{size=a0,scale=1.0} 106 | \ProcessOptionsX\relax 107 | 108 | \ifdebug 109 | \RequirePackage[debug]{fp} 110 | \else 111 | \RequirePackage{fp} 112 | \fi 113 | 114 | %% swap sizes for portrait orientation 115 | \ifportrait 116 | \newdimen\tmp 117 | \setlength{\tmp}{\paperwidth} 118 | \setlength{\paperwidth}{\paperheight} 119 | \setlength{\paperheight}{\tmp} 120 | \setlength{\tmp}{\textwidth} 121 | \setlength{\textwidth}{\textheight} 122 | \setlength{\textheight}{\tmp} 123 | \else\relax 124 | \fi 125 | 126 | %% overwrite dimensions if custom size 127 | \ifcustomsize 128 | \setlength{\paperwidth}{\customwidth cm} 129 | \setlength{\paperheight}{\customheight cm} 130 | \FPupn{\resulttextwidth}{1 customwidth -} 131 | \FPupn{\resulttextheight}{1 customheight -} 132 | \setlength{\textwidth}{\resulttextwidth cm} 133 | \setlength{\textheight}{\resulttextheight cm} 134 | \fi 135 | 136 | %% Setting proper dimensions for a DIN A0 printer 137 | \setlength{\headheight}{0 cm} 138 | \setlength{\headsep}{0 cm} 139 | \setlength{\topmargin}{-12.7 mm} % -1in +1.47cm 140 | \setlength{\oddsidemargin}{-25.4 mm} % -1in +0.4cm 141 | 142 | %% For the page layout 143 | \ifdebug 144 | \typeout{beamerposter: paperwidth=\the\paperwidth, paperheight=\the\paperheight} 145 | \typeout{beamerposter: textwidth=\the\textwidth, textwidth=\the\textheight} 146 | \fi 147 | \geometry{ 148 | paperwidth=\the\paperwidth, 149 | paperheight=\the\paperheight, 150 | hmargin=1cm,% 151 | vmargin=0cm,% 152 | head=0.5cm, % 153 | headsep=0pt,% 154 | foot=0.5cm % 155 | } 156 | 157 | %% scalable vector fonts 158 | \edef\fontSizeX{14.4}\edef\fontSizeY{18} 159 | \FPupn{\resultscriptsizeX}{myfontscale fontSizeX * 2 round} 160 | \FPupn{\resultscriptsizeY}{myfontscale fontSizeY * 2 round} 161 | \renewcommand*{\tiny}{\fontsize{\resultscriptsizeX}{\resultscriptsizeY}\selectfont} 162 | 163 | \edef\fontSizeX{17.28}\edef\fontSizeY{22} 164 | \FPupn{\resultfootnotesizeX}{myfontscale fontSizeX * 2 round} 165 | \FPupn{\resultfootnotesizeY}{myfontscale fontSizeY * 2 round} 166 | \renewcommand*{\scriptsize}{\fontsize{\resultfootnotesizeX}{\resultfootnotesizeY}\selectfont} 167 | 168 | \edef\fontSizeX{20.74}\edef\fontSizeY{25} 169 | \FPupn{\resultsmallX}{myfontscale fontSizeX * 2 round} 170 | \FPupn{\resultsmallY}{myfontscale fontSizeY * 2 round} 171 | \renewcommand*{\footnotesize}{\fontsize{\resultsmallX}{\resultsmallY}\selectfont} 172 | 173 | \edef\fontSizeX{24.88}\edef\fontSizeY{30} 174 | \FPupn{\resultnormalsizeX}{myfontscale fontSizeX * 2 round} 175 | \FPupn{\resultnormalsizeY}{myfontscale fontSizeY * 2 round} 176 | \renewcommand*{\small}{\fontsize{\resultnormalsizeX}{\resultnormalsizeY}\selectfont} 177 | 178 | \edef\fontSizeX{29.86}\edef\fontSizeY{37} 179 | \FPupn{\resultlargeX}{myfontscale fontSizeX * 2 round} 180 | \FPupn{\resultlargeY}{myfontscale fontSizeY * 2 round} 181 | \renewcommand*{\normalsize}{\fontsize{\resultlargeX}{\resultlargeY}\selectfont} 182 | 183 | \edef\fontSizeX{35.83}\edef\fontSizeY{45} 184 | \FPupn{\resultLargeX}{myfontscale fontSizeX * 2 round} 185 | \FPupn{\resultLargeY}{myfontscale fontSizeY * 2 round} 186 | \renewcommand*{\large}{\fontsize{\resultLargeX}{\resultLargeY}\selectfont} 187 | 188 | \edef\fontSizeX{43}\edef\fontSizeY{54} 189 | \FPupn{\resultLARGEX}{myfontscale fontSizeX * 2 round} 190 | \FPupn{\resultLARGEY}{myfontscale fontSizeY * 2 round} 191 | \renewcommand*{\Large}{\fontsize{\resultLARGEX}{\resultLARGEY}\selectfont} 192 | 193 | \edef\fontSizeX{51.6}\edef\fontSizeY{64} 194 | \FPupn{\resulthugeX}{myfontscale fontSizeX * 2 round} 195 | \FPupn{\resulthugeY}{myfontscale fontSizeY * 2 round} 196 | \renewcommand*{\LARGE}{\fontsize{\resulthugeX}{\resulthugeY}\selectfont} 197 | 198 | \edef\fontSizeX{61.92}\edef\fontSizeY{77} 199 | \FPupn{\resultHugeX}{myfontscale fontSizeX * 2 round} 200 | \FPupn{\resultHugeY}{myfontscale fontSizeY * 2 round} 201 | \renewcommand*{\huge}{\fontsize{\resultHugeX}{\resultHugeY}\selectfont} 202 | 203 | \edef\fontSizeX{74.3}\edef\fontSizeY{93} 204 | \FPupn{\resultveryHugeX}{myfontscale fontSizeX * 2 round} 205 | \FPupn{\resultveryHugeY}{myfontscale fontSizeY * 2 round} 206 | \renewcommand*{\Huge}{\fontsize{\resultveryHugeX}{\resultveryHugeY}\selectfont} 207 | 208 | \edef\fontSizeX{80.3}\edef\fontSizeY{101} 209 | \FPupn{\resultVeryHugeX}{myfontscale fontSizeX * 2 round} 210 | \FPupn{\resultVeryHugeY}{myfontscale fontSizeY * 2 round} 211 | \newcommand*{\veryHuge}{\fontsize{\resultVeryHugeX}{\resultVeryHugeY}\selectfont} 212 | 213 | \edef\fontSizeX{107}\edef\fontSizeY{134} 214 | \FPupn{\resultVERYHugeX}{myfontscale fontSizeX * 2 round} 215 | \FPupn{\resultVERYHugeY}{myfontscale fontSizeY * 2 round} 216 | \newcommand*{\VeryHuge}{\fontsize{\resultVERYHugeX}{\resultVERYHugeY}\selectfont} 217 | 218 | % set the normalfont (default) 219 | \renewcommand*{\normalfont}{\normalsize} 220 | -------------------------------------------------------------------------------- /poster/beamerthemeconfposter.sty: -------------------------------------------------------------------------------- 1 | %============================================================================== 2 | % Beamer style for the poster template posted at 3 | % http://www.nathanieljohnston.com/2009/08/latex-poster-template/ 4 | % 5 | % Created by the Computational Physics and Biophysics Group at Jacobs University 6 | % https://teamwork.jacobs-university.de:8443/confluence/display/CoPandBiG/LaTeX+Poster 7 | % Modified by Nathaniel Johnston (nathaniel@nathanieljohnston.com) in August 2009 8 | % ============================================================================= 9 | 10 | \ProvidesPackage{beamerthemeconfposter} 11 | \RequirePackage{tikz} % for drawing the nice rounded boxes 12 | \usetikzlibrary{arrows,backgrounds} 13 | \RequirePackage[T1]{fontenc} 14 | \RequirePackage{lmodern} 15 | \usepackage{exscale} 16 | \RequirePackage{textcomp} 17 | \RequirePackage{amsmath,amssymb} 18 | \usefonttheme{professionalfonts} 19 | \newcommand{\makeruleinbox}{{\usebeamercolor[bg]{block alerted title}\centering\hspace*{-0.7cm}\rule{\inboxrule}{0.5cm}}} 20 | \usepackage{ragged2e} 21 | 22 | % Spacing before and inside list environments to add white space before lists and between items inside lists 23 | \makeatletter 24 | \def\@listi{\leftmargin\leftmarginii 25 | \topsep 1ex % Spacing before lists 26 | \parsep 0\p@ \@plus\p@ 27 | \itemsep 6pt} % Spacing between items 28 | \makeatother 29 | 30 | \usecaptiontemplate{\small\structure{\insertcaptionname~\insertcaptionnumber: }\insertcaption} % A fix for figure numbering 31 | 32 | %----------------------------------------------------------- 33 | % Define a whole bunch of custom colours and fonts 34 | %----------------------------------------------------------- 35 | 36 | \definecolor{lgreen} {RGB}{180,210,100} 37 | \definecolor{dblue} {RGB}{20,66,129} 38 | \definecolor{ddblue} {RGB}{11,36,69} 39 | \definecolor{lred} {RGB}{220,0,0} 40 | \definecolor{nred} {RGB}{224,0,0} 41 | \definecolor{norange}{RGB}{230,120,20} 42 | \definecolor{nyellow}{RGB}{255,221,0} 43 | \definecolor{ngreen} {RGB}{98,158,31} 44 | \definecolor{dgreen} {RGB}{78,138,21} 45 | \definecolor{nblue} {RGB}{28,130,185} 46 | \definecolor{jblue} {RGB}{20,50,100} 47 | 48 | 49 | %---------------------------------------------------------------------------- 50 | % More colours added due to conflict with Colordvi package 51 | % Addition done by Nishan Mudalige (math.mudalige@uoguelph.ca) in April 2011 52 | %---------------------------------------------------------------------------- 53 | 54 | \definecolor{GreenYellow} {RGB}{217, 229, 6} % GreenYellow Approximate PANTONE 388 55 | \definecolor{Yellow} {RGB}{254, 223, 0} % Yellow Approximate PANTONE YELLOW 56 | \definecolor{Goldenrod} {RGB}{249, 214, 22} % Goldenrod Approximate PANTONE 109 57 | \definecolor{Dandelion} {RGB}{253, 200, 47} % Dandelion Approximate PANTONE 123 58 | \definecolor{Apricot} {RGB}{255, 170, 123} % Apricot Approximate PANTONE 1565 59 | \definecolor{Peach} {RGB}{255, 127, 69} % Peach Approximate PANTONE 164 60 | \definecolor{Melon} {RGB}{255, 129, 141} % Melon Approximate PANTONE 177 61 | \definecolor{YellowOrange} {RGB}{240, 171, 0} % YellowOrange Approximate PANTONE 130 62 | \definecolor{Orange} {RGB}{255, 88, 0} % Orange Approximate PANTONE ORANGE-021 63 | \definecolor{BurntOrange} {RGB}{199, 98, 43} % BurntOrange Approximate PANTONE 388 64 | \definecolor{Bittersweet} {RGB}{189, 79, 25} % Bittersweet Approximate PANTONE 167 65 | \definecolor{RedOrange} {RGB}{222, 56, 49} % RedOrange Approximate PANTONE 179 66 | \definecolor{Mahogany} {RGB}{152, 50, 34} % Mahogany Approximate PANTONE 484 67 | \definecolor{Maroon} {RGB}{152, 30, 50} % Maroon Approximate PANTONE 201 68 | \definecolor{BrickRed} {RGB}{170, 39, 47} % BrickRed Approximate PANTONE 1805 69 | \definecolor{Red} {RGB}{255, 0, 0} % Red Approx PANTONE LUMINOUS VIVID RED 70 | \definecolor{BrilliantRed} {RGB}{237, 41, 57} % Red VERY-Approx PANTONE RED 71 | \definecolor{OrangeRed} {RGB}{231, 58, 0} % OrangeRed No PANTONE match (TRIED PANTONE VIVID ORANGE RED) 72 | \definecolor{RubineRed} {RGB}{202, 0, 93} % RubineRed Approximate PANTONE RUBINE-RED 73 | \definecolor{WildStrawberry} {RGB}{203, 0, 68} % WildStrawberry Approximate PANTONE 206 74 | \definecolor{Salmon} {RGB}{250, 147, 171} % Salmon Approximate PANTONE 183 75 | \definecolor{CarnationPink} {RGB}{226, 110, 178} % CarnationPink Approximate PANTONE 218 76 | \definecolor{Magenta} {RGB}{255, 0, 144} % Magenta Approximate PANTONE PROCESS-MAGENTA 77 | \definecolor{VioletRed} {RGB}{215, 31, 133} % VioletRed Approximate PANTONE 219 78 | \definecolor{Rhodamine} {RGB}{224, 17, 157} % Rhodamine Approximate PANTONE RHODAMINE-RED 79 | \definecolor{Mulberry} {RGB}{163, 26, 126} % Mulberry Approximate PANTONE 241 80 | \definecolor{RedViolet} {RGB}{161, 0, 107} % RedViolet Approximate PANTONE 234 81 | \definecolor{Fuchsia} {RGB}{155, 24, 137} % Fuchsia Approximate PANTONE 248 82 | \definecolor{Lavender} {RGB}{240, 146, 205} % Lavender Approximate PANTONE 223 83 | \definecolor{Thistle} {RGB}{222, 129, 211} % Thistle Approximate PANTONE 245 84 | \definecolor{Orchid} {RGB}{201, 102, 205} % Orchid Approximate PANTONE 252 85 | \definecolor{DarkOrchid} {RGB}{153, 50, 204} % DarkOrchid No PANTONE match 86 | \definecolor{Purple} {RGB}{182, 52, 187} % Purple Approximate PANTONE PURPLE 87 | \definecolor{Plum} {RGB}{79, 50, 76} % Plum VERY-Approx PANTONE 518 88 | \definecolor{Violet} {RGB}{75, 8, 161} % Violet Approximate PANTONE VIOLET 89 | \definecolor{RoyalPurple} {RGB}{82, 35, 152} % RoyalPurple Approximate PANTONE 267 90 | \definecolor{BlueViolet} {RGB}{33, 7, 106} % BlueViolet Approximate PANTONE 2755 91 | \definecolor{Periwinkle} {RGB}{136, 132, 213} % Periwinkle Approximate PANTONE 2715 92 | \definecolor{CadetBlue} {RGB}{95, 158, 160} % CadetBlue Approximate PANTONE (534+535)/2, Could not find get on my own so used PANTONE-CADET BLUE 93 | \definecolor{CornflowerBlue} {RGB}{99, 177, 229} % CornflowerBlue Approximate PANTONE 292 94 | \definecolor{MidnightBlue} {RGB}{0, 65, 101} % MidnightBlue Approximate PANTONE 302 95 | \definecolor{NavyBlue} {RGB}{0, 70, 173} % NavyBlue Approximate PANTONE 293 96 | \definecolor{RoyalBlue} {RGB}{0, 35, 102} % RoyalBlue No PANTONE match 97 | \definecolor{Blue} {RGB}{0, 24, 168} % Blue Approximate PANTONE BLUE-072 98 | \definecolor{Cerulean} {RGB}{0, 122, 201} % Cerulean Approximate PANTONE 3005 99 | \definecolor{Cyan} {RGB}{0, 159, 218} % Cyan Approximate PANTONE PROCESS-CYAN 100 | \definecolor{ProcessBlue} {RGB}{0, 136, 206} % ProcessBlue Approximate PANTONE PROCESS-BLUE 101 | \definecolor{SkyBlue} {RGB}{91, 198, 232} % SkyBlue Approximate PANTONE 2985 102 | 103 | \definecolor{Turquoise} {RGB}{0, 255, 239} % Turquoise Approximate PANTONE (312+313)/2, Could not find get on my own so used PANTONE-TURQUOISE 104 | 105 | \definecolor{TealBlue} {RGB}{0, 124, 146} % TealBlue Approximate PANTONE 3145 106 | \definecolor{Aquamarine} {RGB}{0, 148, 179} % Aquamarine Approximate PANTONE 3135 107 | \definecolor{BlueGreen} {RGB}{0, 154, 166} % BlueGreen Approximate PANTONE 320 108 | \definecolor{Emerald} {RGB}{80, 200, 120} % Emerald No PANTONE match 109 | \definecolor{JungleGreen} {RGB}{0, 115, 99} % JungleGreen Approximate PANTONE 328 110 | \definecolor{SeaGreen} {RGB}{0, 176, 146} % SeaGreen Approximate PANTONE 3268 111 | \definecolor{Green} {RGB}{0, 173, 131} % Green VERY-Approx PANTONE GREEN 112 | \definecolor{ForestGreen} {RGB}{0, 105, 60} % ForestGreen Approximate PANTONE 349 113 | \definecolor{PineGreen} {RGB}{0, 98, 101} % PineGreen Approximate PANTONE 323 114 | \definecolor{LimeGreen} {RGB}{50, 205, 50} % LimeGreen No PANTONE match 115 | \definecolor{YellowGreen} {RGB}{146, 212, 0} % YellowGreen Approximate PANTONE 375 116 | \definecolor{SpringGreen} {RGB}{201, 221, 3} % SpringGreen Approximate PANTONE 381 117 | \definecolor{OliveGreen} {RGB}{135, 136, 0} % OliveGreen Approximate PANTONE 582 118 | \definecolor{RawSienna} {RGB}{149, 82, 20} % RawSienna Approximate PANTONE 154 119 | \definecolor{Sepia} {RGB}{98, 60, 27} % Sepia Approximate PANTONE 161 120 | \definecolor{Brown} {RGB}{134, 67, 30} % Brown Approximate PANTONE 1615 121 | \definecolor{Tan} {RGB}{210, 180, 140} % Tan No PANTONE match 122 | \definecolor{Gray} {RGB}{139, 141, 142} % Gray Approximate PANTONE COOL-GRAY-8 123 | 124 | \definecolor{Black} {RGB}{30, 30, 30} % Black Approximate PANTONE PROCESS-BLACK 125 | \definecolor{White} {RGB}{255, 255, 255} % White No PANTONE match 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | % set the basic colors 143 | \setbeamercolor{palette primary} {fg=black,bg=white} 144 | \setbeamercolor{palette secondary} {fg=black,bg=white} 145 | \setbeamercolor{palette tertiary} {bg=jblue,fg=white} 146 | \setbeamercolor{palette quaternary}{fg=black,bg=white} 147 | \setbeamercolor{structure}{fg=jblue} 148 | \setbeamercolor{titlelike} {bg=jblue,fg=white} 149 | \setbeamercolor{frametitle} {bg=jblue!10,fg=jblue} 150 | \setbeamercolor{cboxb}{fg=black,bg=jblue} 151 | \setbeamercolor{cboxr}{fg=black,bg=red} 152 | 153 | % set colors for itemize/enumerate 154 | \setbeamercolor{item}{fg=ngreen} 155 | \setbeamercolor{item projected}{fg=white,bg=ngreen} 156 | 157 | % set colors for blocks 158 | \setbeamercolor{block title}{fg=ngreen,bg=white} 159 | \setbeamercolor{block body}{fg=black,bg=white} 160 | 161 | % set colors for alerted blocks (blocks with frame) 162 | \setbeamercolor{block alerted title}{fg=white,bg=jblue} 163 | \setbeamercolor{block alerted body}{fg=black,bg=jblue!10} 164 | 165 | % set the fonts 166 | \setbeamerfont{section in head/foot}{series=\bfseries} 167 | \setbeamerfont{block title}{series=\bfseries} 168 | \setbeamerfont{block alerted title}{series=\bfseries} 169 | \setbeamerfont{frametitle}{series=\bfseries} 170 | \setbeamerfont{frametitle}{size=\Large} 171 | \setbeamerfont{block body}{series=\rmfamily} 172 | 173 | % set some beamer theme options 174 | \setbeamertemplate{title page}[default][colsep=-4bp,rounded=true] 175 | \setbeamertemplate{sections/subsections in toc}[square] 176 | \setbeamertemplate{items}[circle] 177 | \setbeamertemplate{blocks}[width=0.0] 178 | \beamertemplatenavigationsymbolsempty 179 | 180 | % set bibliography style 181 | \setbeamertemplate{bibliography item}[text] 182 | \setbeamercolor{bibliography item}{fg=black,bg=white} 183 | \setbeamercolor{bibliography entry author}{fg=black,bg=white} 184 | \setbeamercolor{bibliography item}{fg=black,bg=white} 185 | 186 | % define some length variables that are used by the template 187 | \newlength{\inboxwd} 188 | \newlength{\iinboxwd} 189 | \newlength{\inboxrule} 190 | \makeatletter 191 | \makeatother 192 | 193 | %============================================================================== 194 | % build the poster title 195 | %============================================================================== 196 | \setbeamertemplate{headline}{ 197 | \leavevmode 198 | \begin{columns} 199 | \begin{column}{\linewidth} 200 | \vskip1cm 201 | \centering 202 | \usebeamercolor{title in headline}{\color{jblue}\Huge{\textbf{\inserttitle}}\\[0.5ex]} 203 | \usebeamercolor{author in headline}{\color{fg}\Large{\insertauthor}\\[1ex]} 204 | \usebeamercolor{institute in headline}{\color{fg}\large{\insertinstitute}\\[1ex]} 205 | \vskip1cm 206 | \end{column} 207 | \vspace{1cm} 208 | \end{columns} 209 | \vspace{0.5in} 210 | \hspace{0.5in}\begin{beamercolorbox}[wd=47in,colsep=0.15cm]{cboxb}\end{beamercolorbox} 211 | \vspace{0.1in} 212 | } 213 | 214 | % Block definition 215 | \setbeamertemplate{block begin} 216 | { 217 | \par\vskip\medskipamount 218 | \begin{beamercolorbox}[colsep*=0ex,dp={2ex},center]{block title} 219 | \vskip-0.25cm 220 | \usebeamerfont{block title}\large\insertblocktitle 221 | \begin{flushleft} 222 | \vskip-1cm 223 | \begin{tikzpicture}[remember picture,overlay] 224 | \shade [inner color=gray,outer color=white] 225 | (0,0) rectangle (\textwidth,0.3cm); 226 | \end{tikzpicture} 227 | \end{flushleft} 228 | \end{beamercolorbox} 229 | {\parskip0pt\par} 230 | \ifbeamercolorempty[bg]{block title} 231 | {} 232 | {\ifbeamercolorempty[bg]{block body}{}{\nointerlineskip\vskip-0.5pt}} 233 | \usebeamerfont{block body} 234 | \vskip-0.5cm 235 | \begin{beamercolorbox}[colsep*=0ex,vmode]{block body} 236 | \justifying 237 | } 238 | 239 | \setbeamertemplate{block end} 240 | { 241 | \end{beamercolorbox} 242 | \vskip\smallskipamount 243 | } 244 | 245 | % Alert block definition (with frame) 246 | \setbeamertemplate{block alerted begin} 247 | { 248 | \par\vskip\medskipamount 249 | \begin{beamercolorbox}[sep=0ex,rounded=true,center,dp={2ex}]{block alerted title} 250 | \vskip0.01cm 251 | \usebeamerfont{block title}\large\insertblocktitle 252 | \end{beamercolorbox} 253 | {\parskip0pt\par} 254 | \usebeamerfont{block body} 255 | \vskip-0.8cm 256 | \begin{beamercolorbox}[sep=0.5cm, rounded=true,center]{block alerted title} 257 | \setlength{\inboxwd}{\linewidth} 258 | \addtolength{\inboxwd}{-1cm} 259 | \begin{beamercolorbox}[rounded=true,wd={\inboxwd},center]{block alerted body} 260 | \setlength{\iinboxwd}{\inboxwd} 261 | \setlength{\inboxrule}{\inboxwd} 262 | \addtolength{\iinboxwd}{-0.5cm} 263 | \addtolength{\inboxrule}{0.5cm} 264 | \begin{center} 265 | \begin{minipage}{\iinboxwd} 266 | \justifying 267 | } 268 | 269 | \setbeamertemplate{block alerted end} 270 | { 271 | \end{minipage} 272 | \end{center} 273 | \end{beamercolorbox} 274 | \end{beamercolorbox} 275 | \vskip\smallskipamount 276 | } -------------------------------------------------------------------------------- /poster/cudnn_heatmap_gilr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/cudnn_heatmap_gilr.png -------------------------------------------------------------------------------- /poster/cumsum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/cumsum.png -------------------------------------------------------------------------------- /poster/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/logo.png -------------------------------------------------------------------------------- /poster/main.aux: -------------------------------------------------------------------------------- 1 | \relax 2 | \providecommand\hyper@newdestlabel[2]{} 3 | \providecommand\HyperFirstAtBeginDocument{\AtBeginDocument} 4 | \HyperFirstAtBeginDocument{\ifx\hyper@anchor\@undefined 5 | \global\let\oldcontentsline\contentsline 6 | \gdef\contentsline#1#2#3#4{\oldcontentsline{#1}{#2}{#3}} 7 | \global\let\oldnewlabel\newlabel 8 | \gdef\newlabel#1#2{\newlabelxx{#1}#2} 9 | \gdef\newlabelxx#1#2#3#4#5#6{\oldnewlabel{#1}{{#2}{#3}}} 10 | \AtEndDocument{\ifx\hyper@anchor\@undefined 11 | \let\contentsline\oldcontentsline 12 | \let\newlabel\oldnewlabel 13 | \fi} 14 | \fi} 15 | \global\let\hyper@last\relax 16 | \gdef\HyperFirstAtBeginDocument#1{#1} 17 | \providecommand\HyField@AuxAddToFields[1]{} 18 | \providecommand\HyField@AuxAddToCoFields[2]{} 19 | \citation{*} 20 | \bibstyle{unsrt} 21 | \@writefile{toc}{\beamer@endinputifotherversion {3.36pt}} 22 | \@writefile{nav}{\beamer@endinputifotherversion {3.36pt}} 23 | \pgfsyspdfmark {pgfid7}{5622840}{99578032} 24 | \pgfsyspdfmark {pgfid8}{5622840}{29742552} 25 | \pgfsyspdfmark {pgfid9}{79448901}{137795316} 26 | \pgfsyspdfmark {pgfid10}{79448901}{93183638} 27 | \newlabel{table:rnn-throughput}{{1}{1}{}{Doc-Start}{}} 28 | \@writefile{snm}{\beamer@slide {table:rnn-throughput}{1}} 29 | \pgfsyspdfmark {pgfid11}{153274963}{137795316} 30 | \pgfsyspdfmark {pgfid12}{153274963}{59847630} 31 | \pgfsyspdfmark {pgfid13}{153274963}{22694728} 32 | \@writefile{nav}{\headcommand {\slideentry {0}{0}{1}{1/1}{}{0}}} 33 | \@writefile{nav}{\headcommand {\beamer@framepages {1}{1}}} 34 | \@writefile{nav}{\headcommand {\beamer@partpages {1}{1}}} 35 | \@writefile{nav}{\headcommand {\beamer@subsectionpages {1}{1}}} 36 | \@writefile{nav}{\headcommand {\beamer@sectionpages {1}{1}}} 37 | \@writefile{nav}{\headcommand {\beamer@documentpages {1}}} 38 | \@writefile{nav}{\headcommand {\def \inserttotalframenumber {1}}} 39 | -------------------------------------------------------------------------------- /poster/main.nav: -------------------------------------------------------------------------------- 1 | \beamer@endinputifotherversion {3.36pt} 2 | \headcommand {\slideentry {0}{0}{1}{1/1}{}{0}} 3 | \headcommand {\beamer@framepages {1}{1}} 4 | \headcommand {\beamer@partpages {1}{1}} 5 | \headcommand {\beamer@subsectionpages {1}{1}} 6 | \headcommand {\beamer@sectionpages {1}{1}} 7 | \headcommand {\beamer@documentpages {1}} 8 | \headcommand {\def \inserttotalframenumber {1}} 9 | -------------------------------------------------------------------------------- /poster/main.out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/main.out -------------------------------------------------------------------------------- /poster/main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/main.pdf -------------------------------------------------------------------------------- /poster/main.snm: -------------------------------------------------------------------------------- 1 | \beamer@slide {table:rnn-throughput}{1} 2 | -------------------------------------------------------------------------------- /poster/main.tex: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | % Jacobs Landscape Poster 3 | % LaTeX Template 4 | % Version 1.1 (14/06/14) 5 | % 6 | % Created by: 7 | % Computational Physics and Biophysics Group, Jacobs University 8 | % https://teamwork.jacobs-university.de:8443/confluence/display/CoPandBiG/LaTeX+Poster 9 | % 10 | % Further modified by: 11 | % Nathaniel Johnston (nathaniel@njohnston.ca) 12 | % 13 | % This template has been downloaded from: 14 | % http://www.LaTeXTemplates.com 15 | % 16 | % License: 17 | % CC BY-NC-SA 3.0 (http://creativecommons.org/licenses/by-nc-sa/3.0/) 18 | % 19 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 20 | 21 | %---------------------------------------------------------------------------------------- 22 | % PACKAGES AND OTHER DOCUMENT CONFIGURATIONS 23 | %---------------------------------------------------------------------------------------- 24 | 25 | \documentclass[final]{beamer} 26 | 27 | \usepackage[scale=1.24]{beamerposter} % Use the beamerposter package for laying out the poster 28 | 29 | \usetheme{confposter} % Use the confposter theme supplied with this template 30 | 31 | \setbeamercolor{block title}{fg=ngreen,bg=white} % Colors of the block titles 32 | \setbeamercolor{block body}{fg=black,bg=white} % Colors of the body of blocks 33 | \setbeamercolor{block alerted title}{fg=white,bg=dblue!70} % Colors of the highlighted block titles 34 | \setbeamercolor{block alerted body}{fg=black,bg=dblue!10} % Colors of the body of highlighted blocks 35 | % Many more colors are available for use in beamerthemeconfposter.sty 36 | 37 | %----------------------------------------------------------- 38 | % Define the column widths and overall poster size 39 | % To set effective sepwid, onecolwid and twocolwid values, first choose how many columns you want and how much separation you want between columns 40 | % In this template, the separation width chosen is 0.024 of the paper width and a 4-column layout 41 | % onecolwid should therefore be (1-(# of columns+1)*sepwid)/# of columns e.g. (1-(4+1)*0.024)/4 = 0.22 42 | % Set twocolwid to be (2*onecolwid)+sepwid = 0.464 43 | % Set threecolwid to be (3*onecolwid)+2*sepwid = 0.708 44 | 45 | \newlength{\sepwid} 46 | \newlength{\onecolwid} 47 | \newlength{\twocolwid} 48 | \newlength{\threecolwid} 49 | \setlength{\paperwidth}{48in} % A0 width: 46.8in 50 | \setlength{\paperheight}{36in} % A0 height: 33.1in 51 | \setlength{\sepwid}{0.024\paperwidth} % Separation width (white space) between columns 52 | \setlength{\onecolwid}{0.30\paperwidth} % Width of one column 53 | \setlength{\twocolwid}{0.624\paperwidth} % Width of two columns 54 | \setlength{\threecolwid}{0.708\paperwidth} % Width of three columns 55 | \setlength{\topmargin}{-0.5in} % Reduce the top margin size 56 | %----------------------------------------------------------- 57 | 58 | \usepackage{graphicx} % Required for including images 59 | 60 | \usepackage{booktabs} % Top and bottom rules for tables 61 | 62 | %---------------------------------------------------------------------------------------- 63 | % TITLE SECTION 64 | %---------------------------------------------------------------------------------------- 65 | 66 | \title{Parallelizing Linear Recurrent Neural Nets Over Sequence Length} % Poster title 67 | 68 | \author{Eric Martin and Chris Cundy*} % Author(s) 69 | 70 | \institute{*UC Berkeley} % Institution(s) 71 | 72 | %---------------------------------------------------------------------------------------- 73 | 74 | \begin{document} 75 | 76 | \addtobeamertemplate{block end}{}{\vspace*{2ex}} % White space under blocks 77 | \addtobeamertemplate{block alerted end}{}{\vspace*{2ex}} % White space under highlighted (alert) blocks 78 | 79 | \setlength{\belowcaptionskip}{2ex} % White space under figures 80 | \setlength\belowdisplayshortskip{2ex} % White space under equations 81 | 82 | \begin{frame}[t] % The whole poster is enclosed in one beamer frame 83 | 84 | \begin{columns}[t] % The whole poster consists of three major columns, the second of which is split into two columns twice - the [t] option aligns each column's content to the top 85 | 86 | \begin{column}{\sepwid}\end{column} % Empty spacer column 87 | 88 | \begin{column}{\onecolwid} % The first column 89 | 90 | %---------------------------------------------------------------------------------------- 91 | % OBJECTIVES 92 | %---------------------------------------------------------------------------------------- 93 | 94 | \begin{alertblock}{Abstract} 95 | RNN training and inference generally takes time linear in the sequence length because 96 | of non-linear sequential dependencies. 97 | We show the training and inference of RNNs with only linear 98 | sequential dependencies can be parallelized over the sequence length using the 99 | parallel scan algorithm, leading to rapid training on long sequences even with 100 | small minibatch size. We use this insight and a parallel linear recurrence CUDA 101 | kernel to accelerate several state of the art RNN architectures by up to 9x and 102 | to solve a synthetic sequence classification task with a one million time step 103 | dependency. 104 | \end{alertblock} 105 | 106 | %---------------------------------------------------------------------------------------- 107 | % INTRODUCTION 108 | %---------------------------------------------------------------------------------------- 109 | \begin{block}{Introduction} 110 | Large minibatches are necessary for computational performance but create large memory 111 | requirements and damage model generalization ability. 112 | 113 | \vspace{1ex} 114 | 115 | Linear RNNs and convolutional models such as strongly typed RNNs, Wavenet, Bytenet, 116 | Quasi-RNNs, and simple recurrent units 117 | have achieved state of the art results on many sequential tasks with rapid training times. 118 | 119 | \vspace{1ex} 120 | 121 | Given $x_t$, $\lambda_t$ can compute $h_t=\lambda_t h_{t-1} + x_t$ for $t=1\ldots T$ on $p$ 122 | processors in $O(T/p + \log(p))$ with the classic \textbf{parallel scan algorithm}. Backpropagation 123 | of gradient can also be parallelized with the same algorithm. We implemented a parallel linear 124 | recurrence operation in CUDA and integrated it with TensorFlow. 125 | \end{block} 126 | 127 | %------------------------------------------------ 128 | 129 | \begin{figure} 130 | \includegraphics[width=1.0\linewidth]{cumsum.png} 131 | \caption{Example of parallelizing cumulative sum over 3 processors} 132 | \end{figure} 133 | 134 | %------------------------------------------------ 135 | 136 | \begin{block}{Gated Impulse Linear Recurrence} 137 | Given a fast algorithm for evaluating linear recurrences, we introduce a new 138 | linear recurrent layer called \textbf{gated impulse linear recurrence (GILR)} 139 | 140 | \begin{align*} 141 | g_t &= \sigma(Ux_t + b_g) \\ 142 | i_t &= \tau(Vx_t + b_i) \\ 143 | h_t &= g_t \odot h_{t-1} + (1-g_t)\odot i_t 144 | \end{align*} 145 | \end{block} 146 | 147 | %---------------------------------------------------------------------------------------- 148 | 149 | \end{column} % End of the first column 150 | 151 | \begin{column}{\sepwid}\end{column} % Empty spacer column 152 | 153 | \begin{column}{\onecolwid} % Begin a column which is two columns wide (column 2) 154 | 155 | %---------------------------------------------------------------------------------------- 156 | % MATERIALS 157 | %---------------------------------------------------------------------------------------- 158 | 159 | \begin{block}{Linear Surrogate RNNs} 160 | 161 | RNNs have a transition function $s_t = f(s_{t-1},x_t)$. $s_t$ serves dual roles as a 162 | summary of the past as well as the output of the unit. Non-linear $f$ in units such 163 | as vanilla RNN and LSTM prevents parallelization over sequence length. 164 | \vspace{1ex} 165 | 166 | Replacing the summary of the past $s_{t-1}$ with a linear surrogate $\tilde{s}_{t-1}$ 167 | allows the easy adaption of any existing RNN architecture for parallel computation. 168 | Several recent linear RNNs can be viewed as linear surrogate RNNs. 169 | \vspace{1ex} 170 | 171 | The state of an LSTM consists of $(c_t, h_t)$. $c_t$ is already computed 172 | by linear recurrence, so a linear surrogate LSTM must only compute a 173 | linear $\tilde{h}_t$. A \textbf{GILR-LSTM} uses $\tilde{h} = \text{GILR}(x)$ 174 | \end{block} 175 | 176 | %---------------------------------------------------------------------------------------- 177 | 178 | %---------------------------------------------------------------------------------------- 179 | % MATHEMATICAL SECTION 180 | %---------------------------------------------------------------------------------------- 181 | 182 | \begin{block}{Training Runtime Results} 183 | \begin{table}[t] 184 | \begin{center} 185 | \begin{tabular}{lrrrr} 186 | \label{table:rnn-throughput} 187 | \small{Seq.\ Len.} & \small{SRU} & \small{QRNN(2)} 188 | & \small{QRNN(10)} & \small{GILR-LSTM}\\ \midrule 189 | 16 & 0.28 & 0.38 & 0.78 & 0.61\\ 190 | 256 & 0.84 & 0.86 & 0.99 & 0.91\\ 191 | 4,096 & 1.38 & 1.18 & 1.05 & 0.98\\ 192 | 65,536 & 9.21 & 6.68 & 2.05 & 1.41\\ \bottomrule 193 | \end{tabular} 194 | \end{center} 195 | \vspace{2ex} 196 | \caption{Parallel kernel speedup compared to serial linear recurrence for a variety of LS-RNNs 197 | All models use two stacked RNN layers with 256 hidden 198 | units, keeping the GPU memory usage constant by fixing $bT = 65,536$ 199 | for minibatch size $b$ and sequence length $T$. QRNN($k$) refers to a 200 | QRNN with filter size $k$. %\citep{bradbury2017quasi} 201 | } 202 | \end{table} 203 | \end{block} 204 | 205 | %---------------------------------------------------------------------------------------- 206 | \vspace{-1ex} 207 | \begin{figure} 208 | \includegraphics[width=0.9\linewidth]{cudnn_heatmap_gilr.png} 209 | \caption{Throughput (thousand steps/s) for 2 layer 256 unit cuDNN LSTM and GILR-LSTM as 210 | a function of batch size and sequence length. LSTM throughput is independent of sequence 211 | length. GILR-LSTM can achieve much greater throughput at small batch sizes.} 212 | \end{figure} 213 | 214 | 215 | \end{column} % End of column 2.1 216 | 217 | 218 | 219 | 220 | 221 | % \begin{table} 222 | % \vspace{2ex} 223 | % \begin{tabular}{l l l} 224 | % \toprule 225 | % \textbf{Treatments} & \textbf{Response 1} & \textbf{Response 2}\\ 226 | % \midrule 227 | % Treatment 1 & 0.0003262 & 0.562 \\ 228 | % Treatment 2 & 0.0015681 & 0.910 \\ 229 | % Treatment 3 & 0.0009271 & 0.296 \\ 230 | % \bottomrule 231 | % \end{tabular} 232 | % \caption{Table caption} 233 | % \end{table} 234 | 235 | % \end{block} 236 | 237 | %---------------------------------------------------------------------------------------- 238 | 239 | \begin{column}{\sepwid}\end{column} % Empty spacer column 240 | 241 | \begin{column}{\onecolwid} % The third column 242 | 243 | \begin{block}{Learning Long-Term Dependencies} 244 | 245 | Task: Learn to remember 1 bit of information for $T$ time steps. 246 | 247 | \vspace{1ex} 248 | We measured time until convergence for a 2 layer GILR-LSTM and LSTM for $T$ ranging from 249 | 1,000 to 1,000,000. 250 | 251 | \vspace{1ex} 252 | The GILR-LSTM converged in over \textbf{6x} less wall time. We demonstrated a 253 | GILR-LSTM could learn a \textbf{one million time step} sequential dependency, which 254 | is at least a 100x longer dependency than previously learned. 255 | 256 | \vspace{2ex} 257 | \begin{figure} 258 | \includegraphics[width=1.0\linewidth]{8k_for_poster.png} 259 | \caption{Accuracy on the memorization task with 8,192 sequence length} 260 | \end{figure} 261 | 262 | \end{block} 263 | 264 | 265 | %---------------------------------------------------------------------------------------- 266 | % CONCLUSION 267 | %---------------------------------------------------------------------------------------- 268 | 269 | \begin{block}{Conclusion} 270 | 271 | Parallel linear recurrence enables rapid learning on extremely long sequences at small 272 | minibatch sizes. A significant portion of deep learning's current success can be 273 | attributed to highly efficient matrix multiplication and convolution kernels. 274 | We hope that parallel linear recurrence can join these algorithms and 275 | be to large scale sequence 276 | modelling what fast convolution is to image recognition. 277 | 278 | \vspace{1ex} 279 | Possible future work includes parallel training of memory augmented models, applications 280 | to autoregressive flows, and replacing decay vector $\lambda_t$ with structured matrix 281 | $\Lambda_t$. 282 | 283 | \end{block} 284 | 285 | 286 | %---------------------------------------------------------------------------------------- 287 | % REFERENCES 288 | %---------------------------------------------------------------------------------------- 289 | 290 | \begin{block}{References} 291 | 292 | \nocite{*} % Insert publications even if they are not cited in the poster 293 | \small{\bibliographystyle{unsrt}} 294 | D.~Balduzzi and M.~Ghifary. 295 | Strongly-typed recurrent neural networks. 296 | 297 | G.~E. Blelloch. 298 | Prefix sums and their applications. 299 | 300 | J.~Bradbury, S.~Merity, C.~Xiong, and R.~Socher. 301 | Quasi-recurrent neural networks. 302 | 303 | N.~Kalchbrenner, et al. 304 | Neural machine translation in linear time. 305 | 306 | T.~Lei, Y.~Zhang, 307 | Training RNNs as fast as CNNs. 308 | 309 | A.~van~den Oord, et al. 310 | Wavenet: A generative model for raw audio. 311 | 312 | \end{block} 313 | 314 | 315 | %---------------------------------------------------------------------------------------- 316 | 317 | \end{column} % End of the third column 318 | 319 | \end{columns} % End of all the columns in the poster 320 | 321 | \end{frame} % End of the enclosing frame 322 | 323 | \end{document} 324 | -------------------------------------------------------------------------------- /poster/main.toc: -------------------------------------------------------------------------------- 1 | \beamer@endinputifotherversion {3.36pt} 2 | -------------------------------------------------------------------------------- /poster/placeholder.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eamartin/parallelizing_linear_rnns/da10de34701313e5556d67f4df4aacb76983e20c/poster/placeholder.jpg -------------------------------------------------------------------------------- /poster/sample.bib: -------------------------------------------------------------------------------- 1 | @BOOK{Smith:2012qr, 2 | title = {{B}ook {T}itle}, 3 | publisher = {Publisher}, 4 | author = {Smith, J.~M. and Jones, A.~B.}, 5 | year = {2012}, 6 | edition = {7th}, 7 | } 8 | 9 | @ARTICLE{Smith:2013jd, 10 | author = {Jones, A.~B. and Smith, J.~M.}, 11 | title = {{A}rticle {T}itle}, 12 | journal = {Journal title}, 13 | year = {2013}, 14 | volume = {13}, 15 | pages = {123-456}, 16 | number = {52}, 17 | month = {March}, 18 | publisher = {Publisher} 19 | } --------------------------------------------------------------------------------