├── README.md ├── LICENSE ├── .gitignore ├── feedforward_controller.py ├── visual_util.py ├── recurrent_controller.py ├── utility.py ├── controller.py ├── synthetic_task.py ├── memory.py └── uw_dnc.py /README.md: -------------------------------------------------------------------------------- 1 | # UW-DNC 2 | source code for paper Learning to Remember More with Less Memorization 3 | https://openreview.net/forum?id=r1xlvi0qYm 4 | reference: https://github.com/Mostafa-Samir/DNC-tensorflow 5 | 6 | # Synthetic tasks 7 | Training: 8 | 1. regular DNC: 9 | python synthetic_task.py --task=copy --mode=train 10 | 2. UW DNC: 11 | python synthetic_task.py --task=copy --mode=train --hold_mem_mode=2 12 | 3. CUW DNC: 13 | python synthetic_task.py --task=copy --mode=train --hold_mem_mode=2 --cache_attend_dim=16 --cache_size=10 --hidden_dim=100 14 | 15 | Testing: 16 | --mode=test 17 | 18 | Random writing: 19 | --memo_type=random 20 | 21 | Feel free to play with other parameters 22 | 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tony 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/ 107 | data/save/ 108 | data/summary/ 109 | -------------------------------------------------------------------------------- /feedforward_controller.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from controller import BaseController 3 | import numpy as np 4 | 5 | class FeedforwardController(BaseController): 6 | def network_vars(self): 7 | initial_std = lambda in_nodes: np.min([1e-2, np.sqrt(2.0 / in_nodes)]) 8 | input_ = int(self.nn_input_size) 9 | 10 | self.W1 = tf.Variable(tf.truncated_normal([input_, self.hidden_dim], stddev=initial_std(input_)), name='layer1_W') 11 | self.W2 = tf.Variable(tf.truncated_normal([self.hidden_dim, self.hidden_dim*2], 12 | stddev=initial_std(self.hidden_dim)), name='layer2_W') 13 | self.b1 = tf.Variable(tf.zeros([self.hidden_dim]), name='layer1_b') 14 | self.b2 = tf.Variable(tf.zeros([self.hidden_dim*2]), name='layer2_b') 15 | 16 | def network_op(self, X): 17 | l1_output = tf.matmul(X, self.W1) + self.b1 18 | l1_activation = tf.nn.relu(l1_output) 19 | 20 | l2_output = tf.matmul(l1_activation, self.W2) + self.b2 21 | l2_activation = tf.nn.relu(l2_output) 22 | 23 | return l2_activation 24 | 25 | def initials(self): 26 | initial_std = lambda in_nodes: np.min([1e-2, np.sqrt(2.0 / in_nodes)]) 27 | 28 | # defining internal weights of the controller 29 | self.interface_weights = tf.Variable( 30 | tf.truncated_normal([self.nn_output_size, self.interface_vector_size], 31 | stddev=initial_std(self.nn_output_size)), 32 | name='interface_weights' 33 | ) 34 | self.nn_output_weights = tf.Variable( 35 | tf.truncated_normal([self.nn_output_size, self.output_size], stddev=initial_std(self.nn_output_size)), 36 | name='nn_output_weights' 37 | ) 38 | self.mem_output_weights = tf.Variable( 39 | tf.truncated_normal([self.word_size * self.read_heads, self.output_size], 40 | stddev=initial_std(self.word_size * self.read_heads)), 41 | name='mem_output_weights') -------------------------------------------------------------------------------- /visual_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pickle 4 | import time 5 | import sys 6 | import os 7 | import matplotlib as mpl 8 | # mpl.use('Agg') 9 | from matplotlib import pyplot as plt 10 | plt.rcParams["font.family"] = "Times New Roman" 11 | 12 | def plot_memory(emem_view, dmem_view, hold, brin, batch_ind=0): 13 | print(brin[batch_ind]) 14 | 15 | 16 | eraddress = emem_view['read_weightings'][batch_ind,:,:,0] 17 | for i in range(eraddress.shape[0]): 18 | if hold[i]: 19 | eraddress[i,:] = np.zeros(eraddress[i,:].shape) 20 | ewaddress = emem_view['write_weightings'][batch_ind,:,:] 21 | for i in range(ewaddress.shape[0]): 22 | if hold[i]: 23 | ewaddress[i,:] = np.zeros(ewaddress[i,:].shape) 24 | 25 | ewgates = emem_view['write_gates'][batch_ind, :,] 26 | 27 | # ewaddress2 = ewaddress[:-1,:] 28 | # ewaddress = ewaddress2 29 | 30 | draddress = dmem_view['read_weightings'][batch_ind, :, :, 0] 31 | # draddress2 = draddress[1:,:] 32 | # draddress = draddress2 33 | 34 | dwaddress = dmem_view['write_weightings'][batch_ind, :, :] 35 | 36 | fig = plt.figure() 37 | 38 | 39 | 40 | # a = fig.add_subplot(1, 3, 1) 41 | # a.set_title('E Read Weight') 42 | # a.set_aspect('auto') 43 | # a.xaxis.set_ticks(np.arange(0, eraddress.shape[1]),eraddress.shape[1]//min(eraddress.shape[1],10)) 44 | # a.yaxis.set_ticks(np.arange(0, eraddress.shape[0]),eraddress.shape[0]//min(eraddress.shape[0],10)) 45 | # plt.imshow(eraddress, interpolation='nearest', cmap='gray', aspect='auto') 46 | 47 | a = fig.add_subplot(1, 3, 1) 48 | a.set_title('Encoding Write gate') 49 | a.xaxis.set_ticks(np.arange(0, 1)) 50 | a.yaxis.set_ticks(np.arange(0, ewgates.shape[0]),ewgates.shape[0]//min(ewgates.shape[0],10)) 51 | plt.imshow(ewgates, interpolation='nearest', cmap='gray') 52 | 53 | a = fig.add_subplot(1, 3, 2) 54 | a.set_title('Encoding Write Weight') 55 | a.set_aspect('auto') 56 | a.xaxis.set_ticks(np.arange(0, ewaddress.shape[1]), ewaddress.shape[1]//min(ewaddress.shape[1],10)) 57 | a.yaxis.set_ticks(np.arange(0, ewaddress.shape[0]), ewaddress.shape[0]//min(ewaddress.shape[0],10)) 58 | plt.imshow(ewaddress, interpolation='nearest', cmap='gray', aspect='auto') 59 | 60 | a = fig.add_subplot(1, 3, 3) 61 | a.set_title('Decoding Read Weight') 62 | a.set_aspect('auto') 63 | a.xaxis.set_ticks(np.arange(0, draddress.shape[1]), draddress.shape[1]//min(draddress.shape[1],10)) 64 | a.yaxis.set_ticks(np.arange(0, draddress.shape[0]), draddress.shape[0]//min(draddress.shape[0],10)) 65 | plt.imshow(draddress, interpolation='nearest', cmap='gray', aspect='auto') 66 | 67 | # a = fig.add_subplot(2, 2, 4) 68 | # a.set_title('D Write Weight') 69 | # a.set_aspect('auto') 70 | # a.xaxis.set_ticks(np.arange(0, dwaddress.shape[1]), dwaddress.shape[1]//min(dwaddress.shape[1],10)) 71 | # a.yaxis.set_ticks(np.arange(0, dwaddress.shape[0]), dwaddress.shape[0]//min(dwaddress.shape[0],10)) 72 | # plt.imshow(dwaddress, interpolation='nearest', cmap='gray', aspect='auto') 73 | 74 | plt.show() 75 | 76 | -------------------------------------------------------------------------------- /recurrent_controller.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from controller import BaseController 3 | 4 | 5 | 6 | class StatelessRecurrentController(BaseController): 7 | def network_vars(self): 8 | print('--define core rnn stateless controller variables--') 9 | 10 | cell = None 11 | 12 | if self.cell_type == "nlstm": 13 | cell= tf.contrib.rnn.LayerNormBasicLSTMCell(self.hidden_dim, layer_norm = self.batch_norm, 14 | dropout_keep_prob=self.drop_out_keep) 15 | 16 | self.cell_weight_name="layer_norm_basic_lstm_cell/kernel" 17 | else: 18 | if self.cell_type == "lstm": 19 | cell = tf.nn.rnn_cell.LSTMCell(self.hidden_dim) 20 | if self.cell_type == "igru": 21 | cell = tf.nn.rnn_cell.GRUCell(self.hidden_dim, activation=tf.nn.relu, 22 | kernel_initializer=tf.contrib.keras.initializers.Identity(gain=1.0)) 23 | elif self.cell_type == "gru": 24 | cell = tf.nn.rnn_cell.GRUCell(self.hidden_dim) 25 | elif self.cell_type == "rnn": 26 | cell = tf.nn.rnn_cell.BasicRNNCell(self.hidden_dim) 27 | if not isinstance(self.drop_out_keep, int): 28 | cell = tf.contrib.rnn.DropoutWrapper(cell, 29 | input_keep_prob=self.drop_out_keep) 30 | 31 | if self.nlayer==1: 32 | print('1 layer') 33 | self.controller_cell = cell 34 | else: 35 | print('{} layers'.format(self.nlayer)) 36 | if self.cell_type == "nlstm": 37 | self.controller_cell = tf.nn.rnn_cell.MultiRNNCell([tf.contrib.rnn.LayerNormBasicLSTMCell(self.hidden_dim, layer_norm = self.batch_norm, 38 | dropout_keep_prob=self.drop_out_keep) for _ in range(self.nlayer)]) 39 | elif self.cell_type == "lstm": 40 | layers=[] 41 | for _ in range(self.nlayer): 42 | cell = tf.nn.rnn_cell.LSTMCell(self.hidden_dim) 43 | if not isinstance(self.drop_out_keep, int): 44 | cell = tf.contrib.rnn.DropoutWrapper(cell, 45 | input_keep_prob=self.drop_out_keep) 46 | layers.append(cell) 47 | self.controller_cell = tf.nn.rnn_cell.MultiRNNCell(layers) 48 | elif self.cell_type == "gru": 49 | layers=[] 50 | for _ in range(self.nlayer): 51 | cell = tf.nn.rnn_cell.GRUCell(self.hidden_dim) 52 | if not isinstance(self.drop_out_keep, int): 53 | cell = tf.contrib.rnn.DropoutWrapper(cell, 54 | input_keep_prob=self.drop_out_keep) 55 | layers.append(cell) 56 | self.controller_cell = tf.nn.rnn_cell.MultiRNNCell(layers) 57 | elif self.cell_type == "rnn": 58 | layers = [] 59 | for _ in range(self.nlayer): 60 | cell = tf.nn.rnn_cell.BasicRNNCell(self.hidden_dim) 61 | if not isinstance(self.drop_out_keep, int): 62 | cell = tf.contrib.rnn.DropoutWrapper(cell, 63 | input_keep_prob=self.drop_out_keep) 64 | layers.append(cell) 65 | self.controller_cell = tf.nn.rnn_cell.MultiRNNCell(layers) 66 | 67 | print("controller cell") 68 | print(self.controller_cell) 69 | 70 | self.state = self.controller_cell.zero_state(self.batch_size, tf.float32) 71 | 72 | 73 | 74 | 75 | 76 | def network_op(self, X, state): 77 | print('--operation rnn stateless controller variables--') 78 | X = tf.convert_to_tensor(X) 79 | return self.controller_cell(X, state) 80 | 81 | def get_state(self): 82 | return self.state 83 | 84 | def update_state(self, new_state): 85 | return tf.no_op() 86 | 87 | def zero_state(self): 88 | return self.controller_cell.zero_state(self.batch_size, tf.float32) -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def pairwise_add(u, v=None, is_batch=False): 6 | """ 7 | performs a pairwise summation between vectors (possibly the same) 8 | Parameters: 9 | ---------- 10 | u: Tensor (n, ) | (n, 1) 11 | v: Tensor (n, ) | (n, 1) [optional] 12 | is_batch: bool 13 | a flag for whether the vectors come in a batch 14 | ie.: whether the vectors has a shape of (b,n) or (b,n,1) 15 | Returns: Tensor (n, n) 16 | Raises: ValueError 17 | """ 18 | u_shape = u.get_shape().as_list() 19 | 20 | if len(u_shape) > 2 and not is_batch: 21 | raise ValueError("Expected at most 2D tensors, but got %dD" % len(u_shape)) 22 | if len(u_shape) > 3 and is_batch: 23 | raise ValueError("Expected at most 2D tensor batches, but got %dD" % len(u_shape)) 24 | 25 | if v is None: 26 | v = u 27 | else: 28 | v_shape = v.get_shape().as_list() 29 | if u_shape != v_shape: 30 | raise ValueError("Shapes %s and %s do not match" % (u_shape, v_shape)) 31 | 32 | n = u_shape[0] if not is_batch else u_shape[1] 33 | 34 | column_u = tf.reshape(u, (-1, 1) if not is_batch else (-1, n, 1)) 35 | U = tf.concat([column_u] * n, 1 if not is_batch else 2) 36 | 37 | if v is u: 38 | return U + tf.transpose(U, None if not is_batch else [0, 2, 1]) 39 | else: 40 | row_v = tf.reshape(v, (1, -1) if not is_batch else (-1, 1, n)) 41 | V = tf.concat([row_v] * n, 0 if not is_batch else 1) 42 | 43 | return U + V 44 | 45 | 46 | def decaying_softmax(shape, axis): 47 | rank = len(shape) # num dim 48 | max_val = shape[axis] 49 | 50 | weights_vector = np.arange(1, max_val + 1, dtype=np.float32) 51 | weights_vector = weights_vector[::-1] # reversed 52 | weights_vector = np.exp(weights_vector) / np.sum(np.exp(weights_vector)) # softmax weights 53 | 54 | container = np.zeros(shape, dtype=np.float32) 55 | broadcastable_shape = [1] * rank 56 | broadcastable_shape[axis] = max_val 57 | 58 | return container + np.reshape(weights_vector, broadcastable_shape) # the weight matrix is built, with axis is filled with softmax weights 59 | 60 | def unpack_into_tensorarray(value, axis, size=None): 61 | """ 62 | unpacks a given tensor along a given axis into a TensorArray 63 | Parameters: 64 | ---------- 65 | value: Tensor 66 | the tensor to be unpacked 67 | axis: int 68 | the axis to unpack the tensor along 69 | size: int 70 | the size of the array to be used if shape inference resulted in None 71 | Returns: TensorArray 72 | the unpacked TensorArray 73 | """ 74 | 75 | shape = value.get_shape().as_list() 76 | rank = len(shape) 77 | dtype = value.dtype 78 | array_size = shape[axis] if not shape[axis] is None else size 79 | 80 | if array_size is None: 81 | raise ValueError("Can't create TensorArray with size None") 82 | 83 | array = tf.TensorArray(dtype=dtype, size=array_size) #size of the axis 84 | dim_permutation = [axis] +list(range(1, axis)) + [0] + list(range(axis + 1, rank)) 85 | unpack_axis_major_value = tf.transpose(value, dim_permutation)# move axis values to the 0 dim 86 | full_array = array.unstack(unpack_axis_major_value) 87 | 88 | return full_array 89 | 90 | def pack_into_tensor(array, axis): 91 | """ 92 | packs a given TensorArray into a tensor along a given axis 93 | Parameters: 94 | ---------- 95 | array: TensorArray 96 | the tensor array to pack 97 | axis: int 98 | the axis to pack the array along 99 | Returns: Tensor 100 | the packed tensor 101 | """ 102 | 103 | packed_tensor = array.stack() # add 1 dimension at the 0 dim 104 | shape = packed_tensor.get_shape() 105 | try: 106 | rank = len(shape) 107 | except: 108 | print("unknow length of tensor array!!! assume rank 3") 109 | rank = 3 110 | 111 | dim_permutation = [axis] + list(range(1, axis)) + [0] + list(range(axis + 1, rank)) 112 | correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)# put the extra dimension to axis you want 113 | 114 | return correct_shape_tensor 115 | 116 | def pack_into_tensor2(array, axis): 117 | """ 118 | packs a given TensorArray into a tensor along a given axis 119 | Parameters: 120 | ---------- 121 | array: TensorArray 122 | the tensor array to pack 123 | axis: int 124 | the axis to pack the array along 125 | Returns: Tensor 126 | the packed tensor 127 | """ 128 | 129 | packed_tensor = array.stack() # add 1 dimension at the 0 dim 130 | 131 | 132 | dim_permutation = [axis] + list(range(1, axis)) + [0] + list(range(axis + 1, 3)) 133 | correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)# put the extra dimension to axis you want 134 | 135 | return correct_shape_tensor -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class BaseController: 5 | 6 | def __init__(self, input_size, output_size, memory_read_heads, memory_word_size, batch_size=1, 7 | use_mem=True, hidden_dim=256, is_two_mem=0, cell_type="nlstm", 8 | drop_out_keep=1, batch_norm=False, vae_mode=False, nlayer=1, clip_output=0): 9 | """ 10 | constructs a controller as described in the DNC paper: 11 | http://www.nature.com/nature/journal/vaop/ncurrent/full/nature20101.html 12 | 13 | Parameters: 14 | ---------- 15 | input_size: int 16 | the size of the data input vector 17 | output_size: int 18 | the size of the data output vector 19 | memory_read_heads: int 20 | the number of read heads in the associated external memory 21 | memory_word_size: int 22 | the size of the word in the associated external memory 23 | batch_size: int 24 | the size of the input data batch [optional, usually set by the DNC object] 25 | """ 26 | self.use_mem = use_mem 27 | self.input_size = input_size 28 | self.output_size = output_size 29 | self.read_heads = memory_read_heads # in dnc there are many read head but one write head 30 | self.word_size = memory_word_size 31 | self.batch_size = batch_size 32 | self.hidden_dim = hidden_dim 33 | self.drop_out_keep = drop_out_keep 34 | self.batch_norm= batch_norm 35 | self.vae_mode = vae_mode 36 | self.nlayer = nlayer 37 | self.cell_type = cell_type 38 | self.clip_output = clip_output 39 | # indicates if the internal neural network is recurrent 40 | # by the existence of recurrent_update and get_state methods 41 | # subclass should implement these methods if it is rnn based controller 42 | has_recurrent_update = callable(getattr(self, 'update_state', None)) 43 | has_get_state = callable(getattr(self, 'get_state', None)) 44 | self.has_recurrent_nn = has_recurrent_update and has_get_state 45 | 46 | # the actual size of the neural network input after flatenning and 47 | # concatenating the input vector with the previously read vctors from memory 48 | if use_mem or self.vae_mode: 49 | if is_two_mem>0: 50 | self.nn_input_size = self.word_size * self.read_heads*2 + self.input_size 51 | elif self.vae_mode: 52 | self.nn_input_size = self.word_size + self.input_size 53 | else: 54 | self.nn_input_size = self.word_size * self.read_heads + self.input_size 55 | else: 56 | self.nn_input_size = self.input_size 57 | 58 | self.interface_vector_size = self.word_size * self.read_heads #R read keys 59 | self.interface_vector_size += 3 * self.word_size #1 write key, 1 erase, 1 content 60 | self.interface_vector_size += 5 * self.read_heads #R read key strengths, R free gates, 3xR read modes (each mode for each read has 3 values) 61 | self.interface_vector_size += 3 # 1 write strength, 1 allocation gate, 1 write gate 62 | 63 | self.interface_weights = self.nn_output_weights = self.mem_output_weights = None 64 | self.is_two_mem = is_two_mem 65 | 66 | # define network vars 67 | with tf.name_scope("controller"): 68 | self.network_vars() 69 | 70 | self.nn_output_size = None # not yet defined in the general scope --> output of the controller not of the whole 71 | with tf.variable_scope("shape_inference"): 72 | #depend on model type --> seperate variable scope 73 | self.nn_output_size = self.get_nn_output_size() 74 | 75 | self.initials() 76 | 77 | def initials(self): 78 | """ 79 | sets the initial values of the controller transformation weights matrices 80 | this method can be overwritten to use a different initialization scheme 81 | """ 82 | # defining internal weights of the controller 83 | if self.is_two_mem==2: 84 | self.interface_weights = tf.Variable( 85 | tf.random_normal([self.nn_output_size, self.interface_vector_size*2], stddev=0.1), 86 | name='interface_weights' 87 | ) # function to compute interface: i = H x iW 88 | else: 89 | self.interface_weights = tf.Variable( 90 | tf.random_normal([self.nn_output_size, self.interface_vector_size], stddev=0.1), 91 | name='interface_weights' 92 | ) # function to compute interface: i = H x iW 93 | 94 | self.nn_output_weights = tf.Variable( 95 | tf.random_normal([self.nn_output_size, self.output_size], stddev=0.1), 96 | name='nn_output_weights' 97 | ) # function to compute output of the whole : v = H x yW 98 | if self.is_two_mem>0: 99 | self.mem_output_weights = tf.Variable( 100 | tf.random_normal([2*self.word_size * self.read_heads, self.output_size], stddev=0.1), 101 | name='mem_output_weights' 102 | ) 103 | 104 | else: 105 | # if self.vae_mode: 106 | # final_win=self.word_size 107 | # else: 108 | final_win = self.word_size * self.read_heads 109 | 110 | self.mem_output_weights = tf.Variable( 111 | tf.random_normal([final_win, self.output_size], stddev=0.1), 112 | name='mem_output_weights' 113 | ) # function to compute final output of the whole, combine output and read values: y = v + rs x Wr 114 | 115 | 116 | 117 | def network_vars(self): 118 | """ 119 | defines the variables needed by the internal neural network 120 | [the variables should be attributes of the class, i.e. self.*] 121 | """ 122 | raise NotImplementedError("network_vars is not implemented") 123 | 124 | 125 | def network_op(self, X): 126 | """ 127 | defines the controller's internal neural network operation 128 | 129 | Parameters: 130 | ---------- 131 | X: Tensor (batch_size, word_size * read_haeds + input_size) 132 | the input data concatenated with the previously read vectors from memory 133 | 134 | Returns: Tensor (batch_size, nn_output_size) 135 | """ 136 | raise NotImplementedError("network_op method is not implemented") 137 | 138 | 139 | def get_nn_output_size(self): 140 | """ 141 | retrives the output size of the defined neural network 142 | 143 | Returns: int 144 | the output's size 145 | 146 | Raises: ValueError 147 | """ 148 | 149 | input_vector = np.zeros([self.batch_size, self.nn_input_size], dtype=np.float32) #dummy data to get output size 150 | 151 | if self.has_recurrent_nn: 152 | output_vector,_ = self.network_op(input_vector, self.get_state()) # connacate all steps hidden state vector 153 | else: 154 | output_vector = self.network_op(input_vector) # just hidden state vector 155 | 156 | shape = output_vector.get_shape().as_list() # batch x output_size 157 | 158 | if len(shape) > 2: 159 | raise ValueError("Expected the neural network to output a 1D vector, but got %dD" % (len(shape) - 1)) 160 | else: 161 | return shape[1] 162 | 163 | 164 | def parse_interface_vector(self, interface_vector): 165 | """ 166 | pasres the flat interface_vector into its various components with their 167 | correct shapes 168 | 169 | Parameters: 170 | ---------- 171 | interface_vector: Tensor (batch_size, interface_vector_size) 172 | the flattened inetrface vector to be parsed 173 | 174 | Returns: dict 175 | a dictionary with the components of the interface_vector parsed 176 | """ 177 | if self.clip_output>0: 178 | interface_vector = tf.clip_by_value(interface_vector, -self.clip_output, self.clip_output) 179 | parsed = {} 180 | 181 | r_keys_end = self.word_size * self.read_heads 182 | r_strengths_end = r_keys_end + self.read_heads 183 | w_key_end = r_strengths_end + self.word_size 184 | erase_end = w_key_end + 1 + self.word_size 185 | write_end = erase_end + self.word_size 186 | free_end = write_end + self.read_heads 187 | 188 | r_keys_shape = (-1, self.word_size, self.read_heads) 189 | r_strengths_shape = (-1, self.read_heads) 190 | w_key_shape = (-1, self.word_size, 1) 191 | write_shape = erase_shape = (-1, self.word_size) 192 | free_shape = (-1, self.read_heads) 193 | modes_shape = (-1, 3, self.read_heads) 194 | 195 | # parsing the vector into its individual components 196 | parsed['read_keys'] = tf.reshape(interface_vector[:, :r_keys_end], r_keys_shape) #batch x N x R 197 | parsed['read_strengths'] = tf.reshape(interface_vector[:, r_keys_end:r_strengths_end], r_strengths_shape) #batch x R 198 | parsed['write_key'] = tf.reshape(interface_vector[:, r_strengths_end:w_key_end], w_key_shape) #batch x N x 1 --> share similarity function with read 199 | parsed['write_strength'] = tf.reshape(interface_vector[:, w_key_end], (-1, 1)) # batch x 1 200 | parsed['erase_vector'] = tf.reshape(interface_vector[:, w_key_end + 1:erase_end], erase_shape) #batch x N 201 | parsed['write_vector'] = tf.reshape(interface_vector[:, erase_end:write_end], write_shape)# batch x N 202 | parsed['free_gates'] = tf.reshape(interface_vector[:, write_end:free_end], free_shape)# batch x R 203 | parsed['allocation_gate'] = tf.expand_dims(interface_vector[:, free_end], 1)# batch x 1 204 | parsed['write_gate'] = tf.expand_dims(interface_vector[:, free_end + 1], 1)# batch x 1 205 | parsed['read_modes'] = tf.reshape(interface_vector[:, free_end + 2:], modes_shape)# batch x 3 x R 206 | 207 | # transforming the components to ensure they're in the right ranges 208 | parsed['read_strengths'] = 1 + tf.nn.softplus(parsed['read_strengths']) 209 | parsed['write_strength'] = 1 + tf.nn.softplus(parsed['write_strength']) 210 | parsed['erase_vector'] = tf.nn.sigmoid(parsed['erase_vector']) 211 | parsed['free_gates'] = tf.nn.sigmoid(parsed['free_gates']) 212 | parsed['allocation_gate'] = tf.nn.sigmoid(parsed['allocation_gate']) 213 | parsed['write_gate'] = tf.nn.sigmoid(parsed['write_gate']) 214 | parsed['read_modes'] = tf.nn.softmax(parsed['read_modes'], 1) 215 | 216 | return parsed # dict of tensors 217 | 218 | def process_zero(self): 219 | pre_output = tf.zeros([self.batch_size, self.output_size],dtype=np.float32) 220 | interface = tf.zeros([self.batch_size, self.interface_vector_size],dtype=np.float32) 221 | parsed_interface = self.parse_interface_vector(interface) 222 | parsed_interface['read_strengths'] *= 0 223 | parsed_interface['write_strength'] *= 0 224 | parsed_interface['erase_vector'] *= 0 225 | parsed_interface['free_gates'] *= 0 226 | parsed_interface['allocation_gate'] *= 0 227 | parsed_interface['write_gate'] *= 0 228 | parsed_interface['read_modes'] *= 0 229 | if self.has_recurrent_nn: 230 | return pre_output, parsed_interface, self.lstm_cell.zero_state(self.batch_size, tf.float32) 231 | else: 232 | return pre_output, parsed_interface 233 | 234 | def process_input(self, X, last_read_vectors, state=None, compute_interface=True): 235 | """ 236 | processes input data through the controller network and returns the 237 | pre-output and interface_vector 238 | 239 | Parameters: 240 | ---------- 241 | X: Tensor (batch_size, input_size) 242 | the input data batch 243 | last_read_vectors: (batch_size, word_size, read_heads) 244 | the last batch of read vectors from memory 245 | state: Tuple 246 | state vectors if the network is recurrent 247 | 248 | Returns: Tuple 249 | pre-output: Tensor (batch_size, output_size) 250 | parsed_interface_vector: dict 251 | """ 252 | 253 | flat_read_vectors = tf.reshape(last_read_vectors, (self.batch_size, -1)) #flatten R read vectors: batch x RN 254 | if self.use_mem or self.vae_mode: 255 | complete_input = tf.concat([X, flat_read_vectors], 1)#concat input --> read data 256 | else: 257 | complete_input = X 258 | 259 | # print('---') 260 | # print(X.shape) 261 | # print(flat_read_vectors.shape) 262 | # print(complete_input.shape) 263 | if self.has_recurrent_nn: 264 | nn_output, nn_state = self.network_op(complete_input, state) 265 | print('recurrent state') 266 | print(nn_state) 267 | else: 268 | nn_output = self.network_op(complete_input) 269 | 270 | pre_output = tf.matmul(nn_output, self.nn_output_weights) #batch x output_dim -->later combine with new read vector 271 | 272 | if compute_interface: 273 | interface = tf.matmul(nn_output, self.interface_weights) #batch x interface_dim 274 | else: 275 | interface = tf.zeros([self.batch_size, self.interface_vector_size]) 276 | if self.is_two_mem==2: 277 | interface1, interface2 = tf.split(interface, num_or_size_splits=2, axis=-1) 278 | parsed_interface = (self.parse_interface_vector(interface1), 279 | self.parse_interface_vector(interface2)) 280 | else: 281 | parsed_interface = self.parse_interface_vector(interface) #use to read write into vector 282 | 283 | if self.has_recurrent_nn: 284 | return pre_output, parsed_interface, nn_state 285 | else: 286 | return pre_output, parsed_interface 287 | 288 | 289 | def final_output(self, pre_output, new_read_vectors): 290 | """ 291 | returns the final output by taking rececnt memory changes into account 292 | 293 | Parameters: 294 | ---------- 295 | pre_output: Tensor (batch_size, output_size) 296 | the ouput vector from the input processing step 297 | new_read_vectors: Tensor (batch_size, words_size, read_heads) 298 | the newly read vectors from the updated memory 299 | 300 | Returns: Tensor (batch_size, output_size) 301 | """ 302 | 303 | flat_read_vectors = tf.reshape(new_read_vectors, (self.batch_size, -1)) # batch_size x flatten 304 | # final output is combine output from controller and read vectors --> just like concat hidden to read vectors 305 | # then linear transform 306 | 307 | final_output = pre_output 308 | 309 | if self.use_mem: 310 | final_output+=tf.matmul(flat_read_vectors, self.mem_output_weights) 311 | 312 | return final_output #same size as pre_output: batch_size x outputdim (classification problem, outputdim=number of labels) 313 | 314 | -------------------------------------------------------------------------------- /synthetic_task.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pickle 4 | import time 5 | import sys 6 | import os 7 | sys.path.append(os.path.dirname(os.path.abspath(__file__))+'/../') 8 | from uw_dnc import DNC 9 | from recurrent_controller import StatelessRecurrentController 10 | import visual_util 11 | 12 | def exact_acc(target_batch, predict_batch, stop_S=-1, pprint=1.0): 13 | acc=[] 14 | for b in range(target_batch.shape[0]): 15 | trim_target = [] 16 | trim_predict = [] 17 | 18 | for ti, t in enumerate(target_batch[b]): 19 | if t != stop_S: 20 | trim_target.append(t) 21 | 22 | for ti, t in enumerate(predict_batch[b]): 23 | if t != stop_S: 24 | trim_predict.append(t) 25 | 26 | if np.random.rand()>pprint or b==0: 27 | print('{} vs {}'.format(trim_target, trim_predict)) 28 | ac=0 29 | for n1,n2 in zip(trim_predict, trim_target): 30 | if n1==n2: 31 | ac+=1 32 | 33 | acc.append(float(ac/max(len(trim_target), len(trim_predict))))#have to be correct all 34 | return np.mean(acc) 35 | 36 | def llprint(message): 37 | sys.stdout.write(message) 38 | sys.stdout.flush() 39 | 40 | def load(path): 41 | return pickle.load(open(path, 'rb')) 42 | 43 | def onehot(index, size): 44 | # print('-----') 45 | # print(index) 46 | vec = np.zeros(size, dtype=np.float32) 47 | vec[int(index)] = 1.0 48 | return vec 49 | 50 | 51 | 52 | 53 | def copy_sample(vocab_lower, vocab_upper, length_from, length_to): 54 | def random_length(): 55 | if length_from == length_to: 56 | return length_from 57 | return np.random.randint(length_from, length_to + 1) 58 | seed = np.random.choice(list(range(int(vocab_lower),int(vocab_upper))), 59 | int(random_length()), replace=True) 60 | inp = seed.tolist() 61 | inp = inp + [0] 62 | out = seed.tolist() 63 | out = out + [0] 64 | 65 | return inp, out 66 | 67 | 68 | 69 | def sum_sample(vocab_lower, vocab_upper, length_from, length_to): 70 | def random_length(): 71 | if length_from == length_to: 72 | return length_from 73 | return np.random.randint(length_from, length_to + 1) 74 | seed = np.random.choice(list(range(int(vocab_lower),int(vocab_upper))), 75 | int(random_length()), replace=True) 76 | inp = seed.tolist() 77 | out=[] 78 | for i in range(len(inp)//2): 79 | out.append((inp[i]+inp[-1-i])//2) 80 | inp = inp + [0] 81 | out = out + [0] 82 | 83 | return inp, out 84 | 85 | 86 | def reverse_sample(vocab_lower, vocab_upper, length_from, length_to): 87 | def random_length(): 88 | if length_from == length_to: 89 | return length_from 90 | return np.random.randint(length_from, length_to + 1) 91 | seed = np.random.choice(list(range(int(vocab_lower),int(vocab_upper))), 92 | int(random_length()), replace=True) 93 | inp = seed.tolist() 94 | out = inp[::-1] 95 | inp = inp + [0] 96 | # out1 = seed[:len(seed)//2].tolist() 97 | # out2 = seed[len(seed)//2:].tolist() 98 | out = out + [0] 99 | # out = sorted(out1) + sorted(out2, reverse=True) + [0] 100 | 101 | return inp, out 102 | 103 | def double_sample(vocab_lower, vocab_upper, length_from, length_to): 104 | def random_length(): 105 | if length_from == length_to: 106 | return length_from 107 | return np.random.randint(length_from, length_to + 1) 108 | seed = np.random.choice(list(range(int(vocab_lower),int(vocab_upper))), 109 | int(random_length()), replace=True) 110 | inp = seed.tolist() 111 | 112 | out=inp+inp 113 | 114 | inp = inp + [0] 115 | out = out + [0] 116 | 117 | return inp, out 118 | 119 | def max_sample(vocab_lower, vocab_upper, length_from, length_to): 120 | def random_length(): 121 | if length_from == length_to: 122 | return length_from 123 | return np.random.randint(length_from, length_to + 1) 124 | seed = np.random.choice(list(range(int(vocab_lower),int(vocab_upper))), 125 | int(random_length()), replace=True) 126 | inp = seed.tolist() 127 | out=[] 128 | for i in range(len(inp)//2): 129 | if inp[i*2]>inp[i*2+1]: 130 | out.append(inp[i*2]) 131 | else: 132 | out.append(inp[i*2+1]) 133 | 134 | 135 | inp = inp + [0] 136 | out = out + [0] 137 | 138 | return inp, out 139 | 140 | def prepare_batch(bs, vocab_size, length_from, length_to, args): 141 | length=np.random.randint(length_from, length_to + 1) 142 | inps=np.zeros(shape= [bs,length+1, vocab_size]) 143 | lout=length 144 | if "sum" in args.task: 145 | lout=length//2 146 | if "max" in args.task: 147 | lout = length // 2 148 | if "double" in args.task: 149 | lout=length*2 150 | oups=np.zeros(shape=[bs,lout+1, vocab_size]) 151 | oups2=np.zeros(shape=[bs,lout+1, vocab_size]) 152 | hold_mem = np.zeros(length + 1, dtype=bool) 153 | if args.hold_mem_mode>0: 154 | hold_mem = np.ones(length+1, dtype=bool) 155 | # print(hold_mem) 156 | 157 | holdstep=(length+1)//(args.mem_size+1) 158 | holdstep=min(holdstep, args.cache_size) 159 | 160 | if "random" in args.memo_type: 161 | hold_mem=global_var["hold_mem_random"] 162 | else: 163 | if holdstep>0: 164 | for iii in range(holdstep, int(length+1), holdstep): 165 | hold_mem[iii] = False 166 | else: 167 | hold_mem[(length+1)//2] = False 168 | # print(hold_mem) 169 | lin=[] 170 | lou=[] 171 | for b in range(bs): 172 | if "copy" in args.task: 173 | i,o=copy_sample(1,vocab_size,length, length) 174 | elif "sum" in args.task: 175 | i,o=sum_sample(1,vocab_size,length, length) 176 | elif "double" in args.task: 177 | i,o=double_sample(1,vocab_size,length, length) 178 | elif "reverse" in args.task: 179 | i,o=reverse_sample(1, vocab_size, length, length) 180 | elif "max" in args.task: 181 | i,o=max_sample(1,vocab_size,length, length) 182 | 183 | 184 | lin.append(i) 185 | lou.append(o) 186 | c=0 187 | for c1 in i : 188 | inps[b, c, :]=onehot(c1, vocab_size) 189 | c+=1 190 | c = 0 191 | for c2 in o: 192 | oups[b, c, :] = onehot(c2, vocab_size) 193 | c += 1 194 | 195 | 196 | 197 | return inps, oups, oups2, length+1, lout+1, lin, lou, hold_mem 198 | 199 | def get_size_model(): 200 | total_parameters = 0 201 | for variable in tf.trainable_variables(): 202 | # shape is an array of tf.Dimension 203 | shape = variable.get_shape() 204 | variable_parameters = 1 205 | for dim in shape: 206 | variable_parameters *= dim.value 207 | total_parameters += variable_parameters 208 | return total_parameters 209 | 210 | 211 | def synthetic_task(args): 212 | dirname = os.path.dirname(os.path.abspath(__file__)) + '/data/save/' 213 | print(dirname) 214 | ckpts_dir = os.path.join(dirname, 'checkpoints_{}_task'.format(args.task)) 215 | 216 | llprint("Loading Data ... ") 217 | 218 | batch_size = args.batch_size 219 | input_size = args.number_range 220 | output_size = args.number_range 221 | print('dim out {}'.format(output_size)) 222 | words_count = args.mem_size 223 | word_size = args.word_size 224 | read_heads = args.read_heads 225 | 226 | learning_rate = args.learning_rate 227 | momentum = 0.9 228 | 229 | iterations = args.iterations 230 | start_step = 0 231 | 232 | config = tf.ConfigProto(device_count={'CPU': args.cpu_num}) 233 | config.intra_op_parallelism_threads = args.cpu_num 234 | 235 | config.gpu_options.allow_growth = True 236 | # config.gpu_options.per_process_gpu_memory_fraction = args.gpu_ratio 237 | graph = tf.Graph() 238 | with graph.as_default(): 239 | tf.contrib.framework.get_or_create_global_step() 240 | with tf.Session(graph=graph, config=config) as session: 241 | 242 | llprint("Building Computational Graph ... ") 243 | 244 | ncomputer = DNC( 245 | StatelessRecurrentController, 246 | input_size, 247 | output_size, 248 | output_size, 249 | words_count, 250 | word_size, 251 | read_heads, 252 | batch_size, 253 | use_mem=args.use_mem, 254 | controller_cell_type=args.cell_type, 255 | use_emb_encoder=False, 256 | use_emb_decoder=False, 257 | hold_mem_mode=args.hold_mem_mode, 258 | hidden_controller_dim=args.hidden_dim, 259 | cache_attend_dim=args.cache_attend_dim, 260 | nlayer=args.nlayer, 261 | clip_output=20, 262 | batch_norm=args.batch_norm, 263 | pass_encoder_state=True, 264 | parallel_rnn=10, 265 | name='m'+str(args.cache_attend_dim)+str(args.hold_mem_mode)+args.memo_type+str(args.cache_size) 266 | ) 267 | 268 | 269 | 270 | # optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=momentum) 271 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 272 | 273 | 274 | _, prob, loss, apply_gradients = ncomputer.build_loss_function(optimizer, clip_s=5) 275 | 276 | llprint("Done!\n") 277 | llprint("Done!\n") 278 | 279 | llprint("Initializing Variables ... ") 280 | session.run(tf.global_variables_initializer()) 281 | llprint("Done!\n") 282 | variables_names = [v.name for v in tf.trainable_variables()] 283 | values = session.run(variables_names) 284 | print("SHOW VARIABLES") 285 | for k, v in zip(variables_names, values): 286 | print("Variable: {} shape {} ".format(k, v.shape)) 287 | # print (v) 288 | print("*************") 289 | 290 | 291 | if args.from_checkpoint is not '': 292 | if args.from_checkpoint == 'default': 293 | from_checkpoint = ncomputer.print_config() 294 | else: 295 | from_checkpoint = args.from_checkpoint 296 | llprint("Restoring Checkpoint %s ... " % from_checkpoint) 297 | ncomputer.restore(session, ckpts_dir, from_checkpoint) 298 | llprint("Done!\n") 299 | 300 | last_100_losses = [] 301 | 302 | print('no param {}'.format(ncomputer.get_size_model())) 303 | 304 | start = 1 if start_step == 0 else start_step + 1 305 | end = start_step + iterations + 1 306 | if args.mode == 'test': 307 | start = 0 308 | end = start 309 | 310 | 311 | start_time_100 = time.time() 312 | 313 | avg_100_time = 0. 314 | avg_counter = 0 315 | if args.mode == 'train': 316 | log_dir = './data/summary/log_{}/'.format(args.task) 317 | if not os.path.isdir(log_dir): 318 | os.mkdir(log_dir) 319 | log_dir = '{}/{}/'.format(log_dir, ncomputer.print_config()) 320 | if not os.path.isdir(log_dir): 321 | os.mkdir(log_dir) 322 | train_writer = tf.summary.FileWriter(log_dir, session.graph) 323 | min_tloss = 0 324 | for i in range(start, end + 1): 325 | try: 326 | llprint("\rIteration %d/%d" % (i, end)) 327 | input_data, target_output, itarget, seq_len, decoder_length, _, _, hold = \ 328 | prepare_batch(batch_size,args.number_range, args.length_from, args.length_to,args) 329 | fd={ 330 | ncomputer.input_encoder: input_data, 331 | ncomputer.input_decoder: itarget, 332 | ncomputer.target_output: target_output, 333 | ncomputer.sequence_length: seq_len, 334 | ncomputer.decode_length: decoder_length, 335 | } 336 | if args.hold_mem_mode>0: 337 | fd[ncomputer.hold_mem]=hold 338 | summerize = (i % args.valid_time == 0) 339 | if args.mode == 'train': 340 | loss_value, _ = session.run([ 341 | loss, 342 | apply_gradients 343 | ], feed_dict=fd) 344 | last_100_losses.append(loss_value) 345 | if summerize: 346 | llprint("\n\t episode %d -->Avg. Cross-Entropy: %.7f\n" % (i, np.mean(last_100_losses))) 347 | trscores_acc = [] 348 | 349 | 350 | summary = tf.Summary() 351 | summary.value.add(tag='batch_train_loss', simple_value=np.mean(last_100_losses)) 352 | 353 | for ii in range(5): 354 | input_data, target_output, itarget, seq_len, decoder_length, brin, brout, hold = \ 355 | prepare_batch(batch_size, args.number_range, args.length_from, args.length_to, args) 356 | 357 | fd = { 358 | ncomputer.input_encoder: input_data, 359 | ncomputer.input_decoder: itarget, 360 | ncomputer.target_output: target_output, 361 | ncomputer.sequence_length: seq_len, 362 | ncomputer.decode_length: decoder_length, 363 | } 364 | if args.hold_mem_mode > 0: 365 | fd[ncomputer.hold_mem] = hold 366 | 367 | out ,emem_v, dmem_v = session.run([prob, 368 | ncomputer.packed_memory_view_encoder, 369 | ncomputer.packed_memory_view_decoder], feed_dict=fd) 370 | 371 | out = np.reshape(np.asarray(out), [-1, decoder_length, output_size]) 372 | out = np.argmax(out, axis=-1) 373 | bout_list = [] 374 | for b in range(out.shape[0]): 375 | out_list = [] 376 | for io in range(out.shape[1]): 377 | # if out[b][io] == 0: 378 | # break 379 | out_list.append(out[b][io]) 380 | bout_list.append(out_list) 381 | trscores_acc.append(exact_acc(np.asarray(brout), np.asarray(bout_list), pprint=1)) 382 | # visual_util.plot_memory(emem_v, dmem_v, hold, brin) 383 | 384 | tpre=np.mean(trscores_acc) 385 | print('acc {}'.format(tpre)) 386 | if args.mode == 'train': 387 | summary.value.add(tag='train_acc', simple_value=tpre) 388 | train_writer.add_summary(summary, i) 389 | train_writer.flush() 390 | 391 | end_time_100 = time.time() 392 | elapsed_time = (end_time_100 - start_time_100) / 60 393 | avg_counter += 1 394 | avg_100_time += (1. / avg_counter) * (elapsed_time - avg_100_time) 395 | estimated_time = (avg_100_time * ((end - i) / 100.)) / 60. 396 | print("\tAvg. 100 iterations time: %.2f minutes" % (avg_100_time)) 397 | print("\tApprox. time to completion: %.2f hours" % (estimated_time)) 398 | 399 | start_time_100 = time.time() 400 | last_100_losses = [] 401 | if args.mode == 'train' and tpre > min_tloss: 402 | min_tloss = tpre 403 | 404 | llprint("\nSaving Checkpoint ... "), 405 | 406 | ncomputer.save(session, ckpts_dir, ncomputer.print_config()) 407 | 408 | llprint("Done!\n") 409 | 410 | 411 | except KeyboardInterrupt: 412 | sys.exit(0) 413 | 414 | 415 | def str2bool(v): 416 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 417 | return True 418 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 419 | return False 420 | else: 421 | raise argparse.ArgumentTypeError('Boolean value expected.') 422 | 423 | global_var={"hold_mem_random":None} 424 | 425 | 426 | def limit_copy(args): 427 | args.task = "copy500" 428 | args.cell_type = "lstm" 429 | args.mem_type = "dnc" 430 | args.hidden_dim=256 431 | args.mem_size = 49 432 | args.word_size = 64 433 | args.batch_size = 64 434 | args.number_range = 10 435 | args.length_from = 500 436 | args.length_to = 500 437 | args.iterations = 200000 438 | # args.hold_mem_mode = 2 439 | args.cache_sze = 1000 440 | return args 441 | 442 | 443 | if __name__ == '__main__': 444 | import argparse 445 | 446 | parser = argparse.ArgumentParser() 447 | parser.add_argument('--mode', default="train") 448 | parser.add_argument('--use_mem', default=True, type=str2bool) 449 | parser.add_argument('--cell_type', default="nlstm") 450 | parser.add_argument('--mem_type', default="dnc") 451 | parser.add_argument('--task', default="copy", help="support 5 tasks: copy/reverse/double/sum/max") 452 | parser.add_argument('--from_checkpoint', default="") 453 | parser.add_argument('--hidden_dim', default=128, type=int) 454 | parser.add_argument('--cache_attend_dim', default=0, type=int) 455 | parser.add_argument('--mem_size', default=4, type=int) 456 | parser.add_argument('--word_size', default=64, type=int) 457 | parser.add_argument('--batch_size', default=64, type=int) 458 | parser.add_argument('--read_heads', default=1, type=int) 459 | parser.add_argument('--read_heads_decode', default=1, type=int) 460 | parser.add_argument('--batch_norm', default=True, type=str2bool) 461 | parser.add_argument('--number_range', default=10, type=int) 462 | parser.add_argument('--length_from', default=50, type=int) 463 | parser.add_argument('--length_to', default=50, type=int) 464 | parser.add_argument('--memo_type', default="", type=str) 465 | parser.add_argument('--hold_mem_mode', default=0, type=int) 466 | parser.add_argument('--cache_size', default=100, type=int) 467 | parser.add_argument('--nlayer', default=1, type=int) 468 | parser.add_argument('--learning_rate', default=0.001, type=float) 469 | parser.add_argument('--lr_decay_step', default=10000, type=float) 470 | parser.add_argument('--lr_decay_rate', default=0.9, type=float) 471 | parser.add_argument('--iterations', default=10000, type=int) 472 | parser.add_argument('--valid_time', default=100, type=int) 473 | parser.add_argument('--gpu_ratio', default=0.4, type=float) 474 | parser.add_argument('--cpu_num', default=5, type=int) 475 | parser.add_argument('--gpu_device', default="1,2,3", type=str) 476 | 477 | args = parser.parse_args() 478 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device 479 | 480 | 481 | 482 | 483 | 484 | # args = limit_copy(args) 485 | 486 | print(args) 487 | 488 | hold_mem_random = np.ones(args.length_to + 1, dtype=bool) 489 | c=0 490 | if "random" in args.memo_type: 491 | for iii in range(int(args.length_to + 1)): 492 | if np.random.rand() > 0.5: 493 | hold_mem_random[iii] = False 494 | c+=1 495 | 496 | global_var["hold_mem_random"]=hold_mem_random 497 | 498 | synthetic_task(args) -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import utility 4 | 5 | class Memory: 6 | 7 | def __init__(self, words_num=256, word_size=64, read_heads=1, batch_size=1): 8 | """ 9 | constructs a memory matrix with read heads and a write head as described 10 | in the DNC paper 11 | http://www.nature.com/nature/journal/vaop/ncurrent/full/nature20101.html 12 | Parameters: 13 | ---------- 14 | words_num: int 15 | the maximum number of words that can be stored in the memory at the 16 | same time 17 | word_size: int 18 | the size of the individual word in the memory 19 | read_heads: int 20 | the number of read heads that can read simultaneously from the memory 21 | batch_size: int 22 | the size of input data batch 23 | """ 24 | 25 | self.words_num = words_num 26 | self.word_size = word_size 27 | self.read_heads = read_heads 28 | self.batch_size = batch_size 29 | 30 | # a words_num x words_num identity matrix 31 | self.I = tf.constant(np.identity(words_num, dtype=np.float32)) # to support calculate link matrix 32 | 33 | # maps the indecies from the 2D array of free list per batch to 34 | # their corresponding values in the flat 1D array of ordered_allocation_weighting --> vector a --> need to be sorted 35 | self.index_mapper = tf.constant( 36 | np.cumsum([0] + [words_num] * (batch_size - 1), dtype=np.int32)[:, np.newaxis]# [[0], [word_num], [word_num*2], [word_num*3], ...] 37 | ) 38 | 39 | def init_memory(self, read_heads=None): 40 | """ 41 | returns the initial values for the memory Parameters 42 | Returns: Tuple 43 | """ 44 | if read_heads is None: 45 | return ( 46 | # each sample in batch has its own version of memory 47 | tf.fill([self.batch_size, self.words_num, self.word_size], 1e-6), # initial memory matrix 48 | tf.zeros([self.batch_size, self.words_num]), # initial usage vector u 49 | tf.zeros([self.batch_size, self.words_num]), # initial precedence vector p 50 | tf.zeros([self.batch_size, self.words_num, self.words_num]), # initial link matrix L 51 | tf.fill([self.batch_size, self.words_num], 1e-6), # initial write weighting 52 | tf.fill([self.batch_size, self.words_num, self.read_heads], 1e-6), # initial read weightings 53 | tf.fill([self.batch_size, self.word_size, self.read_heads], 1e-6), # initial read vectors 54 | ) 55 | else: 56 | return ( 57 | # each sample in batch has its own version of memory 58 | tf.fill([self.batch_size, self.words_num, self.word_size], 1e-6), # initial memory matrix 59 | tf.zeros([self.batch_size, self.words_num]), # initial usage vector u 60 | tf.zeros([self.batch_size, self.words_num]), # initial precedence vector p 61 | tf.zeros([self.batch_size, self.words_num, self.words_num]), # initial link matrix L 62 | tf.fill([self.batch_size, self.words_num], 1e-6), # initial write weighting 63 | tf.fill([self.batch_size, self.words_num, read_heads], 1e-6), # initial read weightings 64 | tf.fill([self.batch_size, self.word_size, read_heads], 1e-6), # initial read vectors 65 | ) 66 | ''' 67 | USE FOR BOTH READ WRITE 68 | ''' 69 | @staticmethod 70 | def get_lookup_weighting(memory_matrix, keys, strengths): 71 | """ 72 | retrives a content-based adderssing weighting given the keys 73 | Parameters: 74 | ---------- 75 | memory_matrix: Tensor (batch_size, words_num, word_size) 76 | the memory matrix to lookup in 77 | keys: Tensor (batch_size, word_size, number_of_keys) 78 | the keys to query the memory with 79 | strengths: Tensor (batch_size, number_of_keys, ) 80 | the list of strengths for each lookup key 81 | Returns: Tensor (batch_size, words_num, number_of_keys) 82 | The list of lookup weightings for each provided key 83 | """ 84 | 85 | normalized_memory = tf.nn.l2_normalize(memory_matrix, 2) #M=M/|M| 86 | normalized_keys = tf.nn.l2_normalize(keys, 1) #k=k/|k| 87 | 88 | similiarity = tf.matmul(normalized_memory, normalized_keys) #cosine sim: (batch_size, word_num, number_of_keys) 89 | strengths = tf.expand_dims(strengths, 1) #(batch_size, 1, number_of_keys) 90 | 91 | # (batch_size, word_num, number_of_keys) --softmax on 1-->(batch_size, word_num, number_of_keys) 92 | return tf.nn.softmax(similiarity * strengths, 1) #each batch, every row of mem is multiplied with strength and then softmax 93 | 94 | 95 | ''' 96 | WRITE PART 97 | ''' 98 | @staticmethod 99 | def update_usage_vector(usage_vector, read_weightings, write_weighting, free_gates): 100 | """ 101 | updates and returns the usgae vector given the values of the free gates 102 | and the usage_vector, read_weightings, write_weighting from previous step 103 | Parameters: 104 | ---------- 105 | usage_vector: Tensor (batch_size, words_num) 106 | read_weightings: Tensor (batch_size, words_num, read_heads) 107 | write_weighting: Tensor (batch_size, words_num) 108 | free_gates: Tensor (batch_size, read_heads, ) 109 | Returns: Tensor (batch_size, words_num, ) 110 | the updated usage vector 111 | """ 112 | free_gates = tf.expand_dims(free_gates, 1) #(batch_size, 1, read_heads ) 113 | 114 | retention_vector = tf.reduce_prod(1 - read_weightings * free_gates, 2) # (batch_size, word_num) 115 | updated_usage = (usage_vector + write_weighting - usage_vector * write_weighting) * retention_vector 116 | 117 | return updated_usage 118 | 119 | 120 | def get_allocation_weighting(self, sorted_usage, free_list): 121 | """ 122 | retreives the writing allocation weighting based on the usage free list 123 | Parameters: 124 | ---------- 125 | sorted_usage: Tensor (batch_size, words_num, ) 126 | the usage vector sorted ascendingly 127 | free_list: Tensor (batch, words_num, ) 128 | the original indecies of the sorted usage vector: free_list[0] = the least use location --> calculated by sorting usage vector 129 | Returns: Tensor (batch_size, words_num, ) 130 | the allocation weighting for each word in memory 131 | """ 132 | # cum product makes the first index of result (correspond to less usage one) has bigger value --> should be allocate 133 | shifted_cumprod = tf.cumprod(sorted_usage, axis = 1, exclusive=True) 134 | # multiply with this even make larger for less usage ones 135 | unordered_allocation_weighting = (1 - sorted_usage) * shifted_cumprod # batch_size x words_num, the first element is weight for least use 136 | 137 | mapped_free_list = free_list + self.index_mapper# boardcast add with the offset correspond to batch id 138 | flat_unordered_allocation_weighting = tf.reshape(unordered_allocation_weighting, (-1,))# flatten 139 | flat_mapped_free_list = tf.reshape(mapped_free_list, (-1,))# flatten 140 | flat_container = tf.TensorArray(tf.float32, self.batch_size * self.words_num) 141 | 142 | # fill the weights to the original locations 143 | flat_ordered_weightings = flat_container.scatter( 144 | flat_mapped_free_list, 145 | flat_unordered_allocation_weighting 146 | ) 147 | 148 | packed_wightings = flat_ordered_weightings.stack() 149 | return tf.reshape(packed_wightings, (self.batch_size, self.words_num)) 150 | 151 | 152 | @staticmethod 153 | def update_write_weighting(lookup_weighting, allocation_weighting, write_gate, allocation_gate): 154 | """ 155 | updates and returns the current write_weighting 156 | Parameters: 157 | ---------- 158 | lookup_weighting: Tensor (batch_size, words_num, 1) 159 | the weight of the lookup operation in writing --> diff from one in reading 160 | allocation_weighting: Tensor (batch_size, words_num) 161 | the weight of the allocation operation in writing 162 | write_gate: (batch_size, 1) 163 | the fraction of writing to be done 164 | allocation_gate: (batch_size, 1) 165 | the fraction of allocation to be done 166 | Returns: Tensor (batch_size, words_num) 167 | the updated write_weighting 168 | """ 169 | 170 | # remove the dimension of 1 from the lookup_weighting (the third dim, because num write head =1) 171 | lookup_weighting = tf.squeeze(lookup_weighting) 172 | 173 | # the write gate is the final decisor may help protect memory despite other factors 174 | # allocation gate is computed based on usage 175 | # allocation gate interpolate between usage and content lookup 176 | updated_write_weighting = write_gate * (allocation_gate * allocation_weighting + (1 - allocation_gate) * lookup_weighting) 177 | 178 | return updated_write_weighting 179 | 180 | 181 | @staticmethod 182 | def update_memory(memory_matrix, write_weighting, write_vector, erase_vector): 183 | """ 184 | updates and returns the memory matrix given the weighting, write and erase vectors 185 | and the memory matrix from previous step 186 | Parameters: 187 | ---------- 188 | memory_matrix: Tensor (batch_size, words_num, word_size) 189 | the memory matrix from previous step 190 | write_weighting: Tensor (batch_size, words_num) 191 | the weight of writing at each memory location 192 | write_vector: Tensor (batch_size, word_size) 193 | a vector specifying what to write 194 | erase_vector: Tensor (batch_size, word_size) 195 | a vector specifying what to erase from memory 196 | Returns: Tensor (batch_size, words_num, word_size) 197 | the updated memory matrix 198 | """ 199 | 200 | # expand data with a dimension of 1 at multiplication-adjacent location 201 | # to force matmul to behave as an outer product 202 | write_weighting = tf.expand_dims(write_weighting, 2) #(batch_size, words_num, 1) 203 | write_vector = tf.expand_dims(write_vector, 1)# (batch_size, 1, word_size) 204 | erase_vector = tf.expand_dims(erase_vector, 1)# (batch_size, 1, word_size) 205 | 206 | # weight and erase are out product to create a matrix erase 207 | # erase value is reflected differently in each location by the weight 208 | erasing = memory_matrix * (1 - tf.matmul(write_weighting, erase_vector)) #(batch_size, words_num, word_size) 209 | writing = tf.matmul(write_weighting, write_vector) #(batch_size, words_num, word_size) 210 | updated_memory = erasing + writing #(batch_size, words_num, word_size) 211 | 212 | return updated_memory 213 | 214 | ''' 215 | READ PART 216 | ''' 217 | @staticmethod 218 | def update_precedence_vector(precedence_vector, write_weighting): 219 | """ 220 | updates the precedence vector given the latest write weighting --> contain info of writting information 221 | and the precedence_vector from last step 222 | Parameters: 223 | ---------- 224 | precedence_vector: Tensor (batch_size. words_num) 225 | the precedence vector from the last time step 226 | write_weighting: Tensor (batch_size,words_num) 227 | the latest write weighting for the memory 228 | Returns: Tensor (batch_size, words_num) 229 | the updated precedence vector 230 | """ 231 | 232 | # if current write is to full memory --> no need to refer to writing information from the past--> like write weight 233 | reset_factor = 1 - tf.reduce_sum(write_weighting, 1, keep_dims=True) 234 | updated_precedence_vector = reset_factor * precedence_vector + write_weighting 235 | 236 | return updated_precedence_vector 237 | 238 | 239 | def update_link_matrix(self, precedence_vector, link_matrix, write_weighting): 240 | """ 241 | updates and returns the temporal link matrix for the latest write 242 | given the precedence vector and the link matrix from previous step 243 | Parameters: 244 | ---------- 245 | precedence_vector: Tensor (batch_size, words_num) 246 | the precedence vector from the last time step 247 | link_matrix: Tensor (batch_size, words_num, words_num) 248 | the link matrix form the last step 249 | write_weighting: Tensor (batch_size, words_num) 250 | the latest write_weighting for the memory 251 | Returns: Tensor (batch_size, words_num, words_num) 252 | the updated temporal link matrix 253 | """ 254 | 255 | write_weighting = tf.expand_dims(write_weighting, 2) #(batch_size, words_num, 1) 256 | precedence_vector = tf.expand_dims(precedence_vector, 1)#(batch_size, 1, words_num) 257 | 258 | # remove old link between all i and j because now we have new weight write 259 | reset_factor = 1 - utility.pairwise_add(write_weighting, is_batch=True)#(batch_size, words_num, 1) matrix[i,j]=1-w[i]-w[j] 260 | 261 | # add current link between last write (precedence vector) and cur write weight 262 | updated_link_matrix = reset_factor * link_matrix + tf.matmul(write_weighting, precedence_vector)#(batch_size, words_num, words_num) 263 | 264 | # diagnoal position should be 0 265 | updated_link_matrix = (1 - self.I) * updated_link_matrix # eliminates self-links 266 | 267 | return updated_link_matrix 268 | 269 | 270 | @staticmethod 271 | def get_directional_weightings(read_weightings, link_matrix): 272 | """ 273 | computes and returns the forward and backward reading weightings 274 | given the read_weightings from the previous step 275 | Parameters: 276 | ---------- 277 | read_weightings: Tensor (batch_size, words_num, read_heads) 278 | the read weightings from the last time step 279 | link_matrix: Tensor (batch_size, words_num, words_num) 280 | the temporal link matrix 281 | Returns: Tuple 282 | forward weighting: Tensor (batch_size, words_num, read_heads), 283 | backward weighting: Tensor (batch_size, words_num, read_heads) 284 | """ 285 | 286 | # if your last reading location is i, forward lead you to the next location that is written after i (current write j) 287 | forward_weighting = tf.matmul(link_matrix, read_weightings) 288 | # if your last reading location is i, backward lead you to the previous location that is written before i (last write k) 289 | backward_weighting = tf.matmul(link_matrix, read_weightings, transpose_a=True)# tranpose link and mul 290 | 291 | return forward_weighting, backward_weighting 292 | 293 | 294 | @staticmethod 295 | def update_read_weightings(lookup_weightings, forward_weighting, backward_weighting, read_mode): 296 | """ 297 | updates and returns the current read_weightings 298 | Parameters: 299 | ---------- 300 | lookup_weightings: Tensor (batch_size, words_num, read_heads) 301 | the content-based read weighting 302 | forward_weighting: Tensor (batch_size, words_num, read_heads) 303 | the forward direction read weighting 304 | backward_weighting: Tensor (batch_size, words_num, read_heads) 305 | the backward direction read weighting 306 | read_mode: Tesnor (batch_size, 3, read_heads) 307 | the softmax distribution between the three read modes 308 | Returns: Tensor (batch_size, words_num, read_heads) 309 | """ 310 | 311 | # interpolate 3 component: backward forward content 312 | 313 | backward_mode = tf.expand_dims(read_mode[:, 0, :], 1) * backward_weighting 314 | lookup_mode = tf.expand_dims(read_mode[:, 1, :], 1) * lookup_weightings 315 | forward_mode = tf.expand_dims(read_mode[:, 2, :], 1) * forward_weighting 316 | updated_read_weightings = backward_mode + lookup_mode + forward_mode 317 | 318 | return updated_read_weightings 319 | 320 | 321 | @staticmethod 322 | def update_read_vectors(memory_matrix, read_weightings): 323 | """ 324 | reads, updates, and returns the read vectors of the recently updated memory 325 | Parameters: 326 | ---------- 327 | memory_matrix: Tensor (batch_size, words_num, word_size) 328 | the recently updated memory matrix 329 | read_weightings: Tensor (batch_size, words_num, read_heads) 330 | the amount of info to read from each memory location by each read head 331 | Returns: Tensor (word_size, read_heads) 332 | """ 333 | 334 | # the read values 335 | updated_read_vectors = tf.matmul(memory_matrix, read_weightings, transpose_a=True) 336 | 337 | return updated_read_vectors 338 | 339 | ''' 340 | WRAPPER FOR WRITE PROCESS 341 | ''' 342 | def write(self, memory_matrix, usage_vector, read_weightings, write_weighting, 343 | precedence_vector, link_matrix, key, strength, free_gates, 344 | allocation_gate, write_gate, write_vector, erase_vector): 345 | """ 346 | defines the complete pipeline of writing to memory given the write variables 347 | and the memory_matrix, usage_vector, link_matrix, and precedence_vector from 348 | previous step 349 | Parameters: 350 | ---------- 351 | memory_matrix: Tensor (batch_size, words_num, word_size) 352 | the memory matrix from previous step 353 | usage_vector: Tensor (batch_size, words_num) 354 | the usage_vector from the last time step 355 | read_weightings: Tensor (batch_size, words_num, read_heads) 356 | the read_weightings from the last time step 357 | write_weighting: Tensor (batch_size, words_num) 358 | the write_weighting from the last time step 359 | precedence_vector: Tensor (batch_size, words_num) 360 | the precedence vector from the last time step 361 | link_matrix: Tensor (batch_size, words_num, words_num) 362 | the link_matrix from previous step 363 | key: Tensor (batch_size, word_size, 1) 364 | the key to query the memory location with 365 | strength: (batch_size, 1) 366 | the strength of the query key 367 | free_gates: Tensor (batch_size, read_heads) 368 | the degree to which location at read haeds will be freed 369 | allocation_gate: (batch_size, 1) 370 | the fraction of writing that is being allocated in a new locatio 371 | write_gate: (batch_size, 1) 372 | the amount of information to be written to memory 373 | write_vector: Tensor (batch_size, word_size) 374 | specifications of what to write to memory 375 | erase_vector: Tensor(batch_size, word_size) 376 | specifications of what to erase from memory 377 | Returns : Tuple 378 | the updated usage vector: Tensor (batch_size, words_num) 379 | the updated write_weighting: Tensor(batch_size, words_num) 380 | the updated memory_matrix: Tensor (batch_size, words_num, words_size) 381 | the updated link matrix: Tensor(batch_size, words_num, words_num) 382 | the updated precedence vector: Tensor (batch_size, words_num) 383 | """ 384 | 385 | lookup_weighting = self.get_lookup_weighting(memory_matrix, key, strength) 386 | 387 | new_usage_vector = self.update_usage_vector(usage_vector, read_weightings, write_weighting, free_gates) 388 | 389 | sorted_usage, free_list = tf.nn.top_k(-1 * new_usage_vector, self.words_num)#make it from min to max 390 | sorted_usage = -1 * sorted_usage #convert to normal values 391 | 392 | allocation_weighting = self.get_allocation_weighting(sorted_usage, free_list) 393 | new_write_weighting = self.update_write_weighting(lookup_weighting, allocation_weighting, write_gate, allocation_gate) 394 | new_link_matrix = self.update_link_matrix(precedence_vector, link_matrix, new_write_weighting) 395 | new_precedence_vector = self.update_precedence_vector(precedence_vector, new_write_weighting) 396 | 397 | new_memory_matrix = self.update_memory(memory_matrix, new_write_weighting, write_vector, erase_vector) 398 | 399 | return new_usage_vector, new_write_weighting, new_memory_matrix, new_link_matrix, new_precedence_vector 400 | 401 | def read_zero(self, read_heads=None): 402 | if read_heads is None: 403 | return tf.fill([self.batch_size, self.words_num, self.read_heads], 1e-6), \ 404 | tf.fill([self.batch_size, self.word_size, self.read_heads], 1e-6) 405 | else: 406 | return tf.fill([self.batch_size, self.words_num, read_heads], 1e-6), \ 407 | tf.fill([self.batch_size, self.word_size, read_heads], 1e-6) 408 | 409 | def read(self, memory_matrix, read_weightings, keys, strengths, link_matrix, read_modes): 410 | """ 411 | defines the complete pipeline for reading from memory 412 | Parameters: 413 | ---------- 414 | memory_matrix: Tensor (batch_size, words_num, word_size) 415 | the updated memory matrix from the last writing 416 | read_weightings: Tensor (batch_size, words_num, read_heads) 417 | the read weightings form the last time step 418 | keys: Tensor (batch_size, word_size, read_heads) 419 | the kyes to query the memory locations with 420 | strengths: Tensor (batch_size, read_heads) 421 | the strength of each read key 422 | link_matrix: Tensor (batch_size, words_num, words_num) 423 | the updated link matrix from the last writing 424 | read_modes: Tensor (batch_size, 3, read_heads) 425 | the softmax distribution between the three read modes 426 | Returns: Tuple 427 | the updated read_weightings: Tensor(batch_size, words_num, read_heads) 428 | the recently read vectors: Tensor (batch_size, word_size, read_heads) 429 | """ 430 | 431 | lookup_weighting = self.get_lookup_weighting(memory_matrix, keys, strengths) # content weight: later use to produce read weight 432 | 433 | 434 | # need last read weights to infer forward, backward --> just mul with link matrix 435 | forward_weighting, backward_weighting = self.get_directional_weightings(read_weightings, link_matrix) 436 | new_read_weightings = self.update_read_weightings(lookup_weighting, forward_weighting, backward_weighting, read_modes) 437 | 438 | new_read_vectors = self.update_read_vectors(memory_matrix, new_read_weightings) 439 | 440 | return new_read_weightings, new_read_vectors 441 | 442 | -------------------------------------------------------------------------------- /uw_dnc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.ops.rnn_cell import LSTMStateTuple 4 | from memory import Memory 5 | import utility 6 | import os 7 | 8 | class DNC: 9 | 10 | def __init__(self, controller_class, input_encoder_size, input_decoder_size, output_size, 11 | memory_words_num = 256, memory_word_size = 64, memory_read_heads = 1, 12 | batch_size = 1,hidden_controller_dim=256, use_emb_encoder=True, 13 | use_emb_decoder=True, train_emb=True, 14 | use_mem=True, decoder_mode=False, emb_size=64, parallel_rnn=1, 15 | write_protect=False, hold_mem_mode=0, 16 | dual_controller=False, dual_emb=True, controller_cell_type="lstm", 17 | use_teacher=False, cache_attend_dim=0, 18 | use_encoder_output=False, clip_output=0, 19 | pass_encoder_state=True, 20 | memory_read_heads_decode=None, enable_drop_out=False, 21 | enable_rnn_drop_out=False, batch_norm=False, 22 | nlayer=1, name='UW'): 23 | """ 24 | constructs a complete DNC architecture as described in the DNC paper 25 | http://www.nature.com/nature/journal/vaop/ncurrent/full/nature20101.html 26 | Parameters: 27 | ----------- 28 | controller_class: BaseController 29 | a concrete implementation of the BaseController class 30 | input_size: int 31 | the size of the input vector 32 | output_size: int 33 | the size of the output vector 34 | max_sequence_length: int 35 | the maximum length of an input sequence 36 | memory_words_num: int 37 | the number of words that can be stored in memory 38 | memory_word_size: int 39 | the size of an individual word in memory 40 | memory_read_heads: int 41 | the number of read heads in the memory 42 | batch_size: int 43 | the size of the data batch 44 | """ 45 | saved_args = locals() 46 | print("saved_args is", saved_args) 47 | self.name=name 48 | self.parallel_rnn=parallel_rnn 49 | self.input_encoder_size = input_encoder_size 50 | self.input_decoder_size = input_decoder_size 51 | self.output_size = output_size 52 | self.words_num = memory_words_num 53 | self.word_size = memory_word_size 54 | self.read_heads = memory_read_heads 55 | self.batch_norm=batch_norm 56 | self.clip_output=clip_output 57 | 58 | if memory_read_heads_decode is None: 59 | self.read_heads_decode = memory_read_heads 60 | else: 61 | self.read_heads_decode = memory_read_heads_decode 62 | 63 | 64 | self.batch_size = batch_size 65 | self.unpacked_input_encoder_data = None 66 | self.unpacked_input_decoder_data = None 67 | self.packed_output = None 68 | self.packed_output_encoder = None 69 | self.packed_memory_view_encoder = None 70 | self.packed_memory_view_decoder = None 71 | self.decoder_mode = decoder_mode 72 | self.emb_size = emb_size 73 | self.emb_size2 = emb_size 74 | self.dual_emb = dual_emb 75 | self.use_mem = use_mem 76 | self.controller_cell_type = controller_cell_type 77 | self.use_emb_encoder = use_emb_encoder 78 | self.use_emb_decoder = use_emb_decoder 79 | self.hidden_controller_dim = hidden_controller_dim 80 | self.cache_attend_dim = cache_attend_dim 81 | self.use_teacher = use_teacher 82 | self.teacher_force = tf.placeholder(tf.bool,[None], name='teacher') 83 | self.hold_mem_mode=hold_mem_mode 84 | if self.hold_mem_mode>0: 85 | self.hold_mem = tf.placeholder(tf.bool, [None], name='hold_mem') 86 | else: 87 | self.hold_mem=None 88 | 89 | self.use_encoder_output=use_encoder_output 90 | self.pass_encoder_state=pass_encoder_state 91 | self.clear_mem = tf.placeholder(tf.bool,None, name='clear_mem') 92 | self.drop_out_keep = tf.placeholder_with_default(1.0, None, name='drop_out_keep') 93 | self.drop_out_rnn_keep = tf.placeholder_with_default(1.0, None, name='drop_out_rnn_keep') 94 | 95 | self.nlayer=nlayer 96 | self.drop_out_v = 1 97 | self.drop_out_rnnv = 1 98 | if enable_drop_out: 99 | self.drop_out_v = self.drop_out_keep 100 | 101 | if enable_rnn_drop_out: 102 | self.drop_out_rnnv = self.drop_out_rnn_keep 103 | 104 | 105 | 106 | self.controller_out = self.output_size 107 | 108 | 109 | 110 | if self.use_emb_encoder is False: 111 | self.emb_size=input_encoder_size 112 | 113 | if self.use_emb_decoder is False: 114 | self.emb_size2=input_decoder_size #pointer mode not use 115 | 116 | 117 | if self.cache_attend_dim>0: 118 | 119 | 120 | self.cW_a = tf.get_variable('cW_a', [self.hidden_controller_dim, self.cache_attend_dim], 121 | initializer=tf.random_normal_initializer(stddev=0.1)) 122 | 123 | value_size = self.hidden_controller_dim 124 | 125 | self.cU_a = tf.get_variable('cU_a', [value_size, self.cache_attend_dim], 126 | initializer=tf.random_normal_initializer(stddev=0.1)) 127 | if self.use_mem: 128 | self.cV_a = tf.get_variable('cV_a', [self.read_heads*self.word_size, self.cache_attend_dim], 129 | initializer=tf.random_normal_initializer(stddev=0.1)) 130 | self.cv_a = tf.get_variable('cv_a', [self.cache_attend_dim], 131 | initializer=tf.random_normal_initializer(stddev=0.1)) 132 | 133 | # DNC (or NTM) should be structurized into 2 main modules: 134 | # all the graph is setup inside these twos: 135 | self.W_emb_encoder = tf.get_variable('embe_w', [self.input_encoder_size, self.emb_size], trainable=train_emb, 136 | initializer=tf.random_uniform_initializer(minval=-1, maxval=1)) 137 | if self.dual_emb: 138 | self.W_emb_decoder = tf.get_variable('embd_w', [self.output_size, self.emb_size2],trainable=train_emb, 139 | initializer=tf.random_uniform_initializer(minval=-1, maxval=1)) 140 | 141 | self.memory = Memory(self.words_num, self.word_size, self.read_heads, self.batch_size) 142 | with tf.variable_scope('controller_scope'): 143 | self.controller = controller_class(self.emb_size, self.controller_out, self.read_heads, 144 | self.word_size, self.batch_size, use_mem, 145 | cell_type=controller_cell_type, batch_norm=batch_norm, 146 | hidden_dim=hidden_controller_dim, nlayer=nlayer, 147 | drop_out_keep=self.drop_out_rnnv, clip_output=self.clip_output) 148 | 149 | self.dual_controller = dual_controller 150 | if self.dual_controller: 151 | with tf.variable_scope('controller2_scope'): 152 | if use_mem: 153 | self.controller2 = controller_class(self.emb_size2, self.controller_out, self.read_heads_decode, 154 | self.word_size, self.batch_size, use_mem, 155 | cell_type=controller_cell_type, batch_norm=batch_norm, clip_output=self.clip_output, 156 | hidden_dim=hidden_controller_dim, drop_out_keep=self.drop_out_rnnv, nlayer=nlayer) 157 | else: 158 | 159 | self.controller2 = controller_class(self.emb_size2+hidden_controller_dim, self.controller_out, self.read_heads_decode, 160 | self.word_size, self.batch_size, use_mem, 161 | cell_type=controller_cell_type, batch_norm=batch_norm, clip_output=self.clip_output, 162 | hidden_dim=hidden_controller_dim, drop_out_keep=self.drop_out_rnnv, nlayer=nlayer) 163 | self.write_protect = write_protect 164 | 165 | 166 | # input data placeholders 167 | 168 | self.target_output = tf.placeholder(tf.float32, [batch_size, None, output_size], name='targets') 169 | 170 | self.input_encoder = tf.placeholder(tf.float32, [batch_size, None, input_encoder_size], name='input_encoder') 171 | 172 | self.input_decoder = tf.placeholder(tf.float32, [batch_size, None, input_decoder_size], name='input_decoder') 173 | 174 | self.mask = tf.placeholder(tf.bool, [batch_size, None], name='mask') 175 | self.sequence_length = tf.placeholder(tf.int32, name='sequence_length')# variant length? 176 | self.decode_length = tf.placeholder(tf.int32, name='decode_length') # variant length? 177 | 178 | 179 | self.build_graph() 180 | 181 | 182 | 183 | # The nature of DNC is to process data by step and remmeber data at each time step when necessary 184 | # If input has sequence format --> suitable with RNN core controller --> each time step in RNN equals 1 time step in DNC 185 | # or just feed input to MLP --> each feed is 1 time step 186 | def _step_op_encoder(self,time, time2, step, memory_state, controller_state=None, cache_controller_hidden=None): 187 | """ 188 | performs a step operation on the input step data 189 | Parameters: 190 | ---------- 191 | step: Tensor (batch_size, input_size) 192 | memory_state: Tuple 193 | a tuple of current memory parameters 194 | controller_state: Tuple 195 | the state of the controller if it's recurrent 196 | Returns: Tuple 197 | output: Tensor (batch_size, output_size) 198 | memory_view: dict 199 | """ 200 | 201 | last_read_vectors = memory_state[6] # read values from memory 202 | last_read_weights = memory_state[5] 203 | pre_output, interface, nn_state = None, None, None 204 | var = {"time2": time2, "step": step, "compute_interface":True} 205 | cache_controller_hidden = cache_controller_hidden.write(time, controller_state) 206 | 207 | def let_compute(): 208 | var["compute_interface"] = True 209 | ns2 = controller_state 210 | 211 | if self.cache_attend_dim > 0: 212 | # values = utility.pack_into_tensor(cache_controller_hidden, axis=1) 213 | values = cache_controller_hidden.gather(tf.range(time-time2, time+1)) 214 | 215 | value_size = self.hidden_controller_dim 216 | 217 | encoder_outputs = \ 218 | tf.reshape(values, [self.batch_size, -1, value_size]) # bs x Lin x h 219 | v = tf.reshape(tf.matmul(tf.reshape(encoder_outputs, [-1, value_size]), self.cU_a), 220 | [self.batch_size, -1, self.cache_attend_dim]) 221 | 222 | if self.use_mem: 223 | v += tf.reshape( 224 | tf.matmul(tf.reshape(last_read_vectors, [-1, self.read_heads * self.word_size]), 225 | self.cV_a), 226 | [self.batch_size, 1, self.cache_attend_dim]) 227 | ns, statetype = self.get_hidden_value_from_state(controller_state) 228 | print("state typeppppp") 229 | print(controller_state) 230 | print(ns) 231 | v += tf.reshape( 232 | tf.matmul(tf.reshape(ns, [-1, self.hidden_controller_dim]), self.cW_a), 233 | [self.batch_size, 1, self.cache_attend_dim]) # bs.Lin x h_att 234 | print('state include only h') 235 | 236 | v = tf.reshape(tf.tanh(v), [-1, self.cache_attend_dim]) 237 | eijs = tf.matmul(v, tf.expand_dims(self.cv_a, 1)) # bs.Lin x 1 238 | eijs = tf.reshape(eijs, [self.batch_size, -1]) # bs x Lin 239 | alphas = tf.nn.softmax(eijs) 240 | 241 | att = tf.reduce_sum(encoder_outputs * tf.expand_dims(alphas, 2), 1) # bs x h x 1 242 | att = tf.reshape(att, [self.batch_size, value_size]) # bs x h 243 | # step = tf.concat([var["step"], att], axis=-1) # bs x (encoder_input_size + h) 244 | # step = tf.matmul(step, self.cW_ah) # bs x encoder_input_size (or emb_size) 245 | if statetype==1: 246 | ns2=list(controller_state) 247 | ns2[-1][-1]=att 248 | ns2=tuple(ns2) 249 | elif statetype==2 or statetype==3: 250 | # ns2 = list(controller_state) 251 | ns2 = LSTMStateTuple(controller_state[0],att) 252 | # ns2 = tuple(ns2) 253 | elif statetype==4: 254 | return att 255 | 256 | return ns2 257 | 258 | 259 | def hold_compute(): 260 | var["compute_interface"] =False 261 | return controller_state 262 | 263 | if self.hold_mem_mode>0: 264 | controller_state = tf.cond(self.hold_mem[time], hold_compute, let_compute) 265 | 266 | #controller_state = let_compute() 267 | # compute oututs from controller 268 | 269 | if self.controller.has_recurrent_nn: 270 | compute_interface = var["compute_interface"] 271 | # controller state is the rnn cell state pass through each time step 272 | if not self.use_emb_encoder: 273 | step2 = tf.reshape(step, [-1, self.input_encoder_size]) 274 | pre_output, interface, nn_state= self.controller.process_input(step2, last_read_vectors, 275 | controller_state, 276 | compute_interface=compute_interface) 277 | else: 278 | pre_output, interface, nn_state = self.controller.process_input(step, last_read_vectors, 279 | controller_state, 280 | compute_interface=compute_interface) 281 | else: 282 | pre_output, interface = self.controller.process_input(step, last_read_vectors) 283 | 284 | # memory_matrix isthe copy of memory for reading process later 285 | # do the write first 286 | 287 | var["state"]=nn_state 288 | 289 | if self.hold_mem_mode >0: 290 | def hold_write(): 291 | var["time2"] += 1 292 | if self.hold_mem_mode >2: 293 | return self.memory.write( 294 | memory_state[0], memory_state[1], memory_state[5], 295 | memory_state[4], memory_state[2], memory_state[3], 296 | interface['write_key'], 297 | interface['write_strength'], 298 | interface['free_gates'], 299 | interface['allocation_gate'], 300 | interface['write_gate'], 301 | interface['write_vector'], 302 | interface['erase_vector'] 303 | ) 304 | else: 305 | return memory_state[1], memory_state[4], memory_state[0], memory_state[3], memory_state[2] 306 | 307 | def let_write(): 308 | # interface["write_gate"] = (1 - self.max_lambda**tf.cast(var["time2"],tf.float32)) 309 | var["time2"] = 0 310 | # var["state"] = self.controller.zero_state() 311 | 312 | 313 | 314 | return self.memory.write( 315 | memory_state[0], memory_state[1], memory_state[5], 316 | memory_state[4], memory_state[2], memory_state[3], 317 | interface['write_key'], 318 | interface['write_strength'], 319 | interface['free_gates'], 320 | interface['allocation_gate'], 321 | interface['write_gate'], 322 | interface['write_vector'], 323 | interface['erase_vector'] 324 | ) 325 | 326 | usage_vector, write_weighting, memory_matrix, link_matrix, precedence_vector=\ 327 | tf.cond(self.hold_mem[time], hold_write, let_write) 328 | else: 329 | usage_vector, write_weighting, memory_matrix, link_matrix, precedence_vector= \ 330 | self.memory.write( 331 | memory_state[0], memory_state[1], memory_state[5], 332 | memory_state[4], memory_state[2], memory_state[3], 333 | interface['write_key'], 334 | interface['write_strength'], 335 | interface['free_gates'], 336 | interface['allocation_gate'], 337 | interface['write_gate'], 338 | interface['write_vector'], 339 | interface['erase_vector'] 340 | ) 341 | 342 | # then do the read, read after write because the write weight is needed to produce temporal linklage to guide the reading 343 | if self.hold_mem_mode>1: 344 | def hold_read(): 345 | return last_read_weights, last_read_vectors 346 | 347 | def let_read(): 348 | return self.memory.read( 349 | memory_matrix, 350 | memory_state[5], 351 | interface['read_keys'], 352 | interface['read_strengths'], 353 | link_matrix, 354 | interface['read_modes'], 355 | ) 356 | 357 | read_weightings, read_vectors = tf.cond(self.hold_mem[time], hold_read, let_read) 358 | 359 | else: 360 | 361 | read_weightings, read_vectors = self.memory.read( 362 | memory_matrix, 363 | memory_state[5], 364 | interface['read_keys'], 365 | interface['read_strengths'], 366 | link_matrix, 367 | interface['read_modes'], 368 | ) 369 | fout=None 370 | if self.use_encoder_output: 371 | fout = self.controller.final_output(pre_output, read_vectors) 372 | 373 | if self.clip_output>0: 374 | fout = tf.clip_by_value(fout, -self.clip_output, self.clip_output) 375 | 376 | return [ 377 | # report new memory state to be updated outside the condition branch 378 | memory_matrix, #0 379 | 380 | # neccesary for next step to compute memory stuffs 381 | usage_vector, #1 382 | precedence_vector, #2 383 | link_matrix, #3 384 | write_weighting, #4 385 | read_weightings, #5 386 | read_vectors, #6 387 | 388 | # the final output of dnc 389 | fout, #7 390 | 391 | # the values public info to outside 392 | interface['read_modes'], #8 393 | interface['allocation_gate'], #9 394 | interface['write_gate'], #10 395 | 396 | # report new state of RNN if exists, neccesary for next step to compute inner controller stuff 397 | nn_state if nn_state is not None else tf.zeros(1), #11 398 | var["time2"], #12 399 | cache_controller_hidden #13 400 | ] 401 | 402 | def _step_op_decoder(self, time, step, memory_state, 403 | controller_state=None, controller_hiddens=None): 404 | """ 405 | performs a step operation on the input step data 406 | Parameters: 407 | ---------- 408 | step: Tensor (batch_size, input_size) 409 | memory_state: Tuple 410 | a tuple of current memory parameters 411 | controller_state: Tuple 412 | the state of the controller if it's recurrent 413 | Returns: Tuple 414 | output: Tensor (batch_size, output_size) 415 | memory_view: dict 416 | """ 417 | 418 | last_read_weights = memory_state[5] 419 | last_read_vectors = memory_state[6] # read values from memory 420 | pre_output, interface, nn_state = None, None, None 421 | 422 | if self.dual_controller: 423 | controller=self.controller2 424 | else: 425 | controller=self.controller 426 | 427 | # compute outputs from controller 428 | if controller.has_recurrent_nn: 429 | if not self.use_emb_decoder: 430 | step2 = tf.reshape(step, [-1, self.output_size]) 431 | else: 432 | step2 = step 433 | pre_output, interface, nn_state = controller.process_input(step2, last_read_vectors, controller_state) 434 | 435 | else: 436 | pre_output, interface = controller.process_input(step, last_read_vectors) 437 | 438 | # memory_matrix isthe copy of memory for reading process later 439 | # do the write first 440 | if self.write_protect: 441 | usage_vector, write_weighting, memory_matrix, link_matrix, precedence_vector \ 442 | =memory_state[1], memory_state[4], memory_state[0], memory_state[3], memory_state[2] 443 | 444 | else: 445 | usage_vector, write_weighting, memory_matrix, link_matrix, precedence_vector = self.memory.write( 446 | memory_state[0], memory_state[1], memory_state[5], 447 | memory_state[4], memory_state[2], memory_state[3], 448 | interface['write_key'], 449 | interface['write_strength'], 450 | interface['free_gates'], 451 | interface['allocation_gate'], 452 | interface['write_gate'], 453 | interface['write_vector'], 454 | interface['erase_vector'] 455 | ) 456 | 457 | # then do the read, read after write because the write weight is needed to produce temporal linklage to guide the reading 458 | 459 | 460 | read_weightings, read_vectors = self.memory.read( 461 | memory_matrix, 462 | memory_state[5], 463 | interface['read_keys'], 464 | interface['read_strengths'], 465 | link_matrix, 466 | interface['read_modes'], 467 | ) 468 | fout = controller.final_output(pre_output, read_vectors) # bs x output_size 469 | 470 | if self.clip_output>0: 471 | fout = tf.clip_by_value(fout, -self.clip_output, self.clip_output) 472 | 473 | return [ 474 | # report new memory state to be updated outside the condition branch 475 | memory_matrix, # 0 476 | 477 | # neccesary for next step to compute memory stuffs 478 | usage_vector, # 1 479 | precedence_vector, # 2 480 | link_matrix, # 3 481 | write_weighting, # 4 482 | read_weightings, # 5 483 | read_vectors, # 6 484 | 485 | # the final output of dnc 486 | fout, # 7 487 | 488 | # the values public info to outside 489 | interface['read_modes'], # 8 490 | interface['allocation_gate'], # 9 491 | interface['write_gate'], # 10 492 | 493 | # report new state of RNN if exists, neccesary for next step to compute inner controller stuff 494 | nn_state if nn_state is not None else tf.zeros(1), # 11 495 | ] 496 | 497 | ''' 498 | THIS WRAPPER FOR ONE STEP OF COMPUTATION --> INTERFACE FOR SCAN/WHILE LOOP 499 | ''' 500 | def _loop_body_encoder(self, time, memory_state, outputs, free_gates, allocation_gates, write_gates, 501 | read_weightings, write_weightings, usage_vectors, controller_state, 502 | outputs_cache, controller_hiddens, time2, cache_controller_hiddens): 503 | """ 504 | the body of the DNC sequence processing loop 505 | Parameters: 506 | ---------- 507 | time: Tensor 508 | outputs: TensorArray 509 | memory_state: Tuple 510 | free_gates: TensorArray 511 | allocation_gates: TensorArray 512 | write_gates: TensorArray 513 | read_weightings: TensorArray, 514 | write_weightings: TensorArray, 515 | usage_vectors: TensorArray, 516 | controller_state: Tuple 517 | Returns: Tuple containing all updated arguments 518 | """ 519 | 520 | # dynamic tensor array input 521 | 522 | if self.use_emb_encoder: 523 | step_input = tf.matmul(self.unpacked_input_encoder_data.read(time), self.W_emb_encoder) 524 | else: 525 | step_input = self.unpacked_input_encoder_data.read(time) 526 | 527 | # compute one step of controller 528 | op = self._step_op_encoder 529 | 530 | output_list = op(time, time2, step_input, memory_state, controller_state, cache_controller_hiddens) 531 | # update memory parameters 532 | 533 | # new_controller_state = tf.zeros(1) 534 | new_memory_state = tuple(output_list[0:7]) 535 | new_controller_state = output_list[11] #state hidden values 536 | hstate, _ = self.get_hidden_value_from_state(new_controller_state) 537 | 538 | controller_hiddens = controller_hiddens.write(time, hstate) 539 | 540 | if self.use_encoder_output: 541 | outputs = outputs.write(time, output_list[7])# new output is updated 542 | outputs_cache = outputs_cache.write(time, output_list[7])# new output is updated 543 | # collecting memory view for the current step 544 | free_gates = free_gates.write(time, output_list[8]) 545 | allocation_gates = allocation_gates.write(time, output_list[9]) 546 | write_gates = write_gates.write(time, output_list[10]) 547 | read_weightings = read_weightings.write(time, output_list[5]) 548 | write_weightings = write_weightings.write(time, output_list[4]) 549 | usage_vectors = usage_vectors.write(time, output_list[1]) 550 | 551 | # all variables have been updated should be return for next step reference 552 | return ( 553 | time + 1, #0 554 | new_memory_state, #1 555 | outputs, #2 556 | free_gates,allocation_gates, write_gates, #3 4 5 557 | read_weightings, write_weightings, usage_vectors, #6 7 8 558 | new_controller_state, #9 559 | outputs_cache, #10 560 | controller_hiddens, #11 561 | output_list[-2], #12 562 | output_list[-1], #13 563 | ) 564 | 565 | def _loop_body_decoder(self, time, memory_state, outputs, free_gates, allocation_gates, write_gates, 566 | read_weightings, write_weightings, usage_vectors, controller_state, 567 | outputs_cache, controller_hiddens, 568 | encoder_write_weightings, encoder_controller_hiddens): 569 | """ 570 | the body of the DNC sequence processing loop 571 | Parameters: 572 | ---------- 573 | time: Tensor 574 | outputs: TensorArray 575 | memory_state: Tuple 576 | free_gates: TensorArray 577 | allocation_gates: TensorArray 578 | write_gates: TensorArray 579 | read_weightings: TensorArray, 580 | write_weightings: TensorArray, 581 | usage_vectors: TensorArray, 582 | controller_state: Tuple 583 | Returns: Tuple containing all updated arguments 584 | """ 585 | 586 | # dynamic tensor array input 587 | if self.decoder_mode: 588 | def fn1(): 589 | return tf.zeros([self.batch_size, self.output_size]) 590 | def fn2(): 591 | def fn2_1(): 592 | return self.target_output[:, time - 1, :] 593 | 594 | def fn2_2(): 595 | inds = tf.argmax(outputs_cache.read(time - 1), axis=-1) 596 | return tf.one_hot(inds, depth=self.output_size) 597 | 598 | if self.use_teacher: 599 | return tf.cond(self.teacher_force[time - 1], fn2_1, fn2_2) 600 | else: 601 | return fn2_2() 602 | 603 | feed_value = tf.cond(time>0,fn2,fn1) 604 | 605 | 606 | if not self.use_emb_decoder: 607 | r = tf.reshape(feed_value, [self.batch_size, self.input_decoder_size]) 608 | step_input = r 609 | elif self.dual_emb: 610 | step_input = tf.matmul(feed_value, self.W_emb_decoder) 611 | else: 612 | step_input = tf.matmul(feed_value, self.W_emb_encoder) 613 | 614 | else: 615 | if self.use_emb_decoder: 616 | if self.dual_emb: 617 | step_input = tf.matmul(self.unpacked_input_decoder_data.read(time), self.W_emb_decoder) 618 | else: 619 | step_input = tf.matmul(self.unpacked_input_decoder_data.read(time), self.W_emb_encoder) 620 | else: 621 | step_input = self.unpacked_input_decoder_data.read(time) 622 | print(step_input.shape) 623 | print('ssss') 624 | 625 | # compute one step of controller 626 | output_list = self._step_op_decoder(time, step_input, memory_state, controller_state) 627 | # update memory parameters 628 | 629 | # new_controller_state = tf.zeros(1) 630 | new_memory_state = tuple(output_list[0:7]) 631 | new_controller_state = output_list[11] # state hidden values 632 | 633 | if self.nlayer>1: 634 | try: 635 | controller_hiddens = controller_hiddens.write(time, new_controller_state[-1][-1]) 636 | print('state include c and h') 637 | except: 638 | controller_hiddens = controller_hiddens.write(time, new_controller_state[-1]) 639 | print('state include only h') 640 | else: 641 | controller_hiddens = controller_hiddens.write(time, new_controller_state[-1]) 642 | print('single layer') 643 | outputs = outputs.write(time, output_list[7]) # new output is updated 644 | outputs_cache = outputs_cache.write(time, output_list[7]) # new output is updated 645 | # collecting memory view for the current step 646 | free_gates = free_gates.write(time, output_list[8]) 647 | allocation_gates = allocation_gates.write(time, output_list[9]) 648 | write_gates = write_gates.write(time, output_list[10]) 649 | read_weightings = read_weightings.write(time, output_list[5]) 650 | write_weightings = write_weightings.write(time, output_list[4]) 651 | usage_vectors = usage_vectors.write(time, output_list[1]) 652 | 653 | # all variables have been updated should be return for next step reference 654 | return ( 655 | time + 1, # 0 656 | new_memory_state, # 1 657 | outputs, # 2 658 | free_gates, allocation_gates, write_gates, # 3 4 5 659 | read_weightings, write_weightings, usage_vectors, # 6 7 8 660 | new_controller_state, # 9 661 | outputs_cache, # 10 662 | controller_hiddens, # 11 663 | encoder_write_weightings, #12 664 | encoder_controller_hiddens, #13 665 | ) 666 | 667 | def build_graph(self): 668 | """ 669 | builds the computational graph that performs a step-by-step evaluation 670 | of the input data batches 671 | """ 672 | 673 | # make dynamic time step length tensor 674 | self.unpacked_input_encoder_data = utility.unpack_into_tensorarray(self.input_encoder, 1, self.sequence_length) 675 | 676 | # want to store all time step values of these variables 677 | eoutputs = tf.TensorArray(tf.float32, self.sequence_length) 678 | eoutputs_cache = tf.TensorArray(tf.float32, self.sequence_length) 679 | efree_gates = tf.TensorArray(tf.float32, self.sequence_length) 680 | eallocation_gates = tf.TensorArray(tf.float32, self.sequence_length) 681 | ewrite_gates = tf.TensorArray(tf.float32, self.sequence_length) 682 | eread_weightings = tf.TensorArray(tf.float32, self.sequence_length, clear_after_read=False) 683 | ewrite_weightings = tf.TensorArray(tf.float32, self.sequence_length, clear_after_read=False) 684 | eusage_vectors = tf.TensorArray(tf.float32, self.sequence_length, clear_after_read=False) 685 | econtroller_hiddens = tf.TensorArray(tf.float32, self.sequence_length, clear_after_read=False) 686 | cache_econtroller_hiddens = tf.TensorArray(tf.float32, self.sequence_length, clear_after_read=False, dynamic_size=True) 687 | 688 | # make dynamic time step length tensor 689 | self.unpacked_input_decoder_data = utility.unpack_into_tensorarray(self.input_decoder, 1, self.decode_length) 690 | 691 | # want to store all time step values of these variables 692 | doutputs = tf.TensorArray(tf.float32, self.decode_length) 693 | doutputs_cache = tf.TensorArray(tf.float32, self.decode_length) 694 | dfree_gates = tf.TensorArray(tf.float32, self.decode_length) 695 | dallocation_gates = tf.TensorArray(tf.float32, self.decode_length) 696 | dwrite_gates = tf.TensorArray(tf.float32, self.decode_length) 697 | dread_weightings = tf.TensorArray(tf.float32, self.decode_length) 698 | dwrite_weightings = tf.TensorArray(tf.float32, self.decode_length, clear_after_read=False) 699 | dusage_vectors = tf.TensorArray(tf.float32, self.decode_length) 700 | dcontroller_hiddens = tf.TensorArray(tf.float32, self.decode_length, clear_after_read=False) 701 | 702 | # inital state for RNN controller 703 | controller_state = self.controller.zero_state() 704 | print(controller_state) 705 | memory_state = self.memory.init_memory() 706 | 707 | 708 | 709 | # final_results = None 710 | with tf.variable_scope("sequence_encoder_loop"): 711 | time = tf.constant(0, dtype=tf.int32) 712 | time2 = tf.constant(0, dtype=tf.int32) 713 | # use while instead of scan --> suitable with dynamic time step 714 | encoder_results = tf.while_loop( 715 | cond=lambda time, *_: time < self.sequence_length, 716 | body=self._loop_body_encoder, 717 | loop_vars=( 718 | time, memory_state, eoutputs, 719 | efree_gates, eallocation_gates, ewrite_gates, 720 | eread_weightings, ewrite_weightings, 721 | eusage_vectors, controller_state, 722 | eoutputs_cache, econtroller_hiddens, time2, cache_econtroller_hiddens 723 | ), # do not need to provide intial values, the initial value lies in the variables themselves 724 | parallel_iterations=self.parallel_rnn, 725 | swap_memory=True 726 | ) 727 | 728 | memory_state2 = self.memory.init_memory(self.read_heads_decode) 729 | if self.read_heads_decode!=self.read_heads: 730 | encoder_results_state=(encoder_results[1][0],encoder_results[1][1],encoder_results[1][2], 731 | encoder_results[1][3],encoder_results[1][4], memory_state2[5],memory_state2[6]) 732 | else: 733 | encoder_results_state=encoder_results[1] 734 | 735 | 736 | 737 | with tf.variable_scope("sequence_decoder_loop"): 738 | time = tf.constant(0, dtype=tf.int32) 739 | nstate = controller_state 740 | if self.pass_encoder_state: 741 | nstate = encoder_results[9] 742 | 743 | self.final_encoder_state = nstate 744 | self.final_encoder_readw = encoder_results[6].read(encoder_results[0]-1) 745 | self.final_encoder_memory_mat = encoder_results[1][0] 746 | 747 | # use while instead of scan --> suitable with dynamic time step 748 | final_results = tf.while_loop( 749 | cond=lambda time, *_: time < self.decode_length, 750 | body=self._loop_body_decoder, 751 | loop_vars=( 752 | time, encoder_results_state, doutputs, 753 | dfree_gates, dallocation_gates, dwrite_gates, 754 | dread_weightings, dwrite_weightings, 755 | dusage_vectors, nstate, 756 | doutputs_cache, dcontroller_hiddens, 757 | encoder_results[7], encoder_results[11] 758 | ), # do not need to provide intial values, the initial value lies in the variables themselves 759 | parallel_iterations=self.parallel_rnn, 760 | swap_memory=True 761 | ) 762 | 763 | 764 | dependencies = [] 765 | if self.controller.has_recurrent_nn: 766 | # tensor array of pair of hidden and state values of rnn 767 | dependencies.append(self.controller.update_state(final_results[9])) 768 | 769 | with tf.control_dependencies(dependencies): 770 | # convert output tensor array to normal tensor 771 | self.packed_output = utility.pack_into_tensor(final_results[2], axis=1) 772 | if self.use_encoder_output: 773 | self.packed_output_encoder = utility.pack_into_tensor(encoder_results[2], axis=1) 774 | 775 | self.packed_memory_view_encoder = { 776 | 'free_gates': utility.pack_into_tensor(encoder_results[3], axis=1), 777 | 'allocation_gates': utility.pack_into_tensor(encoder_results[4], axis=1), 778 | 'write_gates': utility.pack_into_tensor(encoder_results[5], axis=1), 779 | 'read_weightings': utility.pack_into_tensor(encoder_results[6], axis=1), 780 | 'write_weightings': utility.pack_into_tensor(encoder_results[7], axis=1), 781 | 'usage_vectors': utility.pack_into_tensor(encoder_results[8], axis=1), 782 | 'final_controller_ch': encoder_results[9], 783 | } 784 | 785 | 786 | self.packed_memory_view_decoder = { 787 | 'free_gates': utility.pack_into_tensor(final_results[3], axis=1), 788 | 'allocation_gates': utility.pack_into_tensor(final_results[4], axis=1), 789 | 'write_gates': utility.pack_into_tensor(final_results[5], axis=1), 790 | 'read_weightings': utility.pack_into_tensor(final_results[6], axis=1), 791 | 'write_weightings': utility.pack_into_tensor(final_results[7], axis=1), 792 | 'usage_vectors': utility.pack_into_tensor(final_results[8], axis=1), 793 | 'final_controller_ch':final_results[9], 794 | } 795 | 796 | 797 | 798 | def get_outputs(self): 799 | """ 800 | returns the graph nodes for the output and memory view 801 | Returns: Tuple 802 | outputs: Tensor (batch_size, time_steps, output_size) 803 | memory_view: dict 804 | """ 805 | if self.use_encoder_output: 806 | return self.packed_output_encoder, self.packed_memory_view_encoder, self.packed_memory_view_decoder 807 | return self.packed_output, self.packed_memory_view_encoder, self.packed_memory_view_decoder 808 | 809 | def get_hidden_value_from_state(self, state): 810 | print("state") 811 | if self.nlayer > 1: 812 | if 'lstm' in self.controller_cell_type: 813 | ns = state[-1][-1] 814 | print('multilayer state include c and h') 815 | statetype = 1 816 | else: 817 | ns = state[-1] 818 | print('multilayer state include only h') 819 | statetype = 2 820 | else: 821 | if 'lstm' in self.controller_cell_type: 822 | ns = state[-1] 823 | statetype = 3 824 | else: 825 | ns = state 826 | statetype = 4 827 | print('single layer') 828 | print("{}:{}".format(statetype, ns)) 829 | return ns, statetype 830 | 831 | def get_single_output(self): 832 | h, _ = self.get_hidden_value_from_state(self.final_encoder_state) 833 | h = tf.nn.relu(h) 834 | if not self.use_mem: 835 | self.sWo = tf.get_variable('sWo', [self.hidden_controller_dim, self.hidden_controller_dim//4], 836 | initializer=tf.random_normal_initializer(stddev=0.1)) 837 | else: 838 | self.sWo = tf.get_variable('sWo', [self.hidden_controller_dim+self.word_size*self.read_heads, 839 | self.hidden_controller_dim//4], 840 | initializer=tf.random_normal_initializer(stddev=0.1)) 841 | readv = self.memory.update_read_vectors(self.final_encoder_memory_mat, self.final_encoder_readw) 842 | h = tf.concat([h,tf.reshape(readv,[self.batch_size,-1])], axis=-1) 843 | 844 | 845 | output = tf.matmul(h, self.sWo) 846 | output = tf.nn.dropout(output, keep_prob=self.drop_out_v) 847 | output = tf.nn.relu(output) 848 | self.sWo2 = tf.get_variable('sWo2', [self.hidden_controller_dim//4, self.output_size], 849 | initializer=tf.random_normal_initializer(stddev=0.1)) 850 | output = tf.matmul(output, self.sWo2) 851 | return tf.expand_dims(output, axis=1) 852 | 853 | def assign_pretrain_emb_encoder(self, sess, lookup_mat): 854 | assign_op_W_emb_encoder = self.W_emb_encoder.assign(lookup_mat) 855 | sess.run([assign_op_W_emb_encoder]) 856 | 857 | def assign_pretrain_emb_decoder(self, sess, lookup_mat): 858 | assign_op_W_emb_decoder = self.W_emb_decoder.assign(lookup_mat) 859 | sess.run([assign_op_W_emb_decoder]) 860 | 861 | 862 | def build_loss_function(self, optimizer=None, clip_s=10): 863 | print('build loss....') 864 | if optimizer is None: 865 | optimizer = tf.train.AdamOptimizer() 866 | output, _, _ = self.get_outputs() 867 | 868 | prob = tf.nn.softmax(output, dim=-1) 869 | 870 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 871 | labels=self.target_output, 872 | logits=output, dim=-1)) 873 | 874 | 875 | gradients = optimizer.compute_gradients(loss) 876 | for i, (grad, var) in enumerate(gradients): 877 | if grad is not None: 878 | if isinstance(clip_s, list): 879 | gradients[i] = (tf.clip_by_value(grad, clip_s[0], clip_s[1]), var) 880 | else: 881 | gradients[i] = (tf.clip_by_norm(grad, clip_s), var) 882 | 883 | 884 | apply_gradients = optimizer.apply_gradients(gradients) 885 | return output, prob, loss, apply_gradients 886 | 887 | def build_loss_function_regression(self, optimizer=None, clip_s=10): 888 | print('build loss....') 889 | if optimizer is None: 890 | optimizer = tf.train.AdamOptimizer() 891 | output, _, _ = self.get_outputs() 892 | 893 | 894 | loss = tf.reduce_mean(tf.squared_difference( 895 | self.target_output, 896 | output)) 897 | 898 | 899 | gradients = optimizer.compute_gradients(loss) 900 | for i, (grad, var) in enumerate(gradients): 901 | if grad is not None: 902 | if isinstance(clip_s, list): 903 | gradients[i] = (tf.clip_by_value(grad, clip_s[0], clip_s[1]), var) 904 | else: 905 | gradients[i] = (tf.clip_by_norm(grad, clip_s), var) 906 | 907 | 908 | apply_gradients = optimizer.apply_gradients(gradients) 909 | return output, output, loss, apply_gradients 910 | 911 | def build_loss_function_multi_label(self, optimizer=None, clip_s=10, prefer_one_class=False, is_neat=False): 912 | print('build loss multi label....') 913 | if self.use_mem: 914 | is_neat=False 915 | if optimizer is None: 916 | optimizer = tf.train.AdamOptimizer() 917 | if not is_neat: 918 | output, _, _ = self.get_outputs() 919 | else: 920 | output = self.get_single_output() 921 | print("ooooo") 922 | print(output) 923 | if prefer_one_class: 924 | prob = tf.nn.softmax(output) 925 | fn = tf.nn.softmax_cross_entropy_with_logits 926 | else: 927 | prob = tf.nn.sigmoid(output) 928 | fn = tf.nn.sigmoid_cross_entropy_with_logits 929 | 930 | loss = tf.reduce_mean(fn( 931 | labels=tf.slice(self.target_output, [0, 0, 0], 932 | [self.batch_size, 1, self.output_size]), 933 | 934 | logits=tf.slice(output, [0, 0, 0], 935 | [self.batch_size, 1, self.output_size])) 936 | ) 937 | 938 | gradients = optimizer.compute_gradients(loss) 939 | for i, (grad, var) in enumerate(gradients): 940 | if grad is not None: 941 | if isinstance(clip_s, list): 942 | gradients[i] = (tf.clip_by_value(grad, clip_s[0], clip_s[1]), var) 943 | else: 944 | gradients[i] = (tf.clip_by_norm(grad, clip_s), var) 945 | 946 | apply_gradients = optimizer.apply_gradients(gradients) 947 | return output, prob, loss, apply_gradients 948 | 949 | 950 | def print_config(self): 951 | return '{}.{}cell_{}mem_{}dec_{}dua_{}wrp_{}wsz_{}msz_{}tea_{}hid_{}nread_{}nlayer'.\ 952 | format(self.name, self.controller_cell_type, self.use_mem, 953 | self.decoder_mode, 954 | self.dual_controller, 955 | self.write_protect, 956 | self.words_num, 957 | self.word_size, 958 | self.use_teacher, 959 | self.hidden_controller_dim, 960 | self.read_heads_decode, 961 | self.nlayer) 962 | 963 | @staticmethod 964 | def get_size_model(): 965 | total_parameters = 0 966 | for variable in tf.trainable_variables(): 967 | # shape is an array of tf.Dimension 968 | shape = variable.get_shape() 969 | variable_parameters = 1 970 | for dim in shape: 971 | variable_parameters *= dim.value 972 | total_parameters += variable_parameters 973 | return total_parameters 974 | 975 | @staticmethod 976 | def save(session, ckpts_dir, name): 977 | """ 978 | saves the current values of the model's parameters to a checkpoint 979 | Parameters: 980 | ---------- 981 | session: tf.Session 982 | the tensorflow session to save 983 | ckpts_dir: string 984 | the path to the checkpoints directories 985 | name: string 986 | the name of the checkpoint subdirectory 987 | """ 988 | checkpoint_dir = os.path.join(ckpts_dir, name) 989 | 990 | if not os.path.exists(checkpoint_dir): 991 | os.makedirs(checkpoint_dir) 992 | 993 | tf.train.Saver(tf.global_variables()).save(session, os.path.join(checkpoint_dir, 'model.ckpt')) 994 | 995 | 996 | 997 | 998 | 999 | @staticmethod 1000 | def restore(session, ckpts_dir, name): 1001 | """ 1002 | session: tf.Session 1003 | the tensorflow session to restore into 1004 | ckpts_dir: string 1005 | the path to the checkpoints directories 1006 | name: string 1007 | the name of the checkpoint subdirectory 1008 | """ 1009 | tf.train.Saver(tf.global_variables()).restore(session, os.path.join(ckpts_dir, name, 'model.ckpt')) 1010 | 1011 | @staticmethod 1012 | def get_bool_rand(size_seq, prob_true=0.1): 1013 | ret = [] 1014 | for i in range(size_seq): 1015 | if np.random.rand() < prob_true: 1016 | ret.append(True) 1017 | else: 1018 | ret.append(False) 1019 | return np.asarray(ret) 1020 | 1021 | @staticmethod 1022 | def get_bool_rand_incremental(size_seq, prob_true_min=0, prob_true_max=0.25): 1023 | ret = [] 1024 | for i in range(size_seq): 1025 | prob_true=(prob_true_max-prob_true_min)/size_seq*i 1026 | if np.random.rand() < prob_true: 1027 | ret.append(True) 1028 | else: 1029 | ret.append(False) 1030 | return np.asarray(ret) 1031 | 1032 | @staticmethod 1033 | def get_bool_rand_curriculum(size_seq, epoch, k=0.99, type='exp'): 1034 | if type=='exp': 1035 | prob_true = k**epoch 1036 | elif type=='sig': 1037 | prob_true = k / (k + np.exp(epoch / k)) 1038 | ret = [] 1039 | for i in range(size_seq): 1040 | if np.random.rand() < prob_true: 1041 | ret.append(True) 1042 | else: 1043 | ret.append(False) 1044 | return np.asarray(ret) 1045 | --------------------------------------------------------------------------------