├── FSMNCell.py └── README.md /FSMNCell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class DFSMN(object): 4 | def __init__(self, memory_size_left, memory_size_right, stride_l, stride_r, 5 | input_size, output_size, dtype=tf.float32): 6 | self._memory_size_left = memory_size_left 7 | self._memory_size_right = memory_size_right 8 | self._memory_size = memory_size_left + memory_size_right + 1 9 | self._stride_l = stride_l 10 | self._stride_r = stride_r 11 | self._input_size = input_size 12 | self._output_size = output_size 13 | self._dtype = dtype 14 | self._build_graph() 15 | 16 | def _build_graph(self): 17 | self._W = tf.get_variable("dfsmnn_w", [self._input_size, self._output_size], 18 | initializer=tf.truncated_normal_initializer(stddev=5e-2, dtype=self._dtype)) 19 | self._bias = tf.get_variable("dfsmnn_bias", [self._output_size], 20 | initializer=tf.constant_initializer(0.0, dtype=self._dtype)) 21 | self._memory_weights = tf.get_variable("memory_weights", [self._input_size, self._memory_size], 22 | initializer=tf.constant_initializer(1.0, dtype=self._dtype)) 23 | 24 | def __call__(self, skip_con, input_data): 25 | size = input_data.get_shape() 26 | batch_size = size[0].value 27 | num_steps = size[1].value 28 | 29 | def skip_connection_func(input_skip): 30 | # Need to implement the transform 31 | return input_skip 32 | 33 | # Construct memory matrix 34 | memory_matrix = [] 35 | for step in range(num_steps): 36 | left_num = tf.maximum(0, step - self._memory_size_left) 37 | right_num = tf.maximum(0, num_steps - step - self._memory_size_right - 1) 38 | weight_start = tf.minimum(-self._memory_size_left + step - 1, -1) 39 | weight_middle = -self._memory_size_left-1 40 | weight_end = tf.maximum(step - self._memory_size_left - num_steps, -self._memory_size) 41 | mem_l = self._memory_weights[weight_start:weight_middle:-1] 42 | mem_r = self._memory_weights[weight_middle-1:weight_end:-1] 43 | left_num = left_num - self._stride_l*len(mem_l) 44 | right_num = right_num - self._stride_r*len(mem_r) 45 | ele_1, ele_2 = 1, 0 46 | while ele_1 <= len(mem_l): 47 | for count1 in range(self._stride_l): 48 | mem_l.insert(ele_1, self._input_size*[0]) 49 | ele_1 += self._stride_l 50 | while ele_2 < len(mem_r): 51 | for count2 in range(self._stride_l): 52 | mem_r.insert(ele_2, self._input_size*[0]) 53 | ele_2 += self._stride_r 54 | 55 | mem = mem_l + self._memory_weights[self._memory_size_right] + mem_r 56 | mem = mem[tf.maximum(len(mem_l) - step, 0): tf.minimum(num_steps - step - 1 -len(mem_r), -1)] 57 | 58 | # strides padding 59 | if left_num <= 0: 60 | left_num = 0 61 | if right_num <=0: 62 | right_num = 0 63 | d_batch = tf.pad(mem, [[left_num, right_num], [0, 0]]) 64 | memory_matrix.append([d_batch]) 65 | memory_matrix = tf.concat(0, memory_matrix) 66 | 67 | # Compute the layer output 68 | h_hatt = tf.matmul([memory_matrix] * batch_size, input_data) 69 | p_s = skip_connection_func(skip_con) 70 | p_hatt = p_s + h_hatt 71 | h = tf.matmul(p_hatt, [self._W] * batch_size) + self._bias 72 | return h, p_hatt 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFSMN 2 | Tensorflow version of DFSMN core structure 3 | https://arxiv.org/abs/1803.05030 4 | --------------------------------------------------------------------------------